#!/usr/bin/env python
#
#   Copyright 2016 Blaise Frederick
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
#
#       $Author: frederic $
#       $Date: 2016/06/14 12:04:50 $
#       $Id: linfit,v 1.4 2016/06/14 12:04:50 frederic Exp $
#
from __future__ import division, print_function

import argparse
import os
import sys

try:
    from sklearnex import patch_sklearn

    print("using sklearnex")
    patch_sklearn()
except ImportError:
    pass

import matplotlib.pyplot as plt
import numpy as np
import rapidtide.io as tide_io
import rapidtide.workflows.parser_funcs as pf
from sklearn.ensemble import (
    AdaBoostClassifier,
    BaggingClassifier,
    ExtraTreesClassifier,
    GradientBoostingClassifier,
    RandomForestClassifier,
)
from sklearn.inspection import permutation_importance
from sklearn.manifold import TSNE
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import cross_val_score, train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import RobustScaler, StandardScaler
from sklearn.svm import SVC, LinearSVC


def _get_parser():
    parser = argparse.ArgumentParser(
        prog="supercluster",
        description="Plots the data in text files.",
        usage="%(prog)s texfilename[:col1,col2...,coln] [textfilename]... [options]",
    )

    # Required arguments
    parser.add_argument(
        "datafilename",
        type=lambda x: pf.is_valid_file(parser, x),
        help="The name of the 4 dimensional nifti file to cluster.",
    )

    parser.add_argument(
        "classfilename",
        type=lambda x: pf.is_valid_file(parser, x),
        help="The name of the 3 dimensional nifti file of the class labels.",
    )

    parser.add_argument(
        "outputrootname",
        type=str,
        help=("The root name of the output nifti files."),
    )

    parser.add_argument(
        "--maskfile",
        dest="maskfilename",
        metavar="DATAMASK",
        help=("Use DATAMASK to specify which voxels in the data to classify."),
        default=None,
    )
    parser.add_argument(
        "--prescaler",
        dest="prescaler",
        action="store",
        type=str,
        choices=["standard", "robust", "None"],
        help=("What sort of prescaler to apply to data prior to classification.  Default is None"),
        default="None",
    )
    parser.add_argument(
        "--bagging",
        action="store_true",
        dest="bagging",
        help="Use a bagging classifier.",
        default=False,
    )
    parser.add_argument(
        "--numruns",
        dest="numtests",
        metavar="NUM",
        action="store",
        type=int,
        help=("Set the number of independent iterations to run.  Default is 1."),
        default=1,
    )

    parser.add_argument(
        "--n_estimators",
        dest="n_estimators",
        metavar="ESTIMATORS",
        action="store",
        type=int,
        help=("Set the number of estimators to ESTIMATORS (default is 100)"),
        default=100,
    )

    parser.add_argument(
        "--max_features",
        dest="max_features",
        metavar="NUM",
        action="store",
        type=int,
        help=("Set number of features to NUM (default is -1 - set equal to sqrt(n_estimators)"),
        default=-1,
    )

    parser.add_argument(
        "--n_jobs",
        dest="n_jobs",
        metavar="NUM",
        action="store",
        type=int,
        help=("Number of parallel jobs to run. Default is -1 (maximum)"),
        default=-1,
    )

    parser.add_argument(
        "--type",
        dest="classifiertype",
        action="store",
        type=str,
        choices=[
            "randomforest",
            "knn",
            "gradientboost",
            "adaboost",
            "extratrees",
            "svc",
            "lsvc",
        ],
        help=("Set the classifier type. Default is randomforest."),
        default="randomforest",
    )

    parser.add_argument(
        "--usepermutation",
        action="store_true",
        dest="usepermutation",
        help="Calculate feature importances using permutation method (MUCH slower).",
        default=False,
    )

    kopts = parser.add_argument_group("KNN options")
    kopts.add_argument(
        "--n_neighbors",
        dest="n_neighbors",
        metavar="NUM",
        action="store",
        type=int,
        help=("Set the number of neighbors for KNN classifiers.  Default is 5."),
        default=5,
    )
    kopts.add_argument(
        "--k_algorithm",
        dest="k_algorithm",
        action="store",
        type=str,
        choices=["brute", "kd_tree", "ball_tree", "auto"],
        help=("Set the nearest neighbors algorithm. Default is kd_tree."),
        default="kd_tree",
    )
    kopts.add_argument(
        "--weights",
        dest="weights",
        action="store",
        type=str,
        choices=["uniform", "distance"],
        help=("Set the weighting for KNN classifiers.  Default is distance."),
        default="distance",
    )

    svcopts = parser.add_argument_group("SVC options")
    svcopts.add_argument(
        "--svc_C",
        dest="svc_C",
        metavar="NUM",
        action="store",
        type=float,
        help=("Regularization parameter.  Default is 1.0.  Reduce for more regularization."),
        default=1,
    )
    svcopts.add_argument(
        "--svc_kernel",
        dest="svc_kernel",
        action="store",
        type=str,
        choices=["linear", "poly", "rbf", "sigmoid", "precomputed"],
        help=("Kernel to use.  Default is rbf"),
        default="rbf",
    )
    svcopts.add_argument(
        "--svc_polydegree",
        dest="svc_polydegree",
        metavar="NUM",
        action="store",
        type=int,
        help=("Degree for polynomial kernel.  Default is 3"),
        default=3,
    )
    svcopts.add_argument(
        "--svc_gamma",
        dest="svc_gamma",
        action="store",
        type=str,
        choices=["scale", "auto"],
        help=("Kernel coefficient for rbf, poly, and sigmoid kernels.  Default is scale"),
        default="scale",
    )
    svcopts.add_argument(
        "--probability",
        action="store_true",
        dest="svc_probability",
        help="Calculate classification probabilities.  Very slow.  Default is False.",
        default=False,
    )

    parser.add_argument(
        "--display",
        action="store_true",
        dest="display",
        help="Display feature importances if available.",
        default=False,
    )

    parser.add_argument(
        "--debug",
        action="store_true",
        dest="debug",
        help="Output additional debugging information.",
        default=False,
    )
    return parser


