#!/usr/bin/env python
#
#   Copyright 2016 Blaise Frederick
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
#
# $Author: frederic $
#       $Date: 2016/06/14 12:04:50 $
#       $Id: glmfilt,v 1.25 2016/06/14 12:04:50 frederic Exp $
#
from __future__ import print_function

import sys

import numpy as np
import rapidtide.fit as tide_fit
import rapidtide.io as tide_io


def main():
    #
    #       Initial setup
    #
    # read in the datafile
    nargs = len(sys.argv)
    if nargs < 5:
        print("usage: fitglm tcfile outputroot evfile [evfile_2...evfile_n]")
        print("    Fits multiple evs to timecourses in a file")
        exit()

    # initialize some variables
    evdata = []
    evfilename = []

    # read in the parameters
    inputfile = sys.argv[1]
    outputroot = sys.argv[2]
    evfilename.append(sys.argv[3])
    numfiles = 1
    print(evfilename[0])
    if nargs > 4:
        for i in range(4, nargs):
            numfiles += 1
            evfilename.append(sys.argv[i])
            print(evfilename[numfiles - 1])

    # read the datafile and the evfiles
    tc_data = tide_io.readvecs(inputfile)
    numtcs = tc_data.shape[0]
    timepoints = tc_data.shape[1]

    numregressors = 0
    for i in range(0, numfiles):
        print("file ", i, " has name ", evfilename[i])
        print("reading global regressor from ", evfilename[i])
        evtimeseries = tide_io.readvec(evfilename[i])
        print("timeseries length = ", len(evtimeseries))
        evdata.append(1.0 * evtimeseries)
        numregressors += 1

    for j in range(0, numregressors):
        if timepoints != len(evdata[j]):
            print("Input file and ev file ", j, " dimensions do not match")
            exit()

    print("will perform GLM with ", numregressors, " regressors")
    meandata = np.zeros((numtcs), dtype="float")
    coefficient = np.zeros((numtcs, numregressors), dtype="float")
    Rdata = np.zeros((numtcs), dtype="float")
    component = np.zeros((numtcs, timepoints, numregressors), dtype="float")
    thefit = np.zeros((numtcs, timepoints), dtype="float")
    residuals = 1.0 * thefit
    trimmeddata = 1.0 * tc_data[:, :]
    print("numtcs = ", numtcs)
    print("timepoints = ", timepoints)
    print("numregressors = ", numregressors)

    for thetc in range(0, numtcs):
        regressorvec = []
        for j in range(0, numregressors):
            regressorvec.append(evdata[j])
        if np.max(trimmeddata[thetc, :]) - np.min(trimmeddata[thetc, :]) > 0.0:
            thisfit, R = tide_fit.mlregress(regressorvec, trimmeddata[thetc, :])
            meandata[thetc] = thisfit[0, 0]
            Rdata[thetc] = R
            for j in range(0, numregressors):
                coefficient[thetc, j] = thisfit[0, j + 1]
                component[thetc, :, j] = thisfit[0, j + 1] * regressorvec[j]
        else:
            meandata[thetc] = 0.0
            Rdata[thetc] = 0.0
            for j in range(0, numregressors):
                coefficient[thetc, j] = 0.0
                component[thetc, :, j] = 0.0 * regressorvec[j]
        thefit[thetc, :] = np.sum(component[thetc, :, :], axis=1)
        residuals[thetc, :] = trimmeddata[thetc, :] - thefit[thetc, :]

    print("processing complete: about to save data")

    # first save the things with a single timepoint
    tide_io.writenpvecs(meandata, outputroot + "_mean.txt")
    for j in range(0, numregressors):
        tide_io.writenpvecs(coefficient[:, j], outputroot + "_coefficient_" + str(j).zfill(2))
    tide_io.writenpvecs(Rdata, outputroot + "_R")

    # now save the things with full timecourses
    for j in range(0, numregressors):
        tide_io.writenpvecs(component[:, :, j], outputroot + "_component_" + str(j).zfill(2))
    tide_io.writenpvecs(thefit, outputroot + "_fit")
    tide_io.writenpvecs(residuals, outputroot + "_residuals")


if __name__ == "__main__":
    # import cProfile as profile
    # profile.run('main()', 'rapidtide_profile')
    main()
