#!/usr/bin/env python

from pyPheWAS.pyPhewasCorev2 import *
import os
import sys
import argparse
import time
import math
import matplotlib.pyplot as plt
from pathlib import Path
import os.path as osp

def parse_args():
    parser = argparse.ArgumentParser(description="pyPheWAS Plotting Tool")

    parser.add_argument('--statfile', required=True, type=str, help='Name of the statistics/regressions file')
    parser.add_argument('--thresh_type', required=True, type=str, help=' the type of threshold to be used in the plot')
    parser.add_argument('--custom_thresh', required=False, default=None, type=float, help='Custom threshold value (float between 0 and 1)')
    parser.add_argument('--imbalance', required=False, default="True", help = 'Show the direction of imbalance in the Manhattan plot [True (default) or False]')
    parser.add_argument('--plot_all_pts', required=False, default="True", help='Show all points regardless of significance in the Manhattan plot [True (default) or False]')
    parser.add_argument('--phewas_label', required=False, default="plot", type=str, help='Location of PheCode labels on Log Odds plot [plot (default) or axis]')
    parser.add_argument('--path', required=False, default='.', type=str, help='Path to all input files and destination of output files')
    parser.add_argument('--outfile', required=False, default=None, type=str, help='Name of the output file for the plot')

    args = parser.parse_args()
    return args

"""
Print Start Message
"""
start = time.time()
print_start_msg()
print('\npyPhewasPlot: Plot Mass PheCode Regression Results\n')


"""
Retrieve and validate all arguments.
"""
args = parse_args()
kwargs = {'path': Path(args.path),
          'statfile': args.statfile,
          'thresh_type': args.thresh_type,
          'imbalance': args.imbalance,
          'plot_all_pts': args.plot_all_pts,
          'custom_thresh':args.custom_thresh,
          'phewas_label': args.phewas_label,
          'outfile':args.outfile,
}

# Assert that a valid threshold type was used
assert kwargs['thresh_type'] in threshold_map.keys(), "%s is not a valid regression type" % (kwargs['thresh_type'])
if kwargs['thresh_type'] == 'custom':
    assert kwargs['custom_thresh'] is not None, "Custom threshold specified. Please define --custom_thresh"
    assert (kwargs['custom_thresh'] < 1.0) & (kwargs['custom_thresh'] > 0.0), "%s is not a valid threshold (should be between 0.0 and 1.0)" % (kwargs['custom_thresh'])

# Assert that valid files were given
assert kwargs['statfile'].endswith('.csv'), "%s is not a valid phenotype file (must be a .csv file)" % (kwargs['feature_matrix'])
assert osp.exists(kwargs['path'] / kwargs['statfile']), "%s does not exist" %(kwargs['path'] / kwargs['statfile'])

assert kwargs['phewas_label'] in ["plot","axis"], "%s is not a valid PheCode label location" % (kwargs['phewas_label'])

assert kwargs['imbalance'] in ["True", "False"], "%s is not a valid imbalance value (\"True\" or \"False\")" % kwargs['imbalance']
kwargs['imbalance'] = eval(kwargs['imbalance'])
assert kwargs['plot_all_pts'] in ["True", "False"], "%s is not a valid plot_all_pts value (\"True\" or \"False\")" % kwargs['plot_all_pts']
kwargs['plot_all_pts'] = eval(kwargs['plot_all_pts'])

# Print Arguments
display_kwargs(kwargs)
# Make all arguments local variables
locals().update(kwargs)


"""
Load Data
"""

ff = open(path / statfile)
header = ff.readline().strip().split(',')
reg_args = {}
for i in range(0,len(header),2):
    reg_args[header[i]] = header[i+1]
print('\nRegression Info')
display_kwargs(reg_args)

# Read in the remaining data (the pandas DataFrame)
regressions = pd.read_csv(ff,dtype={'PheWAS Code':str})

try:
    # make confidence interval numberic instead of a string
    regressions[['lowlim', 'uplim']] = regressions['Conf-interval beta'].str.split(',', expand=True)
    regressions['uplim'] = regressions.uplim.str.replace(']', '')
    regressions['lowlim'] = regressions.lowlim.str.replace('[', '')
    regressions = regressions.astype(dtype={'uplim':float,'lowlim':float})
except Exception as e:
    print('Error reading regression file:')
    print(e)
    sys.exit()


"""
Create plots
"""

# Get the threshold
pvalues = regressions['p-val'].values

if thresh_type == 'bon':
    thresh = get_bon_thresh(pvalues,0.05)
elif thresh_type == 'fdr':
    thresh = get_fdr_thresh(pvalues,0.05)
elif thresh_type == 'custom':
    thresh = custom_thresh
print('%s threshold: %0.5f'%(thresh_type,thresh))

# figure out file names
if outfile is not None:
    file_name, file_format = osp.splitext(outfile)
    savem = path / (file_name + '_Manhattan' + file_format)
    saveb = path / (file_name + '_LogOdds' + file_format)
    savev = path / (file_name + '_Volcano' + file_format)
    file_format = file_format[1:] # remove '.' from from first index
else:
    savem = ''
    saveb = ''
    savev = ''
    file_format = ''

plot_manhattan(regressions, thresh, 'ICD', imbalance, plot_all_pts, savem, file_format)
plot_log_odds_ratio(regressions, thresh, 'ICD', saveb, file_format, phewas_label)
plot_volcano(regressions, 'ICD', savev, file_format)

if outfile is not None:
    print("Saving plots to %s" % (path))
else:
    plt.show()


"""
Calculate runtime
"""

interval = time.time() - start
hour = math.floor(interval/3600.0)
minute = math.floor((interval - hour*3600)/60)
second = math.floor(interval - hour*3600 - minute*60)

if hour > 0:
    time_str = '%dh:%dm:%ds' %(hour,minute,second)
elif minute > 0:
    time_str = '%dm:%ds' % (minute, second)
else:
    time_str = '%ds' % second

print('pyPhewasPlot Complete [Runtime: %s]' %time_str)