#!/usr/bin/env python
#
#   Copyright 2016 Blaise Frederick
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
#
#       $Author: frederic $
#       $Date: 2016/06/14 12:04:50 $
#       $Id: linfit,v 1.4 2016/06/14 12:04:50 frederic Exp $
#
import argparse
import sys

import numpy as np
import rapidtide.io as tide_io
import rapidtide.workflows.parser_funcs as pf
from pylab import *
from scipy.stats import pearsonr

# set default variable values
DEFAULTPCT = 98.0
DEFAULTPASSES = 1


def calccorrs(intemplatedata, indata, numregionstemplate, numregions, debug=False):
    # first calculate all correlations
    if debug:
        print("calccorrs:")
        print(f"\tlen(intemplatedata): {len(intemplatedata)}")
        print(f"\tlen(indata): {len(indata)}")
        print(f"\tnumregionstemplate: {numregionstemplate}")
        print(f"\tnumregions: {numregions}")

    maxdim = max(numregionstemplate, numregions)
    thecorrmat = np.zeros((maxdim, maxdim), dtype=np.float64)
    outdata = indata * 0

    for i in range(numregionstemplate):
        cluster1 = np.where(intemplatedata == (i + 1), 1.0, 0.0)
        if debug:
            print(f"index {i}")
        for j in range(numregions):
            cluster2 = np.where(indata == (j + 1), 1.0, 0.0)
            thecorrmat[i, j], dummy = pearsonr(cluster1, cluster2)

    # now find the best match of each region to a template region
    # make sure that each row and each column only get 1 match
    matches = []
    for i in range(numregions):
        thisloc = np.unravel_index(thecorrmat.argmax(), thecorrmat.shape)
        thiscorrval = thecorrmat[thisloc]
        matches.append([thisloc[0] + 0, thisloc[1] + 0, thiscorrval + 0.0])
        thecorrmat[:, thisloc[1]] = -1.0
        thecorrmat[thisloc[0], :] = -1.0

    bestmatchlocs = np.zeros((numregions), dtype=int)
    bestmatchvals = np.zeros((numregions), dtype=float)
    for i in range(numregions):
        bestmatchlocs[matches[i][0]] = matches[i][1]
        bestmatchvals[i] = matches[i][2]

    if debug:
        print("after matching")
        print(matches)
        for i in range(numregions):
            print(
                f"component {i}: best match of {bestmatchvals[i]} at component {bestmatchlocs[i]}"
            )

    for i in range(numregions):
        outdata[np.where(indata == (bestmatchlocs[i] + 1))] = i + 1

    return outdata


def _get_parser():
    # get the command line parameters
    parser = argparse.ArgumentParser(
        prog="clustersort",
        description="Renumber regions in a 4D cluster file to optimally match.",
        usage="%(prog)s  datafile outputroot",
    )
    parser.add_argument(
        "datafilename",
        help="The name of the 4 dimensional nifti file of cluster labels.",
    )
    parser.add_argument("outputrootname", help="The root of the output file names.")

    parser.add_argument(
        "--maskfile",
        dest="maskfilename",
        type=lambda x: pf.is_valid_file(parser, x),
        metavar="MASK",
        help="Only process voxels within the 3D mask MASK.",
        default=None,
    )
    parser.add_argument(
        "--regionfile",
        dest="regionfilename",
        type=lambda x: pf.is_valid_file(parser, x),
        metavar="FILE",
        help="Use the region labels in FILE as a starting template for sorting.",
        default=None,
    )

    parser.add_argument(
        "--percentile",
        dest="percentile",
        type=float,
        metavar="PCT",
        help=f"Percentile probability value to use for sorting regions.  Default is {DEFAULTPCT}.",
        default=DEFAULTPCT,
    )

    parser.add_argument(
        "--passes",
        dest="passes",
        type=int,
        metavar="PASSES",
        help=f"Number of matching/refinement passes.  Default is {DEFAULTPASSES}.",
        default=DEFAULTPASSES,
    )

    parser.add_argument(
        "--debug",
        dest="debug",
        action="store_true",
        help=("Print extended debugging information."),
        default=False,
    )
    return parser