def main():
    # set default variable values
    try:
        args = _get_parser().parse_args()
    except SystemExit:
        _get_parser().print_help()
        raise

    intlist = None

    print("Will perform", args.classifiertype, "classification")

    # read in data
    print("reading in data arrays")
    (
        datafile_img,
        datafile_data,
        datafile_hdr,
        datafiledims,
        datafilesizes,
    ) = tide_io.readfromnifti(args.datafilename)
    (
        classfile_img,
        classfile_data,
        classfile_hdr,
        classfiledims,
        classfilesizes,
    ) = tide_io.readfromnifti(args.classfilename)
    if args.maskfilename is not None:
        (
            datamask_img,
            datamask_data,
            datamask_hdr,
            datamaskdims,
            datamasksizes,
        ) = tide_io.readfromnifti(args.maskfilename)

    xsize, ysize, numslices, timepoints = tide_io.parseniftidims(datafiledims)
    xdim, ydim, slicethickness, tr = tide_io.parseniftisizes(datafilesizes)

    # check dimensions
    print("checking class dimensions")
    if not tide_io.checkspacematch(datafile_hdr, classfile_hdr):
        print("target map spatial dimensions do not match image")
        exit()
    if not classfiledims[4] == 1:
        print("class file must have time dimension of 1")
        exit()

    if args.maskfilename is not None:
        print("checking mask dimensions")
        if not tide_io.checkspacematch(datafile_hdr, datamask_hdr):
            print("input mask spatial dimensions do not match image")
            exit()
        if not datamaskdims[4] == 1:
            print("input mask time must have time dimension of 1")
            exit()

    # allocating arrays
    print("reshaping arrays")
    numspatiallocs = int(xsize) * int(ysize) * int(numslices)
    rs_datafile = datafile_data.reshape((numspatiallocs, timepoints))
    rs_classfile = classfile_data.reshape((numspatiallocs))

    # find known data
    print("calculating knownlocs")
    knownlocs = np.where(rs_classfile > 0)

    # mask the data
    print("masking arrays")
    if args.maskfilename is not None:
        proclocs = np.where(datamask_data.reshape((numspatiallocs)) > 0.9)
    else:
        themaxes = np.max(rs_datafile, axis=1)
        themins = np.min(rs_datafile, axis=1)
        thediffs = (themaxes - themins).reshape((numspatiallocs))
        proclocs = np.where(thediffs > 0.0)

    # construct the arrays
    Y_known = rs_classfile[knownlocs]

    X_known = rs_datafile[knownlocs, :][0]
    X_predict = rs_datafile[proclocs, :][0]
    print(rs_datafile.shape, X_known.shape, X_predict.shape, Y_known.shape)
    n_features = rs_datafile.shape[1]

    print("data has", n_features, "features")
    if args.max_features == -1:
        max_features = int(np.ceil(np.sqrt(n_features)))
    else:
        max_features = args.max_features

    # Scale the data
    if intlist is None:
        intlist = [timepoints]
    if args.prescaler != "None":
        print("prescaling...")
        thepos = 0
        for interval in intlist:
            if args.prescaler == "standard":
                X_known[:, thepos : (thepos + interval)] = StandardScaler().fit_transform(
                    X_known[:, thepos : (thepos + interval)]
                )
                X_predict[:, thepos : (thepos + interval)] = StandardScaler().fit_transform(
                    X_predict[:, thepos : (thepos + interval)]
                )
            else:
                X_known[:, thepos : (thepos + interval)] = RobustScaler().fit_transform(
                    X_known[:, thepos : (thepos + interval)]
                )
                X_predict[:, thepos : (thepos + interval)] = RobustScaler().fit_transform(
                    X_predict[:, thepos : (thepos + interval)]
                )
            thepos += interval
        prescalename = "prescale_" + args.prescaler + "_"
        print("Done")
    else:
        prescalename = ""

    # take a look at it
    if args.display and False:
        print("running TSNE")
        theTSNE = TSNE(n_jobs=args.n_jobs, perplexity=args.n_neighbors)
        projection = theTSNE.fit_transform(X_known)
        plt.scatter(*projection.T)
        plt.show()

    # set up the classifier
    if args.classifiertype == "randomforest":
        if args.bagging:
            theclassifier = BaggingClassifier(
                RandomForestClassifier(
                    n_estimators=args.n_estimators,
                    max_features=max_features,
                    max_depth=None,
                    min_samples_split=2,
                    n_jobs=args.n_jobs,
                ),
                max_samples=0.5,
                max_features=0.5,
            )
            methodname = (
                "randomforest_bagging_"
                + prescalename
                + str(args.n_estimators).zfill(3)
                + "_"
                + str(max_features).zfill(3)
            )
        else:
            theclassifier = RandomForestClassifier(
                n_estimators=args.n_estimators,
                max_features=max_features,
                max_depth=None,
                min_samples_split=2,
                random_state=0,
                n_jobs=args.n_jobs,
            )
            methodname = (
                "randomforest_"
                + prescalename
                + str(args.n_estimators).zfill(3)
                + "_"
                + str(max_features).zfill(3)
            )

    elif args.classifiertype == "extratrees":
        if args.bagging:
            theclassifier = BaggingClassifier(
                ExtraTreesClassifier(
                    n_estimators=args.n_estimators,
                    max_features=max_features,
                    max_depth=None,
                    min_samples_split=2,
                    n_jobs=args.n_jobs,
                ),
                max_samples=0.5,
                max_features=0.5,
            )
            methodname = (
                "extratrees_bagging_"
                + prescalename
                + str(args.n_estimators).zfill(3)
                + "_"
                + str(max_features).zfill(3)
            )
        else:
            theclassifier = ExtraTreesClassifier(
                n_estimators=args.n_estimators,
                max_features=max_features,
                max_depth=None,
                min_samples_split=2,
                random_state=0,
                n_jobs=args.n_jobs,
            )
            methodname = (
                "extratrees_"
                + prescalename
                + str(args.n_estimators).zfill(3)
                + "_"
                + str(max_features).zfill(3)
            )
    elif args.classifiertype == "svc":
        theclassifier = SVC(
            C=args.svc_C,
            kernel=args.svc_kernel,
            degree=args.svc_polydegree,
            gamma=args.svc_gamma,
            probability=args.svc_probability,
        )
        methodname = f"svc_{prescalename}{args.svc_kernel}_{args.svc_polydegree}_{args.svc_C}"

    elif args.classifiertype == "lsvc":
        theclassifier = LinearSVC(
            C=args.svc_C,
        )
        methodname = f"lsvc_{prescalename}{args.svc_kernel}_{args.svc_C}"

    elif args.classifiertype == "adaboost":
        theclassifier = AdaBoostClassifier(n_estimators=args.n_estimators)
        methodname = f"ada_{prescalename}{str(args.n_estimators).zfill(3)}"

    elif args.classifiertype == "gradientboost":
        theclassifier = GradientBoostingClassifier(
            n_estimators=args.n_estimators, max_features="auto"
        )
        methodname = f"gbc_{prescalename}{str(args.n_estimators).zfill(3)}"

    elif args.classifiertype == "knn":
        if args.bagging:
            theclassifier = BaggingClassifier(
                KNeighborsClassifier(
                    n_neighbors=args.n_neighbors,
                    weights=args.weights,
                    n_jobs=args.n_jobs,
                    algorithm=args.k_algorithm,
                ),
                max_samples=0.5,
                max_features=0.5,
            )
            methodname = (
                f"knn_bagging_{prescalename}{args.weights}_{str(args.n_neighbors).zfill(2)}"
            )
        else:
            theclassifier = KNeighborsClassifier(
                n_neighbors=args.n_neighbors,
                weights=args.weights,
                n_jobs=args.n_jobs,
                algorithm=args.k_algorithm,
            )
            methodname = f"knn_{prescalename}{args.weights}_{str(args.n_neighbors).zfill(2)}"

    else:
        print("illegal classifier type")
        sys.exit()

    print("method description tag:", methodname)

    # check accuracy vs. number of components
    ncomps = timepoints

    # check the accuracy of the model
    accuracies = []

    # make space for saved data
    if (
        args.classifiertype == "randomforest"
        or args.classifiertype == "extratrees"
        or args.classifiertype == "adaboost"
        or args.classifiertype == "gradientboost"
    ) and not args.bagging:
        feature_importances = []
    if args.numtests > 1:
        Y_predict_all = np.zeros((X_predict.shape[0], args.numtests), dtype=np.double)

    for testrun in range(args.numtests):
        print(f"test {testrun}")
        print("\tsplitting data")
        X_train, X_test, Y_train, Y_test = train_test_split(
            X_known[:, 0:ncomps], Y_known, test_size=0.2
        )

        # check array information
        if args.debug:
            print("Image data:")
            print(f"\tX flags: {X_train.flags}")
            print(f"\tX type: {X_train.dtype}")
            print(f"\tY flags: {Y_train.flags}")
            print(f"\tY type: {Y_train.dtype}")

        # fit the model
        print("\tfitting model")
        theclassifier.fit(X_train, Y_train)

        if (
            args.classifiertype == "randomforest"
            or args.classifiertype == "extratrees"
            or args.classifiertype == "adaboost"
            or args.classifiertype == "gradientboost"
        ) and not args.bagging:
            print("\tcalculating feature importances")
            if args.usepermutation:
                feature_importances.append(
                    permutation_importance(theclassifier, X_train, Y_train)["importances_mean"]
                )
            else:
                feature_importances.append(theclassifier.feature_importances_)
            if args.display and False:
                plt.plot(feature_importances[-1])
                plt.show()

        # now make predictions on the test data
        print("\tpredicting on test data")
        Y_predict = theclassifier.predict(X_test)
        accuracies.append(accuracy_score(Y_test, Y_predict))

        # if we are doing multirun, make a prediction on all known data for each training
        if args.numtests > 1:
            print("\tpredicting on all data")
            Y_predict_all[:, testrun] = theclassifier.predict(X_predict)

        # save a status value
        tide_io.writenpvecs(
            np.asarray([testrun], dtype=int),
            args.outputrootname + "_" + methodname + "_currentiteration.txt",
        )

        if len(accuracies) > 0:
            theaccuracies = np.asarray(accuracies, dtype=np.double)
            print(
                "Accuracies for",
                ncomps,
                "components (mean, std, min, max):",
                np.mean(theaccuracies),
                np.std(theaccuracies),
                np.min(theaccuracies),
                np.max(theaccuracies),
            )
            with open(args.outputrootname + "_" + methodname + "_accuracy.csv", "w") as text_file:
                dpath, dfile = os.path.split(args.datafilename)
                cpath, cfile = os.path.split(args.classfilename)
                fieldlist = [
                    dfile,
                    cfile,
                    methodname,
                    str(np.mean(theaccuracies)),
                    str(np.std(theaccuracies)),
                    str(np.min(theaccuracies)),
                    str(np.max(theaccuracies)),
                ]
                text_file.write(",".join(fieldlist) + "\n")

    # now classify all the known data to calculate metrics
    Y_predict_known = theclassifier.predict(X_known)
    print("prediction of known locations complete")
    confusion = confusion_matrix(Y_known, Y_predict_known)
    classreport = classification_report(Y_known, Y_predict_known)
    tide_io.writenpvecs(
        confusion,
        args.outputrootname + "_" + methodname + "_predict_all_confusion_matrix.txt",
    )
    print(classreport)

    if args.numtests == 1:
        # now classify the entire image
        Y_predict_all = theclassifier.predict(X_predict)
        print("prediction of all locations complete")

    # save the data
    theheader = datafile_hdr
    theheader["dim"][4] = 1

    # output classification of all known data
    tempout = np.zeros((numspatiallocs), dtype=np.double)
    tempout[knownlocs] = Y_predict_known[:]
    tide_io.savetonifti(
        tempout.reshape((xsize, ysize, numslices, 1)),
        datafile_hdr,
        args.outputrootname + "_" + methodname + "_predict_known",
    )

    # save the command line
    tide_io.writevec(
        [" ".join(sys.argv)],
        args.outputrootname + "_" + methodname + "_commandline.txt",
    )

    # output classification of all known data
    if args.numtests > 1:
        tempout = np.zeros((numspatiallocs, args.numtests), dtype=np.double)
        if args.debug:
            print(f"Y_predice_all.shape: {Y_predict_all.shape}")
            print(f"temput.shape: {tempout.shape}")
        tempout[proclocs, :] = Y_predict_all[:, :]
        theheader = datafile_hdr
        theheader["dim"][4] = args.numtests
        tide_io.savetonifti(
            tempout.reshape((xsize, ysize, numslices, args.numtests)),
            theheader,
            args.outputrootname + "_" + methodname + "_predict_all",
        )
    else:
        tempout = np.zeros((numspatiallocs), dtype=np.double)
        tempout[proclocs] = Y_predict_all[:]
        tide_io.savetonifti(
            tempout.reshape((xsize, ysize, numslices, 1)),
            datafile_hdr,
            args.outputrootname + "_" + methodname + "_predict_all",
        )

    # save other classifier-specific metrics
    if (
        args.classifiertype == "randomforest"
        or args.classifiertype == "extratrees"
        or args.classifiertype == "adaboost"
        or args.classifiertype == "gradientboost"
    ) and not args.bagging:
        importancemeans = np.mean(np.array(feature_importances), axis=0)
        importancestds = np.std(np.array(feature_importances), axis=0)
        tide_io.writenpvecs(
            np.array(feature_importances),
            args.outputrootname + "_" + methodname + "_featureimportances.txt",
        )
        tide_io.writenpvecs(
            importancemeans,
            args.outputrootname + "_" + methodname + "_featureimportances_mean.txt",
        )
        tide_io.writenpvecs(
            importancestds,
            args.outputrootname + "_" + methodname + "_featureimportances_std.txt",
        )
        xaxis = range(len(importancemeans))
        if args.display:
            plt.errorbar(xaxis, importancemeans, yerr=importancestds)
            plt.show()


if __name__ == "__main__":
    main()
