#!/usr/bin/env python

import os
import flask
from flask import Flask
from flask_cors import CORS
import numpy as np
import pandas as pd
from pyPheWAS.pyPhewasExplorerCore import *
import scipy.stats
import argparse
from pathlib import Path


# for dev mode, run before calling this script:
# export FLASK_ENV=development

# constants
bin_reg = 0
count_reg = 1
dur_reg = 2

"""
Get arguments
"""

parser = argparse.ArgumentParser(description='pyPheWAS Explorer server')
parser.add_argument('indir', type=str, help='Input directory for pyPheWAS analysis')
parser.add_argument('--response', type=str, required=False, default='target', help='Name of column to use as the response (dependent) variable')
args = parser.parse_args()

data_path = Path(args.indir)
response = args.response

group_f = data_path / "group.csv"
icd_f = data_path / "icds.csv"
bin_fm_f = data_path / "binary_feature_matrix.csv"
cnt_fm_f = data_path / "count_feature_matrix.csv"
dur_fm_f = data_path / "duration_feature_matrix.csv"
FMs_exist = bin_fm_f.exists() & cnt_fm_f.exists() & dur_fm_f.exists()

"""
Set everything up
"""
# load group data
group_data = get_group_file(group_f)
gvars = group_data.columns
assert 'id' in gvars, "Group file (%s) does not contain a subject identifier column ('id')" % group_f
assert response in gvars, "Group file (%s) does not contain the target variable (%s)" % (group_f, response)

gvars = gvars.drop(['id', response])
if len(gvars) > 10:
	print('WARNING: more than 10 group variables found. Limiting display to first 10 group variables')
	gvars = gvars[0:10]
	print(', '.join(gvars))

if FMs_exist:
	print('Loading binary feature matrix')
	fm_bin = np.loadtxt(data_path/'binary_feature_matrix.csv', delimiter=',')
	print('Loading count feature matrix')
	fm_cnt = np.loadtxt(data_path/'count_feature_matrix.csv', delimiter=',')
	print('Loading duration feature matrix')
	fm_dur = np.loadtxt(data_path/'duration_feature_matrix.csv', delimiter=',')
	sub_count = [group_data.shape[0], group_data.shape[0], group_data.shape[0]]
	assert [fm_bin.shape[0],fm_cnt.shape[0],fm_dur.shape[0]] == sub_count, "Feature matrices and group data do not contain the same number of subjects. Please delete the feature matrices and restart the Explorer"
else:
	print('Building Feature Matrices')
	pheno = get_icd_codes(icd_f)
	fm_bin, fm_cnt, fm_dur, columns = generate_feature_matrix(group_data, pheno)
	# save feature matrices
	h = ','.join(columns)
	np.savetxt(data_path/'binary_feature_matrix.csv', fm_bin, delimiter=',', header=h)
	np.savetxt(data_path/'count_feature_matrix.csv', fm_cnt, delimiter=',', header=h)
	np.savetxt(data_path/'duration_feature_matrix.csv', fm_dur, delimiter=',', header=h)

print("pyPheWAS Explorer Ready")
print("Please open http://localhost:8000/ in a web brower (preferably Google Chrome)")
print("---\n\n\n")



"""
create Flask app
"""
app = Flask(__name__)
CORS(app)

