#! /usr/bin/env python

import argparse
import sys
import numpy as np
import numpy.lib.recfunctions
import scipy
import scipy.stats
import scipy.special
import RIFT.lalsimutils as lalsimutils
import lalsimulation as lalsim
import lalframe
import lal
import functools
import itertools

parser = argparse.ArgumentParser()
parser.add_argument("--fname",help="filename of *.dat file [standard ILE output]")
parser.add_argument("--fname-output-samples",default="my_output",help="filename of *.dat file [standard ILE output]")
parser.add_argument("--fref",default=20,type=float,help="filename of *.dat file [standard ILE output]")
parser.add_argument("--fmin",default=20,type=float,help="filename of *.dat file [standard ILE output]")
opts=  parser.parse_args()


col_lnL=9
dat_orig = dat = np.loadtxt(opts.fname)
dat_orig = dat[dat[:,col_lnL].argsort()] # sort  http://stackoverflow.com/questions/2828059/sorting-arrays-in-numpy-by-colu

P_list = []
P= lalsimutils.ChooseWaveformParams()
P_list_in = []
for line in dat:
    P.fref = opts.fref  # IMPORTANT if you are using a quantity that depends on J
    P.fmin = opts.fmin
    P.m1 = line[1]*lal.MSUN_SI
    P.m2 = line[2]*lal.MSUN_SI
    P.s1x = line[3]
    P.s1y = line[4]
    P.s1z = line[5]
    P.s2x = line[6]
    P.s2y = line[7]
    P.s2z = line[8]

    # if opts.input_tides:
    #     P.lambda1 = line[9]
    #     P.lambda2 = line[10]
    # if opts.input_distance:
    #     P.dist = lal.PC_SI*1e6*line[9]  # Incompatible with tides, note!

    P_list.append(P)
    
lalsimutils.ChooseWaveformParams_array_to_xml(P_list,fname=opts.fname_output_samples,fref=P.fref)

