#!/usr/bin/env python

from pyPheWAS.pyPhewasCorev2 import *
import time
import math
import argparse
import sys
from pathlib import Path
import os.path as osp

def parse_args():
    parser = argparse.ArgumentParser(description="Run the Entire pyProWAS Pipeline (Lookup, Model, Plot)")

    parser.add_argument('--phenotype', required=True, type=str, help='Name of the phenotype file (e.g. cpt_data.csv)')
    parser.add_argument('--group', required=True, type=str, help ='Name of the group file (e.g. groups.csv)')
    parser.add_argument('--reg_type', required=True, type=str, help='Type of regression that you would like to use (log, lin, or dur)')
    
    parser.add_argument('--path', required=False, default='.', type=str, help='Path to all input files and destination of output files')
    parser.add_argument('--postfix', required=False, default=None, type=str, help='Descriptive postfix for output files (e.g. poster or ages50-60)')
    parser.add_argument('--prowas_cov', required=False, default=None, type=str, help='ProCodes to use as covariates in the regression')
    parser.add_argument('--covariates', required=False, default='', type=str, help='Variables to be used as covariates')
    parser.add_argument('--response', required=False, default='genotype', type=str, help='Variable to predict instead of genotype')
    parser.add_argument('--reg_thresh', required=False, default=5, type=int,help='Threshold of subjects presenting a ProCode required for running regression (default: 5)')
    parser.add_argument('--legacy', required=False, default="False", type=str, help='Use legacy solver scheme (default: False)')

    parser.add_argument('--imbalance', required=False, default="True", type=str, help='Whether or not to show the direction of imbalance in the plot')
    parser.add_argument('--plot_all_pts', required=False, default="True", help='Show all points regardless of significance in the Manhattan plot [True (default) or False]')
    parser.add_argument('--prowas_label', required=False, default="plot", type=str, help='Where to put ProCode labels - plot (default) or axis')
    parser.add_argument('--thresh_type', required=False, default=None, type=str, help='Type of threshold to be used in the plot (fdr, bon, or custom)')
    parser.add_argument('--custom_thresh', required=False, default=None, type=float, help='Custom threshold value (float between 0 and 1)')
    parser.add_argument('--plot_format', required=False, default="png", type=str, help='File extension for plot files (default: png)')
    
    args = parser.parse_args()
    return args

"""
Print Start Message
"""
start = time.time()
print_start_msg()
print('\npyProwasPipeline: Run the entire pyProWAS Pipeline (Lookup, Model, Plot)\n')


"""
Retrieve and validate all arguments.
"""
args = parse_args()
kwargs = {'path': Path(args.path),
          'phenotype': args.phenotype,
          'group': args.group,
          'reg_type': args.reg_type,
          'postfix': args.postfix,
          'response': args.response,
          'reg_thresh':args.reg_thresh,
          'covariates': args.covariates,
          'prowas_cov': args.prowas_cov,
          'legacy':args.legacy,
          'thresh_type': args.thresh_type,
          'imbalance': args.imbalance,
          'plot_all_pts': args.plot_all_pts,
          'custom_thresh': args.custom_thresh,
          'prowas_label': args.prowas_label,
          'plot_format': args.plot_format
}

# Assert that valid file names were given
assert kwargs['phenotype'].endswith('.csv'), "%s is not a valid phenotype file, must be a .csv file" % (kwargs['phenotype'])
assert kwargs['group'].endswith('.csv'), "%s is not a valid group file, must be a .csv file" % (kwargs['group'])
# Assert that valid files were given
assert osp.exists(kwargs['path'] / kwargs['phenotype']), "%s does not exist" %(kwargs['path'] / kwargs['phenotype'])
assert osp.exists(kwargs['path'] / kwargs['group']), "%s does not exist" %(kwargs['path'] / kwargs['group'])

# Assert that valid regression & threshold types were used
assert kwargs['reg_type'] in regression_map.keys(), "%s is not a valid regression type" % kwargs['reg_type']

if kwargs['thresh_type'] is None:
    kwargs['thresh_type'] = ['fdr','bon']
else:
    assert kwargs['thresh_type'] in threshold_map.keys(), "%s is not a valid regression type" % (kwargs['thresh_type'])
    kwargs['thresh_type'] = [kwargs['thresh_type']]
    if kwargs['thresh_type'][0] == 'custom':
        assert kwargs['custom_thresh'] is not None, "Custom threshold specified. Please define --custom_thresh"
        assert (kwargs['custom_thresh'] < 1.0) & (kwargs['custom_thresh'] > 0.0), "%s is not a valid threshold (should be between 0.0 and 1.0)" % (kwargs['custom_thresh'])

# Assign the output postfix if none was assigned
if kwargs['postfix'] is None:
    kwargs['postfix'] = osp.splitext(kwargs['group'])[0]
    if kwargs['covariates'] is not '':
        kwargs['postfix'] = kwargs['covariates'] + '_' + kwargs['postfix']
else:
    # make sure there's not an extension
    kwargs['postfix'] = osp.splitext(kwargs['postfix'])[0]

assert kwargs['prowas_label'] in ["plot","axis"], "%s is not a valid ProCode label location" % (kwargs['prowas_label'])

