#!/usr/bin/env python
#
#   Copyright 2016 Blaise Frederick
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
#
#       $Author: frederic $
#       $Date: 2016/06/14 12:04:51 $
#       $Id: showstxcorr,v 1.11 2016/06/14 12:04:51 frederic Exp $
#
from __future__ import division, print_function

import sys

import numpy as np
import rapidtide.io as tide_io


def usage():
    print("")
    print("statematch - infer the matching between states in two state timecourses")
    print("")
    print("usage: statematch timecoursefile1 timecoursefile2 outputroot")
    print("")
    print("required arguments:")
    print("    timecoursefile1  - text file containing state labels")
    print("    timecoursefile2  - text file containing state labels")
    print("    outputroot       - the root name of the output files")
    print("")
    return ()


# check that required arguments are set
if len(sys.argv) != 4:
    usage()
    sys.exit()

# get the command line parameters
infile1 = sys.argv[1]
infile2 = sys.argv[2]
outfile = sys.argv[3]

inputdata1 = tide_io.readvec(infile1).astype(int)
inputdata2 = tide_io.readvec(infile2).astype(int)
if len(inputdata1) != len(inputdata2):
    print("input file lengths do not match")
    sys.exit()

min1 = np.min(inputdata1)
min2 = np.min(inputdata2)
max1 = np.max(inputdata1)
max2 = np.max(inputdata2)

if (max1 - min1) != (max2 - min2):
    print("timecourses do not have the same number of states")
    sys.exit()
numstates = max1 - min1 + 1

matcharray = np.zeros((numstates, numstates), dtype=float)
for i in range(len(inputdata1)):
    loc1 = inputdata1[i] - min1
    loc2 = inputdata2[i] - min2
    matcharray[loc1, loc2] += 1.0

maparray = np.zeros((numstates), dtype=int)
for i in range(numstates):
    maxpos = np.argmax(matcharray[:, i])
    maparray[i] = maxpos
    percentmatch = 100.0 * matcharray[maxpos, i] / np.sum(matcharray[:, i])
    print(
        "state",
        maxpos + min1,
        "in tc1 matches state",
        i + min2,
        "in tc2",
        percentmatch,
        "% of the time",
    )

remapped2 = inputdata2 * 0
for i in range(len(inputdata1)):
    remapped2[i] = maparray[inputdata2[i] - min2] + min1

tide_io.writenpvecs(remapped2, outfile + "_remapped.txt")
tide_io.writenpvecs(matcharray, outfile + "_matcharray.txt")
