#!/usr/bin/env python

from pyPheWAS.pyPhewasCorev2 import print_start_msg, display_kwargs
from pyPheWAS.NoveltyPheDAS import *
import pandas as pd
from pathlib import Path
import time
import argparse
import math
from ast import literal_eval
import sys
from tqdm import tqdm


def parse_args():
    parser = argparse.ArgumentParser(description="Run pyPheWAS Novelty Analysis")

    parser.add_argument('--pm_dir', required=True, type=str, help='Path to PheCode PubMed directory')
    parser.add_argument('--statfile', required=True, type=str, help='Name of the pyPheWAS stat file (e.g. regressions.csv)')
    parser.add_argument('--dx_pm', required=True, type=str,help='Name of the Disease Search PubMed file (e.g. dx_PubMED_results.csv)')
    parser.add_argument('--null_int', required=True, type=str, help='Null interval for calculating the 2nd gen p-value (e.g. [0.3, 1.1])')
    parser.add_argument('--path', required=False, default='.', type=str,help='Path to all input files and destination of output files')
    parser.add_argument('--postfix', required=False, default=None, type=str,help='Descriptive postfix for output files (e.g. poster or ages50-60)')

    args = parser.parse_args()
    return args


"""
Print Start Message
"""
start = time.time()
print_start_msg()
print('\nNoveltyAnalysis: pyPheWAS Novelty Finding Index Tool\n')

"""
Retrieve and validate all arguments.
"""
args = parse_args()
kwargs = {'path': Path(args.path),
          'pm_dir': Path(args.pm_dir),
		  'statfile': args.statfile,
		  'dx_pm': args.dx_pm,
          'null_int': args.null_int,
		  'postfix': args.postfix,
}

# Assert that valid file names were given
assert kwargs['statfile'].endswith('.csv'), "%s is not a valid stat file, must be a .csv file" % (kwargs['statfile'])
assert kwargs['dx_pm'].endswith('.csv'), "%s is not a valid Dx PubMed file, must be a .csv file" % (kwargs['dx_pm'])
# Assert that valid files were given
assert (kwargs['path'] / kwargs['statfile']).exists(), "%s does not exist" %(kwargs['path'] / kwargs['statfile'])
assert (kwargs['path'] / kwargs['dx_pm']).exists(), "%s does not exist" %(kwargs['path'] / kwargs['dx_pm'])
assert kwargs['pm_dir'].exists(), "%s does not exist" % kwargs['pm_dir']

# check null interval
try:
    null_int_str = kwargs['null_int']
    kwargs['null_int'] = literal_eval(kwargs['null_int'])
except Exception as e:
    print('Error encountered while parsing the null interval: %s' % null_int_str)
    print(e)
    sys.exit()
assert len(kwargs['null_int']) == 2, 'Provided null interval does not contain two items/boundaries: %s' % null_int_str

# Print Arguments
display_kwargs(kwargs)
# Make all arguments local variables
locals().update(kwargs)

"""
Load files
"""
dx_pubmed = pd.read_csv(path / dx_pm)

reg_f = open(path / statfile)
reg_hdr = reg_f.readline()
reg = pd.read_csv(reg_f, dtype={"PheWAS Code":str})
reg_f.close()

# split confidence interval into lower & upper bounds
reg[['beta_lowlim', 'beta_uplim']] = reg['Conf-interval beta'].str.split(',', expand=True)
reg['beta_uplim'] = reg.beta_uplim.str.replace(']', '')
reg['beta_lowlim'] = reg.beta_lowlim.str.replace('[', '')
reg = reg.astype(dtype={'beta_uplim': float, 'beta_lowlim': float})

# convert log odds ratio (beta & its confidence interval) to odds ratios
reg['OddsRatio'] = np.exp(reg['beta'])
reg['OR_uplim'] = np.exp(reg['beta_uplim'])
reg['OR_lowlim'] = np.exp(reg['beta_lowlim'])

"""
Combine Mass PheCode & Dx PubMed results
"""
reg = get_joint_PubMed_articles(reg, dx_pubmed, pm_dir)

"""
Run Novelty Calculations
"""
print('Calculating Novelty Finding Index')
reg = calcNoveltyScore(reg, null_int)

"""
Save Regression File w/ Novelty Data
"""
if postfix is not None:
    fparts = statfile.split('.')
    outfile = '%s_%s.%s' % (fparts[0], postfix, fparts[1])
else:
    outfile = statfile

print('Saving updated regression file to %s' % (path/outfile))
with open(path / outfile, 'w+') as f:
    f.write(reg_hdr)
    reg.to_csv(f, index=False)

"""
Plot Novelty Finding Index Results
"""
print('Generating Novelty Finding Index Plots')
# filter regressions - only plot those with second gen p-val == 0 (significant)
reg_to_plot = reg[reg['sgpv']==0].copy()
reg_to_plot.sort_values(by=['Novelty_Finding_Index'], ascending=False, inplace=True)
# plot in groups of 25-30 (or less) to keep plots legible
if postfix is not None:
    basename = 'NFI_%s' % postfix
else:
    basename = 'NFI'
# finally, make the plots
if reg_to_plot.shape[0] < 30: # just one plot
    plotfile = basename + '.png'
    plot_log_odds_ratio_novelty(reg_to_plot, np.log(null_int), save=path / plotfile)
else: # lots of results = lots of plots
    for plot_ix, plot_group in tqdm(reg_to_plot.groupby(np.arange(reg_to_plot.shape[0])//25)):
        plotfile = '%s_%d.png' %(basename, plot_ix)
        plot_log_odds_ratio_novelty(plot_group, np.log(null_int), path / plotfile)

"""
Calculate runtime
"""
interval = time.time() - start
hour = math.floor(interval/3600.0)
minute = math.floor((interval - hour*3600)/60)
second = math.floor(interval - hour*3600 - minute*60)

if hour > 0:
    time_str = '%dh:%dm:%ds' %(hour,minute,second)
elif minute > 0:
    time_str = '%dm:%ds' % (minute, second)
else:
    time_str = '%ds' % second

print('NoveltyAnalysis Complete [Runtime: %s]' %time_str)