assert kwargs['legacy'] in ["True", "False"], "%s is not a valid legacy value (True or False)" % kwargs['legacy']
kwargs['legacy'] = eval(kwargs['legacy'])
assert kwargs['imbalance'] in ["True", "False"], "%s is not a valid imbalance value (True or False)" % kwargs['imbalance']
kwargs['imbalance'] = eval(kwargs['imbalance'])
assert kwargs['plot_all_pts'] in ["True", "False"], "%s is not a valid plot_all_pts value (\"True\" or \"False\")" % kwargs['plot_all_pts']
kwargs['plot_all_pts'] = eval(kwargs['plot_all_pts'])

# Print Arguments
display_kwargs(kwargs)

# Make all arguments local variables
locals().update(kwargs)


""" 
Load Data
"""

print("Retrieving phenotype data...")
phenotypes = get_cpt_codes(path, phenotype, regression_map[reg_type])

print("Retrieving group data...")
genotypes = get_group_file(path, group)

# check response variable
assert response in genotypes.columns, "response %s is not a column in the group file" % response

# check covariates
if covariates != '':
    for cov in covariates.replace(" ", "").split('+'):
        if cov != 'MaxAgeAtCPT':
            assert cov in genotypes.columns, "covariate %s is not a column in the group file" % cov
        else:
            assert 'MaxAgeAtCPT' not in genotypes.columns, "covariate %s is a reserved covariate name; please remove or rename this column in the group file" % cov

if not {'MaxAgeAtVisit'}.issubset(genotypes.columns):
	print('WARNING: MaxAgeAtVisit was not found in group file. Calculating MaxAgeAtVisit from phenotype data')
	phenotypes['MaxAgeAtVisit'] = phenotypes.groupby(['id'])['AgeAtCPT'].transform('max')
	genotypes = pd.merge(genotypes, phenotypes[['id','MaxAgeAtVisit']].drop_duplicates(subset='id'),on='id',how='left')


"""
pyProwasLookup
"""

print("Generating feature matrix...")
fm,columns = generate_feature_matrix(genotypes, phenotypes, regression_map[reg_type], 'CPT', prowas_cov)

fm_outfile = "feature_matrix_" + postfix + ".csv"
print("Saving feature matrices to %s" % (path /('*_' + fm_outfile)))
h = ','.join(columns)

np.savetxt(path /('agg_measures_' + fm_outfile), fm[0],delimiter=',',header=h)
print("...")
np.savetxt(path /('cpt_age_' + fm_outfile), fm[1],delimiter=',',header=h)
if prowas_cov is not None:
	# only save this if it actually means something
    print("...")
    np.savetxt(path /('prowas_cov_' + fm_outfile), fm[2],delimiter=',',header=h)


""" 
pyProwasModel
"""

print("Running ProWAS regressions...")
if legacy:
    # regressions are run with regularization is only one target group has them; regressions are run without regularization otherwise
    regressions = run_phewas_legacy(fm, genotypes, 'CPT', covariates, response, reg_thresh)
else:
    # all regressions are run with regularization
    regressions = run_phewas(fm, genotypes, 'CPT', covariates, response, reg_thresh)

reg_outfile = "regressions_" + postfix + ".csv"
print("Saving regression data to %s" % (path / reg_outfile))

if prowas_cov is not None:
    covariates = covariates + '+' + prowas_cov
h = ['group', group, 'feature_matrix', fm_outfile, 'reg_type', reg_type, 'code_type', 'CPT', 'covariates', covariates]
header = ','.join(h) + '\n'
f = open(path / reg_outfile, 'w')
f.write(header)
regressions.to_csv(f, index=False)
f.close()

""" 
pyProwasPlot
"""

try:
    # make confidence interval numberic instead of a string
    regressions[['lowlim', 'uplim']] = regressions['Conf-interval beta'].str.split(',', expand=True)
    regressions['uplim'] = regressions.uplim.str.replace(']', '')
    regressions['lowlim'] = regressions.lowlim.str.replace('[', '')
    regressions = regressions.astype(dtype={'uplim':float,'lowlim':float})
except Exception as e:
    print('Error reading regression file:')
    print(e)
    sys.exit()

# Get the threshold
pvalues = regressions['p-val'].values

plot_format.replace('.','') # remove leading periods if they were given by user
print('Saving plots to %s' % path)

for t in thresh_type:
    # Get the threshold type
    if t == 'bon':
        thresh = get_bon_thresh(pvalues,0.05)
    elif t == 'fdr':
        thresh = get_fdr_thresh(pvalues,0.05)
    elif t == 'custom':
        thresh = custom_thresh
    print('%s threshold: %0.5f' % (t, thresh))
    
    savem = path / ('Manhattan_%s_%s.%s' %(t, postfix, plot_format))
    saveb = path / ('LogOdds_%s_%s.%s' %(t, postfix, plot_format))

    plot_manhattan(regressions, thresh, 'CPT', imbalance, plot_all_pts, savem, plot_format)
    plot_log_odds_ratio(regressions, thresh, 'CPT', saveb, plot_format, prowas_label)
   
savev = path / ('Volcano_%s.%s' %(postfix, plot_format))
plot_volcano(regressions, 'CPT', savev, plot_format)


"""
Calculate runtime
"""

interval = time.time() - start
hour = math.floor(interval/3600.0)
minute = math.floor((interval - hour*3600)/60)
second = math.floor(interval - hour*3600 - minute*60)

if hour > 0:
    time_str = '%dh:%dm:%ds' %(hour,minute,second)
elif minute > 0:
    time_str = '%dm:%ds' % (minute, second)
else:
    time_str = '%ds' % second

print('pyProwasPipeline Complete [Runtime: %s]' %time_str)