def main():
    # get the command line parameters
    try:
        args = _get_parser().parse_args()
    except SystemExit:
        _get_parser().print_help()
        raise

    thecommandline = [" ".join(sys.argv)]
    # save the command line
    tide_io.writevec(
        thecommandline,
        args.outputrootname + "_commandline.txt",
    )

    # read in data
    print("reading in data array")
    (
        datafile_img,
        datafile_data,
        datafile_hdr,
        datafiledims,
        datafilesizes,
    ) = tide_io.readfromnifti(args.datafilename)

    if args.maskfilename is not None:
        print("reading in mask array")
        (
            datamask_img,
            datamask_data,
            datamask_hdr,
            datamaskdims,
            datamasksizes,
        ) = tide_io.readfromnifti(args.maskfilename)

    if args.regionfilename is not None:
        print("reading in region array")
        (
            dataregion_img,
            dataregion_data,
            dataregion_hdr,
            dataregiondims,
            dataregionsizes,
        ) = tide_io.readfromnifti(args.regionfilename)

    xsize, ysize, numslices, timepoints = tide_io.parseniftidims(datafiledims)
    xdim, ydim, slicethickness, tr = tide_io.parseniftisizes(datafilesizes)

    # check dimensions
    if args.maskfilename is not None:
        print("checking mask dimensions")
        if not tide_io.checkspacematch(datafile_hdr, datamask_hdr):
            print("input mask spatial dimensions do not match image")
            exit()
        if not datamaskdims[4] == 1:
            print("input mask time must have time dimension of 1")
            exit()
    if args.regionfilename is not None:
        print("checking region file dimensions")
        if not tide_io.checkspacematch(datafile_hdr, dataregion_hdr):
            print("input region file spatial dimensions do not match image")
            exit()
        if not dataregiondims[4] == 1:
            print("input region file time must have time dimension of 1")
            exit()

    # allocating arrays
    print("reshaping arrays")
    numspatiallocs = int(xsize) * int(ysize) * int(numslices)
    print(f"there are {numspatiallocs} voxels")
    rs_datafile = datafile_data.reshape((numspatiallocs, timepoints))

    print("masking arrays")
    if args.maskfilename is not None:
        proclocs = np.where(datamask_data.reshape((numspatiallocs)) > 0.5)
    else:
        themaxes = np.max(rs_datafile, axis=1).reshape((numspatiallocs))
        proclocs = np.where(themaxes > 0)
    procdata = rs_datafile[proclocs, :][0]
    if args.debug:
        print(f"unmasked shape: {rs_datafile.shape}, masked shape: {procdata.shape}")

    outputarray = procdata * 0
    outputarray[:, 0] = procdata[:, 0]
    numregions = int(np.max(rs_datafile))

    for thepass in range(args.passes):
        if thepass == 0:
            if args.regionfilename is None:
                thetemplate = procdata[:, 0]
            else:
                thetemplate = dataregion_data.reshape((numspatiallocs))[proclocs]
        numregionstemplate = int(np.max(thetemplate))
        if args.passes > 1:
            print(f"\nStarting pass {thepass + 1} of {args.passes}")
            passlabel = f"_pass{str(thepass + 1).zfill(2)}"
        else:
            passlabel = ""

        # match labelling to that of the template
        for thetimepoint in range(timepoints):
            # compare labels to the first timepoint
            print(f"matching components for run {thetimepoint + 1}")
            outputarray[:, thetimepoint] = calccorrs(
                thetemplate,
                procdata[:, thetimepoint],
                numregionstemplate,
                numregions,
                debug=args.debug,
            )

        # make a probability map
        probmap = np.zeros((procdata.shape[0], numregions), dtype=np.float64)
        regionpercentileprob = np.zeros((numregions), dtype=np.float64)
        regionsize = np.zeros((numregions), dtype=int)
        weightedregionsize = np.zeros((numregions), dtype=np.float64)
        for theregion in range(numregions):
            thisarray = np.where(outputarray == (theregion + 1), 1.0, 0.0)
            probmap[:, theregion] = np.mean(thisarray, axis=1)
            regvoxels = probmap[np.where(probmap[:, theregion] > 0.0), theregion]
            regionpercentileprob[theregion] = np.percentile(regvoxels, args.percentile)
            regionsize[theregion] = len(regvoxels[0])
            weightedregionsize[theregion] = regionsize[theregion] * regionpercentileprob[theregion]

        # sort the probability map by rank
        regionpercentileprobinds = np.flip(regionpercentileprob.argsort())
        regionsizeinds = np.flip(regionsize.argsort())
        weightedregionsizeinds = np.flip(weightedregionsize.argsort())
        print("regionpercentileprob", regionpercentileprobinds)
        print("regionsizes", regionsizeinds)
        print("weightedregionsizes", weightedregionsizeinds)
        rankindices = regionpercentileprobinds

        for i in range(numregions):
            print(
                f"region: {rankindices[i]}, "
                + f"size: {regionsize[rankindices[i]]}, "
                + f"{DEFAULTPCT}th percentile:{regionpercentileprob[rankindices[i]]}"
            )

        # save the sorted cluster maps
        remappedoutputarray = np.zeros((procdata.shape[0], timepoints), dtype=int)
        for theregion in range(numregions):
            remappedoutputarray[np.where(outputarray == rankindices[theregion] + 1)] = (
                theregion + 1
            )
        remappedtempout = np.zeros((numspatiallocs, timepoints), dtype=int)
        remappedtempout[proclocs, :] = remappedoutputarray
        tide_io.savetonifti(
            remappedtempout.reshape((xsize, ysize, numslices, timepoints)),
            datafile_hdr,
            args.outputrootname + passlabel,
        )

        # write out the individual probability maps
        tempout = np.zeros((numspatiallocs, numregions), dtype=np.float64)
        remappedtempout = np.zeros((numspatiallocs, numregions), dtype=np.float64)
        for theregion in range(numregions):
            tempout[proclocs, theregion] = probmap[:, theregion]
        for theregion in range(numregions):
            remappedtempout[:, theregion] = tempout[:, rankindices[theregion]]
        output_hdr = datafile_hdr
        output_hdr["dim"][4] = numregions
        tide_io.savetonifti(
            remappedtempout.reshape((xsize, ysize, numslices, numregions)),
            output_hdr,
            args.outputrootname + "_probseg" + passlabel,
        )

        # write out the maxprob map
        tempout = np.zeros((numspatiallocs), dtype=int)
        remappedtempout = np.zeros((numspatiallocs), dtype=int)
        tempout[proclocs] = np.argmax(probmap, axis=1) + 1
        for theregion in range(numregions):
            remappedtempout[np.where(tempout == rankindices[theregion] + 1)] = theregion + 1
        output_hdr["dim"][0] = 3
        output_hdr["dim"][4] = 1
        tide_io.savetonifti(
            remappedtempout.reshape((xsize, ysize, numslices)),
            output_hdr,
            args.outputrootname + "_maxprob" + passlabel,
        )

        # set up for next pass
        thetemplate = remappedtempout[proclocs]


if __name__ == "__main__":
    main()