@app.route('/grab_data', methods=['GET','POST'])
def get_signals():
	# get data from the client
	client_data = flask.request.json
	command = client_data['cmd']

	if command == "init":
		datatype = client_data['ftype']
		if datatype == "response":
			msg = pd.DataFrame(columns=['msg'])
			msg.loc[0, 'msg'] = response
			data_obj = msg.to_json(orient='records')
		elif datatype == "group":
			var_df = pd.DataFrame(columns=['var','corr','extent','g0','g1'], index=range(0,len(gvars)))
			var_df['var'] = gvars
			mask0 = group_data[response] == 0
			mask1 = group_data[response] == 1
			x = group_data[response].values
			for ix,data in var_df.iterrows():
				gv = data['var']
				y = group_data[gv].values
				var_df.loc[ix,'corr'] = scipy.stats.spearmanr(x, y).correlation
				var_df.loc[ix, 'extent'] = np.array([group_data[gv].min(), group_data[gv].max()])
				var_df.loc[ix, 'g0'] = group_data.loc[mask0, gv].values
				var_df.loc[ix, 'g1'] = group_data.loc[mask1, gv].values
			data_obj = var_df.to_json(orient='records')
		elif datatype == "histograms":
			h = pd.DataFrame()
			for g in gvars:
				print(g)
				h_g = get_1D_histogram(group_data, g, response)
				h = h.append(h_g)
			h.reset_index(drop=True)
			data_obj = h.to_json(orient='records')
		else:
			data_obj = "ERROR executing command: INIT\nUnknown data type %s" % datatype

	elif command == "compute_hist2D":
		var1 = client_data["var1"]
		var2 = client_data["var2"]
		if var1 == '':
			msg = pd.DataFrame(columns=['msg'])
			msg.loc[0, 'msg'] = "no_data"
			data_obj = msg.to_json(orient='records')
		elif var1 == var2:
			msg = pd.DataFrame(columns=['msg'])
			msg.loc[0, 'msg'] = "select_2nd_var"
			data_obj = msg.to_json(orient='records')
		else:
			h = get_2D_histogram(group_data, var1, var2, response)
			data_obj = h.to_json(orient='records')

	elif command == "independence_tests":
		var1 = client_data["var1"]
		var2 = client_data["var2"]
		if var1 == '':
			msg = pd.DataFrame(columns=['msg'])
			msg.loc[0,'msg'] = "no_data"
			data_obj = msg.to_json(orient='records')
		elif var1 == var2:
			msg = pd.DataFrame(columns=['msg'])
			msg.loc[0, 'msg'] = "select_2nd_var"
			data_obj = msg.to_json(orient='records')
		else:
			stats = variable_comparison(group_data, var1, var2, response)
			data_obj = stats.to_json(orient='records')

	elif command == "run_reg":
		reg_type = int(client_data['rtype'])
		if reg_type == -1:
			# init - don't do anything
			msg = pd.DataFrame(columns=['msg'])
			msg.loc[0,'msg'] = "no_data"
			data_obj = msg.to_json(orient='records')
		else:
			# build & send regressions to notebook
			cov = '+'.join(client_data['cov'])
			rfile_base = 'regressions_' + cov
			if reg_type == bin_reg:
				regressions = run_phewas(fm_bin, group_data, cov, response)
				reg_str = 'binary'
				reg_str_long = 'reg_type,binary(log)'
			elif reg_type == count_reg:
				regressions = run_phewas(fm_cnt, group_data, cov, response)
				reg_str = 'count'
				reg_str_long = 'reg_type,count(lin)'
			elif reg_type == dur_reg:
				regressions = run_phewas(fm_dur, group_data, cov, response)
				reg_str = 'duration'
				reg_str_long = 'reg_type,duration(dur)'
			# write regressions to file
			header = ','.join([reg_str_long, 'group', 'group.csv', 'response', response, 'covariates', cov]) + '\n'
			rpath = data_path/('regression_%s_%s_%s.csv' %(reg_str, response, cov))
			f = open(rpath, 'w')
			f.write(header)
			regressions.to_csv(f, index=False)
			f.close()
			regressions['pval_str'] = regressions.apply(lambda x: str(x['pval']),axis=1)
			data_obj = regressions.to_json(orient='records')
	else:
		data_obj = "ERROR Unknown command %s" % command

	return flask.jsonify(data_obj)


# execute the application (by default, it should be hosted at localhost:5000, which you will see in the output)
if __name__ == '__main__':
	app.run()
