#!/usr/bin/env python
from pyPheWAS.pyPhewasCorev2 import *
import sys, os, math
import pandas as pd
import argparse
from pathlib import Path
import os.path as osp
import time


def parse_args():
    parser = argparse.ArgumentParser(description="pyPheWAS Phenotype Assignment Tool")

    parser.add_argument('--phenotype', required=True, type=str, help='Name of the input phenotype file')
    parser.add_argument('--group', required=False, default='', type=str, help='Name of the group file to add target variable map to')
    parser.add_argument('--groupout', required=True, type=str, help ='Name of the output group file')
    parser.add_argument('--case_codes', required=True, type=str, help='Case ICD codes (filename or comma-separated list)')
    parser.add_argument('--ctrl_codes', required=False, default='', type=str,  help='Control ICD codes (filename or comma-separated list)')
    parser.add_argument('--code_freq', required=True, type=str, help='Minimum frequency of codes (If 2 comma-separated values aregiven and ctrl_codes is given, 2nd argument is applied to controls)')
    parser.add_argument('--response_var', required=False, type=str,default='target', help='Name of response variable (e.g. target, genotype, response [default=target])')
    parser.add_argument('--path', required=False, default='.', type=str, help='Path to all input files and destination of output files (default = current directory)')

    args = parser.parse_args()
    return args


"""
Print Start Message
"""
start = time.time()
print_start_msg()
print('\ncreatePhenotypeFile: Response Variable Assignment Tool\n')


"""
Retrieve and validate all arguments.
"""
args = parse_args()

kwargs = {
	'phenotype':args.phenotype,
	'group':args.group,
	'groupout':args.groupout,
	'path':Path(args.path),
	'case_codes':args.case_codes,
	'ctrl_codes': args.ctrl_codes,
	'code_freq':args.code_freq,
    'response_var':args.response_var,
}

# Change path to absolute path
# kwargs['path'] = os.path.join(os.path.abspath(kwargs['path']),'')

# Assert that files are valid
assert kwargs['phenotype'].endswith('.csv'), "%s is not a valid phenotype file, must be a .csv file" % (kwargs['phenotype'])
assert kwargs['groupout'].endswith('.csv'), "%s is not a valid output file, must be a .csv file" % (kwargs['groupout'])
if len(kwargs['group']) > 0:
	assert kwargs['group'].endswith('.csv'), "%s is not a valid output file, must be a .csv file" % (kwargs['group'])

# Print Arguments
display_kwargs(kwargs)
# Make all arguments local variables
locals().update(kwargs)

# Fill paths
phenotype = path / phenotype
groupout = path / groupout
if len(str(group)) > 0:
	group = path / group

# Assert that all files exist
assert osp.exists(phenotype), "%s does not exist" % phenotype
if len(str(group)) > 0:
	assert osp.exists(group), "%s does not exist" % group
if case_codes.endswith('.csv') | case_codes.endswith('.txt'):
	assert osp.exists(path/case_codes), "%s does not exist" % (path / case_codes)
if ctrl_codes.endswith('.csv') | ctrl_codes.endswith('.txt'):
	assert osp.exists(path/ctrl_codes), "%s does not exist" % (path / ctrl_codes)

# Read group file
if len(str(group)) > 0:
	group_data = pd.read_csv(group)

# Make code frequency an integer
code_freq = code_freq.replace(" ","").split(',')
for i in range(len(code_freq)):
	code_freq[i] = int(code_freq[i])

"""
Parse codes
"""
# Case
if case_codes.endswith('.csv') | case_codes.endswith('.txt'):
	print('Reading case group codes from file')
	with open(path/case_codes,'r') as code_f:
		case_codes = code_f.readlines()[0]
# remove white space and split into an array
case_codes = case_codes.replace(" ","").replace("\n","").split(',')


# Controls
if len(ctrl_codes) > 0:
	if ctrl_codes.endswith('.csv') | ctrl_codes.endswith('.txt'):
		print('Reading control group codes from file')
		with open(path/ctrl_codes,'r') as code_f:
			ctrl_codes = code_f.readlines()[0]
	# remove white space and split into an array
	ctrl_codes = ctrl_codes.replace(" ", "").replace("\n", "").split(',')


"""
Find codes & make groups
"""
phen = pd.read_csv(phenotype)
phen[response_var] = -1

# Cases
print('Finding cases with codes: %s' % '|'.join(case_codes))
# append \Z to force regex to find exact match
for ix in range(len(case_codes)):
	case_codes[ix] = case_codes[ix] + '\Z'
phen['gen'] = phen['ICD_CODE'].str.match('|'.join(case_codes)) # find all ICD code matches
phen['gen'] = phen['gen']*1 # convert to integer
phen['genc'] = phen.groupby('id')['gen'].transform('sum') # count all instances
case_mask = phen['genc']>=code_freq[0]
rm_mask = (phen['genc'] < code_freq[0]) & (phen['genc'] > 0) # need to remove these later
phen.loc[case_mask,response_var] = 1

# Controls
if len(ctrl_codes) > 0:
	print('Finding controls with codes: %s' % '|'.join(ctrl_codes))
	# append \Z to force regex to find exact match
	for ix in range(len(ctrl_codes)):
		ctrl_codes[ix] = ctrl_codes[ix] + '\Z'
	phen['gen_ctrl'] = phen['ICD_CODE'].str.match('|'.join(ctrl_codes))
	phen['gen_ctrl'] = phen['gen_ctrl']*1
	phen['genc_ctrl'] = phen.groupby('id')['gen_ctrl'].transform('sum')
	if len(code_freq) > 1:
		cf = code_freq[1]
	else:
		cf = code_freq[0]
	ctrl_mask = (phen['genc_ctrl']>=cf) & ~ case_mask
	phen.loc[ctrl_mask, response_var] = 0
	# drop other subjects
	sub_mask = (case_mask | ctrl_mask) & ~rm_mask
	phen = phen[sub_mask]
else:
	phen.loc[~case_mask,response_var] = 0
    # drop subjects
    phen = phen[~rm_mask]

phen = phen[['id',response_var]].drop_duplicates()

"""
Save Output
"""
if len(str(group)) > 0:
	print('Merging response variable assignment with provided group file')
	phen = pd.merge(phen, group_data, how='inner',on='id', suffixes=('','_old'))

num_case = phen[phen[response_var]==1].shape[0]
num_ctrl = phen[phen[response_var]==0].shape[0]
print('Cases: %d\nControls: %d' %(num_case, num_ctrl))

print('Saving response variable mapping to %s' % groupout)
phen.to_csv(groupout,index=False)

"""
Calculate runtime
"""
interval = time.time() - start
hour = math.floor(interval/3600.0)
minute = math.floor((interval - hour*3600)/60)
second = math.floor(interval - hour*3600 - minute*60)

if hour > 0:
    time_str = '%dh:%dm:%ds' %(hour,minute,second)
elif minute > 0:
    time_str = '%dm:%ds' % (minute, second)
else:
    time_str = '%ds' % second

print('createPhenotypeFile Complete [Runtime: %s]' %time_str)
