#!/usr/bin/env python
#
#   Copyright 2016-2019 Blaise Frederick
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
#
#       $Author: frederic $
#       $Date: 2016/06/14 12:04:51 $
#       $Id: showstxcorr,v 1.11 2016/06/14 12:04:51 frederic Exp $
#
from __future__ import division, print_function

import getopt
import sys

import joblib
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
import rapidtide.filter as tide_filter
import rapidtide.fit as tide_fit
import rapidtide.io as tide_io
import rapidtide.miscmath as tide_math
import rapidtide.stats as tide_stats
from sklearn import metrics
from sklearn.cluster import DBSCAN, KMeans, MiniBatchKMeans
from sklearn.decomposition import PCA, FastICA, IncrementalPCA
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.feature_selection import (
    RFE,
    SelectFdr,
    SelectKBest,
    SelectPercentile,
    f_classif,
)
from sklearn.preprocessing import StandardScaler

try:
    import hdbscan as hdbs

    hdbpresent = True
    print("hdbscan is present")
except:
    hdbpresent = False

import capcalc.utils as capcalc_utils


def usage():
    print("")
    print("capfromtcs - calculate and cluster coactivation patterns for a set of timecourses")
    print("")
    print("usage: capfromtcs -i timecoursefile -o outputfile --samplefreq=FREQ --sampletime=TSTEP")
    print("                  [--nodetrend] [-s STARTTIME] [-D DURATION]")
    print("                  [-F LOWERFREQ,UPPERFREQ[,LOWERSTOP,UPPERSTOP]] [-V] [-L] [-R] [-C]")
    print(
        "                  [-m] [-n NUMCLUSTER] [-b BATCHSIZE] [-S SEGMENTSIZE] [-E SEGMENTTYPE] [-I INITIALIZATIONS]"
    )
    print(
        "                  [--noscale] [--nonorm] [--pctnorm] [--varnorm] [--stdnorm] [--ppnorm] [--quality]"
    )
    print("                  [--pca] [--ica] [-p NUMCOMPONENTS] --modelroot=MODELROOT")
    print("")
    print("required arguments:")
    print("    -i, --infile=TIMECOURSEFILE  - text file multiple timeseries")
    print("    -o, --outfile=OUTNAME        - the root name of the output files")
    print("")
    print("    --samplefreq=FREQ            - sample frequency of all timecourses is FREQ ")
    print("           or")
    print("    --sampletime=TSTEP           - time step of all timecourses is TSTEP ")
    print(
        "                                   NB: --samplefreq and --sampletime are two ways to specify"
    )
    print("                                   the same thing.")
    print("")
    print("optional arguments:")
    print("")
    print("  Data selection/partition:")
    print(
        "    -s STARTTIME                 - time of first datapoint to use in seconds in the first file"
    )
    print("    -D DURATION                  - amount of data to use in seconds")
    print("    -S SEGMENTSIZE,[SEGSIZE2,...SEGSIZEN]")
    print(
        "                                 - treat the timecourses as segments of length SEGMENTSIZE for preprocessing."
    )
    print("    -E SEGTYPE,SEGTYPE2[,...SEGTYPEN]")
    print(
        "                                 - group subsegments for summary statistics.  All subsegments in the same group must be the same length"
    )
    print(
        "                                   If there are multiple, comma separated numbers, treat these as subsegment lengths."
    )
    print("                                   Default segmentsize is the entire length")
    print("  --skippts=NUMPTS               - drop first NUMPTS points from each segment")
    print("  Clustering:")
    print(
        "    -m                           - run MiniBatch Kmeans rather than conventional - use with very large datasets"
    )
    print(
        "    -n NUMCLUSTER                - set the number of clusters to NUMCLUSTER (default is 8)"
    )
    print(
        "    -b BATCHSIZE                 - use a batchsize of BATCHSIZE if doing MiniBatch - ignored if not.  Default is 1000"
    )
    print("    --dbscan                     - perform dbscan clustering")
    print("    --hdbscan                    - perform hdbscan clustering")
    print(
        "    -I INITIALIZATIONS           - Restart KMeans INITIALIZATIONS times to find best fit (default is 1000)"
    )
    print("")
    print("  Preprocessing:")
    print(
        "    -F LOWERFREQ,UPPERFREQ       - filter data and regressors from LOWERFREQ to UPPERFREQ."
    )
    print(
        "                                   LOWERSTOP and UPPERSTOP can be specified, or will be calculated automatically"
    )
    print("    -V                           - filter data and regressors to VLF band")
    print("    -L                           - filter data and regressors to LFO band")
    print("    -R                           - filter data and regressors to respiratory band")
    print("    -C                           - filter data and regressors to cardiac band")
    print("    --nodetrend                  - do not detrend the data before correlation")
    print("    --noscale                    - don't perform vector magnitude scaling")
    print("    --nonorm                     - don't normalize timecourses")
    print("    --pctnorm                    - scale each timecourse to its percentage of the mean")
    print(
        "    --varnorm                    - scale each timecourse to have a variance of 1.0 (default)"
    )
    print(
        "    --stdnorm                    - scale each timecourse to have a standard deviation of 1.0"
    )
    print(
        "    --ppnorm                     - scale each timecourse to have a peak to peak range of 1.0"
    )
    print(
        "    --pca                        - perform PCA dimensionality reduction prior to analysis"
    )
    print(
        "    --ica                        - perform ICA dimensionality reduction prior to analysis"
    )
    print(
        "    -p NUMCOMPONENTS             - set the number of p/ica components to NUMCOMPONENTS (default is 8).  Set to -1 to estimate"
    )
    print("    --noscale                    - do not apply standard scaler before cluster fitting")
    print("    --preproconly                - do preprocessing then quit")
    print(
        "    --minout=MINOUT              - transitions out of a state shorter than MINOUT will be patched.  Default is 1"
    )
    print(
        "    --minhold=MINHOLD)           - time in a state shorter than MINHOLD will be assigned to the previous state.  Default is 1"
    )
    print("")
    print("  Other:")
    print(
        "    --GBR                        - apply gradient boosting regressor testing on clusters"
    )
    print("    -d                           - display some quality metrics")
    print("    --quality                    - perform a silhouette test to evaluate fit quality")
    print("    -v                           - turn on verbose mode")
    print(
        "    --modelroot=MODELROOT        - reread trained models from a previous run - MODELROOT should "
    )
    print(
        "                                   be the outputfile from the previous run (i.e. what followed -o)"
    )
    print("")
    return ()


# get the command line parameters
summaryonly = True

# preprocessing options
preprocessingtype = None
detrendorder = 1
timenormmethod = "varnorm"

# clustering/partitioning options
minibatch = False
n_clusters = 8
n_pca = 8
max_iter = 250
n_init = 100
batch_size = 1000
clustertype = "kmeans"
clustertype = "kmeans"
connfilename = None
affinity = "euclidean"
linkage = "ward"
eps = 0.3
min_samples = 100
alpha = 1.0
standardscale = True
skippts = 0
minoutlength = 1
minholdlength = 1

duration = 100000000.0
starttime = 0.0
usebutterworthfilter = False
filtorder = 3
verbose = False

doGBR = False
display = False
preproconly = False

trainedmodelroot = None


# scan the command line arguments
try:
    opts, args = getopt.gnu_getopt(
        sys.argv[1:],
        "di:o:s:D:F:S:E:VLRCmn:p:b:I:v",
        [
            "infile=",
            "outfile=",
            "nodetrend",
            "dbscan",
            "hdbscan",
            "GBR",
            "pca",
            "ica",
            "noscale",
            "preproconly",
            "nonorm",
            "pctnorm",
            "varnorm",
            "stdnorm",
            "ppnorm",
            "skippts=",
            "minout=",
            "minhold=",
            "modelroot=",
            "quality",
            "samplefreq=",
            "sampletime=",
            "help",
        ],
    )
except getopt.GetoptError as err:
    # print help information and exit:
    print(str(err))  # will print something like "option -x not recognized"
    usage()
    sys.exit(2)

if len(args) > 1:
    print("capfromtcs takes no unflagged arguments")
    print(args)
    sys.exit(2)

# unset all required arguments
infilename = []
segsize = -1
subsegs = []
subseggroupIDs = None
sampletime = None
Fs = None
outfilename = None

theprefilter = tide_filter.NoncausalFilter(transferfunc="butterworth")
theprefilter.setbutterorder(filtorder)

# set the default characteristics
theprefilter.settype("None")

for o, a in opts:
    if o == "--infile" or o == "-i":
        infilename.append(a)
        if verbose:
            print("will use", infilename[-1], "as an input file")
    elif o == "--outfile" or o == "-o":
        outfilename = a
        if verbose:
            print("will use", outfilename, "as output file")
    elif o == "-S":
        for seg in a.split(","):
            subsegs.append(int(seg))
        segsize = np.sum(np.asarray(subsegs))
        print("SUBSEGS:", subsegs)
        if verbose:
            print("Setting segment size to ", segsize)
    elif o == "-E":
        subseggroupIDs = []
        for seg in a.split(","):
            subseggroupIDs.append(seg)
        print("SUBSEGGROUPIDS:", subseggroupIDs)
    elif o == "--samplefreq":
        Fs = float(a)
        sampletime = 1.0 / Fs
        linkchar = "="
        if verbose:
            print("Setting sample frequency to ", Fs)
    elif o == "--sampletime":
        sampletime = float(a)
        Fs = 1.0 / sampletime
        linkchar = "="
        if verbose:
            print("Setting sample time step to ", sampletime)
    elif o == "-display":
        display = True
        if verbose:
            print("will display quality metrics")
    elif o == "--preproconly":
        preproconly = True
        if verbose:
            print("only do preprocessing through PCA/ICA")
    elif o == "--noscale":
        standardscale = False
        if verbose:
            print("will not magnitude scale feature vectors")
    elif o == "--nonorm":
        timenormmethod = "none"
        if verbose:
            print("will do no normalization")
    elif o == "--pctnorm":
        timenormmethod = "pctnorm"
        if verbose:
            print("will do percent normalization")
    elif o == "--stdnorm":
        timenormmethod = "stdnorm"
        if verbose:
            print("will do std dev normalization")
    elif o == "--varnorm":
        timenormmethod = "varnorm"
        if verbose:
            print("will do variance normalization")
    elif o == "--ppnorm":
        timenormmethod = "ppnorm"
        if verbose:
            print("will do p-p normalization")
    elif o == "--modelroot":
        trainedmodelroot = a
        if verbose:
            print("will read trained models from", trainedmodelroot, "_*.joblib")
    elif o == "--minhold":
        minholdlength = int(a)
        if verbose:
            print("residency in a state shorter than", minholdlength, "will be patched")
    elif o == "--minout":
        minoutlength = int(a)
        if verbose:
            print(
                "transitions out of a state shorter than",
                minoutlength,
                "will be patched",
            )
    elif o == "--skippts":
        skippts = int(a)
        if verbose:
            print("will drop first", skippts, "points from each segnent")
    elif o == "--quality":
        summaryonly = False
        if verbose:
            print("will do silhouette test")
    elif o == "-v":
        verbose = True
        if verbose:
            print("verbose mode enabled")
    elif o == "--ica":
        preprocessingtype = "ica"
        if verbose:
            print("will perform ica dimensionality reduction step")
    elif o == "--GBR":
        doGBR = True
        if verbose:
            print("will do GBR on clusters")
    elif o == "--pca":
        preprocessingtype = "pca"
        if verbose:
            print("will perform pca dimensionality reduction step")
    elif o == "--hdbscan":
        clustertype = "hdbscan"
        if not hdbpresent:
            print("hdbs is not installed, cannot perform hdbscan clustering.  Exiting")
            sys.exit()
        if verbose:
            print("switching to hdbscan clustering")
    elif o == "--dbscan":
        clustertype = "dbscan"
        if verbose:
            print("switching to dbscan clustering")
    elif o == "--nodetrend":
        detrendorder = 0
        if verbose:
            print("disabling detrending")
    elif o == "-D":
        duration = float(a)
        if verbose:
            print("duration set to", duration)
    elif o == "-s":
        starttime = float(a)
        if verbose:
            print("starttime set to", starttime)
    elif o == "-V":
        theprefilter.settype("vlf")
        if verbose:
            print("prefiltering to vlf band")
    elif o == "-L":
        theprefilter.settype("lfo")
        if verbose:
            print("prefiltering to lfo band")
    elif o == "-R":
        theprefilter.settype("resp")
        if verbose:
            print("prefiltering to respiratory band")
    elif o == "-C":
        theprefilter.settype("card")
        if verbose:
            print("prefiltering to cardiac band")
    elif o == "-F":
        arbvec = a.split(",")
        if len(arbvec) != 2 and len(arbvec) != 4:
            usage()
            sys.exit()
        if len(arbvec) == 2:
            arb_lower = float(arbvec[0])
            arb_upper = float(arbvec[1])
            arb_lowerstop = 0.9 * float(arbvec[0])
            arb_upperstop = 1.1 * float(arbvec[1])
        if len(arbvec) == 4:
            arb_lower = float(arbvec[0])
            arb_upper = float(arbvec[1])
            arb_lowerstop = float(arbvec[2])
            arb_upperstop = float(arbvec[3])
        theprefilter.settype("arb")
        theprefilter.setfreqs(arb_lowerstop, arb_lower, arb_upper, arb_upperstop)
        if verbose:
            print(
                "prefiltering to ",
                arb_lower,
                arb_upper,
                "(stops at ",
                arb_lowerstop,
                arb_upperstop,
                ")",
            )
    elif o == "-m":
        minibatch = True
        print("will perform MiniBatchKMeans")
    elif o == "-b":
        batch_size = int(a)
        print("will use", batch_size, "as batch_size")
    elif o == "-I":
        n_init = int(a)
        print("will do", n_init, "initializations")
    elif o == "-n":
        n_clusters = int(a)
        print("will use", n_clusters, "clusters")
    elif o == "-p":
        n_pca = float(a)
        if n_pca <= 0.0:
            print("will estimate the number of pca components for dimensionality reduction")
        elif n_pca < 1.0:
            print(
                "will use enough pca components to explain at least",
                100.0 * n_pca,
                "% of the variance",
            )
        else:
            n_pca = int(a)
            print("will use", n_pca, "pca components for dimensionality reduction")
    else:
        assert False, "unhandled option"

# check that required arguments are set
if outfilename is None:
    print("outfile must be set")
    usage()
    sys.exit()

# check to make sure groups and subsegments are in agreement
groups = {}
if subsegs is not None:
    print("subsegs:", subsegs)
    if subseggroupIDs is None:
        subseggroupIDs = []
        for i in range(len(subsegs)):
            subseggroupIDs.append("group_" + str(i))
    print("subseggroupIDs:", subseggroupIDs)
    if len(subsegs) != len(subseggroupIDs):
        print("number of subsegment group IDs must match number of subsegs")
        sys.exit()
    for segnum in range(len(subseggroupIDs)):
        try:
            groups[subseggroupIDs[segnum]]["segnum"].append(int(segnum))
        except:
            groups[subseggroupIDs[segnum]] = {}
            groups[subseggroupIDs[segnum]]["seglen"] = []
            groups[subseggroupIDs[segnum]]["segstart"] = []
            groups[subseggroupIDs[segnum]]["segnum"] = [int(segnum)]
        groups[subseggroupIDs[segnum]]["seglen"].append(subsegs[segnum])
        groups[subseggroupIDs[segnum]]["segstart"].append(int(np.sum(subsegs[0:segnum])))
    for key in groups:
        groupsegs = groups[key]["seglen"]
        if not all(x == groupsegs[0] for x in groupsegs):
            print("all subsegments in a group must have the same length")
            sys.exit()
    print(groups)

if sampletime is None:
    print("sampletime must be set")
    usage()
    sys.exit()

if timenormmethod == "none":
    print("will not normalize timecourses")
elif timenormmethod == "pctnorm":
    print("will normalize timecourses to percentage of mean")
elif timenormmethod == "stdnorm":
    print("will normalize timecourses to standard deviation of 1.0")
elif timenormmethod == "varnorm":
    print("will normalize timecourses to variance of 1.0")
elif timenormmethod == "ppnorm":
    print("will normalize timecourses to p-p deviation of 1.0")

# save the command line
tide_io.writevec([" ".join(sys.argv)], outfilename + "_commandline.txt")

# read in the files and get everything trimmed to the right length
startpoint = max([int(starttime * Fs), 0]) + skippts
if len(infilename) == 1:
    # each column is a timecourse, each row is a timepoint
    print("processing single input file")
    matrixoutput = True
    inputdata = tide_io.readvecs(infilename[0])[skippts:]
    if verbose:
        print("input data shape is ", inputdata.shape)
    numpoints = inputdata.shape[1]
    endpoint = min([startpoint + int(duration * Fs), numpoints])
    trimmeddata = inputdata[:, startpoint:endpoint]
elif len(infilename) == 2:
    print("processing two input files")
    inputdata1 = tide_io.readvec(infilename[0])
    numpoints = len(inputdata1)
    inputdata2 = tide_io.readvec(infilename[1])
    endpoint1 = min([startpoint + int(duration * Fs), int(len(inputdata1)), int(len(inputdata2))])
    endpoint2 = min([int(duration * Fs), int(len(inputdata1)), int(len(inputdata2))])
    trimmeddata = np.zeros((2, numpoints), dtype="float")
    trimmeddata[0, :] = inputdata1[startpoint:endpoint1]
    trimmeddata[1, :] = inputdata2[0:endpoint2]
else:
    print(
        "showstxcorr requires 1 multicolumn timecourse file or two single column timecourse files as input"
    )
    usage()
    sys.exit()

# band limit the regressors if that is needed
if theprefilter.gettype() != "None":
    if verbose:
        print("filtering to ", theprefilter.gettype(), " band")
else:
    if verbose:
        print("no prefiltering applied")

origdims = inputdata.shape
thedims = trimmeddata.shape
print("original file dimensions:", origdims)
print("trimmed file dimensions:", thedims)
n_features = thedims[0]
n_samples = thedims[1]
if segsize < 0:
    segsize = n_samples
    subsegs.append(segsize)
print(
    "input dataset has",
    n_features,
    "features and",
    n_samples,
    "samples in segments of size",
    segsize,
)
if len(subsegs) > 1:
    print("    segment is broken into", len(subsegs), "subsegments of length", subsegs)
reformdata = np.reshape(trimmeddata, (n_features, n_samples))
if n_samples % segsize > 0:
    print(
        "segment size (",
        segsize,
        ") is not an even divisor of the total length (",
        n_samples,
        ")- exiting",
    )
    sys.exit()
else:
    numsegs = int(n_samples // segsize)

for feature in range(n_features):
    if verbose:
        print("preprocessing feature", feature)
    for segment in range(numsegs):
        subsegstart = segment * segsize
        for subseglen in subsegs:
            if detrendorder > 0:
                segdata = tide_fit.detrend(
                    reformdata[feature, subsegstart : subsegstart + subseglen]
                )
            else:
                segdata = reformdata[feature, subsegstart : subsegstart + subseglen]

            if timenormmethod == "none":
                segnorm = segdata - np.mean(segdata)
            elif timenormmethod == "pctnorm":
                segnorm = tide_math.pcnormalize(segdata)
            elif timenormmethod == "varnorm":
                segnorm = tide_math.varnormalize(segdata)
            elif timenormmethod == "stdnorm":
                segnorm = tide_math.stdnormalize(segdata)
            elif timenormmethod == "ppnorm":
                segnorm = tide_math.ppnormalize(segdata)
            else:
                segnorm = segdata

            reformdata[feature, subsegstart : subsegstart + subseglen] = theprefilter.apply(
                Fs, segnorm
            )
            subsegstart += subseglen
X = np.nan_to_num(np.transpose(reformdata))

if standardscale:
    X = StandardScaler().fit_transform(X)

if preprocessingtype == "pca":
    print("running PCA")
    print("shape going in:", X.shape)
    if trainedmodelroot is None:
        print("running PCA")
        if n_pca <= 0:
            thepca = PCA(n_components="mle", svd_solver="full").fit(X)
        else:
            thepca = PCA(n_components=n_pca).fit(X)

        # save the model
        joblib.dump(thepca, outfilename + "_pca.joblib")
    else:
        modelfilename = trainedmodelroot + "_pca.joblib"
        print("reading PCA from", modelfilename)
        try:
            thepca = joblib.load(modelfilename)
        except Exception as ex:
            template = (
                "An exception of type {0} occurred when trying to open {1}. Arguments:\n{2!r}"
            )
            message = template.format(type(ex).__name__, modelfilename, ex.args)
            print(message)
            sys.exit()

    thetransform = thepca.transform(X)
    X = thepca.inverse_transform(thetransform)
    print("shape coming out:", X.shape)
    for i in range(thepca.n_components_):
        print(
            "component",
            i,
            "explained variance:",
            thepca.explained_variance_[i],
            "explained variance %:",
            100.0 * thepca.explained_variance_ratio_[i],
        )
    tide_io.writenpvecs(thepca.components_, outfilename + "_pcacomponents.txt")
    tide_io.writenpvecs(
        np.transpose(thepca.components_), outfilename + "_pcacomponents_transpose.txt"
    )
elif preprocessingtype == "ica":
    print("running FastICA")
    if n_pca <= 1.0:
        n_pca = int(0)
    if trainedmodelroot is None:
        theica = FastICA(n_components=n_pca, algorithm="deflation").fit(X)

        # save the model
        joblib.dump(theica, outfilename + "_ica.joblib")
    else:
        modelfilename = trainedmodelroot + "_ica.joblib"
        print("reading ICA from", modelfilename)
        try:
            theica = joblib.load(modelfilename)
        except Exception as ex:
            template = (
                "An exception of type {0} occurred when trying to open {1}. Arguments:\n{2!r}"
            )
            message = template.format(type(ex).__name__, modelfilename, ex.args)
            print(message)
            sys.exit()

    thetransform = theica.transform(X)
    X = theica.inverse_transform(thetransform)
    tide_io.writenpvecs(theica.components_, outfilename + "_icacomponents.txt")
    tide_io.writenpvecs(
        np.transpose(theica.components_), outfilename + "_icacomponents_transpose.txt"
    )

tide_io.writenpvecs(reformdata, outfilename + "_preprocessed.txt")
if preproconly:
    print("preprocessing done - quitting")
    sys.exit()

if clustertype == "kmeans":
    print("setting up kmeans")
    if trainedmodelroot is None:
        print("training model")
        if minibatch:
            kmeans = MiniBatchKMeans(
                n_clusters=n_clusters, batch_size=batch_size, max_iter=max_iter
            ).fit(X)
        else:
            kmeans = KMeans(n_clusters=n_clusters, max_iter=max_iter, n_init=n_init).fit(X)

        # save the model
        joblib.dump(kmeans, outfilename + "_kmeans.joblib")
    else:
        modelfilename = trainedmodelroot + "_kmeans.joblib"
        print("reading kmeans model from", modelfilename)
        try:
            kmeans = joblib.load(modelfilename)
        except Exception as ex:
            template = (
                "An exception of type {0} occurred when trying to open {1}. Arguments:\n{2!r}"
            )
            message = template.format(type(ex).__name__, modelfilename, ex.args)
            print(message)
            sys.exit()

    theclusters = np.transpose(kmeans.cluster_centers_)
    thestatelabels = kmeans.predict(X)
    # thestatelabels = kmeans.labels_
    print("thestatelabels shape", thestatelabels.shape)
    print("kmeans done")
    tide_io.writenpvecs(theclusters, outfilename + "_clustercenters.txt")

    # make normalized clusters
    thenormclusters = theclusters * 0.0
    themeans = np.mean(theclusters, axis=0)
    thestds = np.std(theclusters, axis=0)
    print("themeans:", themeans)
    print("thestds:", thestds)
    print("shape:", theclusters.shape)
    for i in range(theclusters.shape[1]):
        thenormclusters[:, i] = (theclusters[:, i] - themeans[i]) / thestds[i]
    tide_io.writenpvecs(thenormclusters, outfilename + "_norm_clustercenters.txt")

    # save the states
    tide_io.writenpvecs(thestatelabels, outfilename + "_statelabels.txt")

    # find most important features
    print("finding most important features")
    # rfe = RFE(kmeans, 10)
    # rfe.fit(X, thestatelabels)
    # print(rfe.support_)
    # print(rfe.ranking_)

    print(
        "calling SelectPercentiles with X and y of dimensions",
        X.shape,
        thestatelabels.shape,
    )
    selector = SelectPercentile(f_classif, percentile=10)
    selector.fit(X, thestatelabels)
    print(selector.get_params())
    X_indices = np.arange(X.shape[-1])
    scores = -np.nan_to_num(np.log10(np.nan_to_num(selector.pvalues_)))
    scores /= scores.max()
    sortedscores = np.sort(np.nan_to_num(selector.scores_))[::-1]
    print(sortedscores)
    if display:
        plt.bar(
            X_indices - 0.45,
            scores,
            width=0.2,
            label=r"Univariate score ($-Log(p_{value})$)",
            color="darkorange",
        )
        print(selector.get_support(indices=True))
        fig = plt.subplots(1, 1)
        plt.plot(sortedscores)
        plt.show()

    # now do some stats!
    thesilavgs, thesilclusterstats = capcalc_utils.silhouette_test(
        X, kmeans, n_clusters, numsegs, segsize, summaryonly
    )
    tide_io.writenpvecs(thesilavgs, outfilename + "_silhouettesegmentstats.txt")

    silinfo = []
    for state in range(n_clusters):
        silinfo.append([])
    print("shape going in:", thestatelabels.shape)
    statelabelsbysegment = np.reshape(thestatelabels, (-1, segsize))
    print("shape coming out:", statelabelsbysegment.shape)
    meaninstate = np.zeros((n_clusters, segsize), dtype="float")
    stdinstate = np.zeros((n_clusters, segsize), dtype="float")

    # do the subsegment summaries
    for key in groups:
        groups[key]["meaninstate"] = np.zeros(
            (n_clusters, groups[key]["seglen"][0]), dtype="float"
        )
        groups[key]["stdinstate"] = np.zeros((n_clusters, groups[key]["seglen"][0]), dtype="float")
        for state in range(n_clusters):
            tcbyseg = []
            for seginstance in range(len(groups[key]["segnum"])):
                startpos = groups[key]["segstart"][seginstance]
                endpos = startpos + groups[key]["seglen"][seginstance]
                tcbyseg.append(np.where(statelabelsbysegment[:, startpos:endpos] == state, 1, 0))
            groups[key]["meaninstate"][state, :] = np.mean(np.concatenate(tcbyseg, axis=0), axis=0)
            groups[key]["stdinstate"][state, :] = np.std(np.concatenate(tcbyseg, axis=0), axis=0)
        tide_io.writenpvecs(
            groups[key]["meaninstate"],
            outfilename + "_" + str(key) + "_meaninstate.txt",
        )
        tide_io.writenpvecs(
            groups[key]["stdinstate"], outfilename + "_" + str(key) + "_stdinstate.txt"
        )
    allstatestats = []
    alllenlists = []
    for i in range(n_clusters):
        alllenlists.append([])
    for segment in range(numsegs):
        thesestatelabels = thestatelabels[segment * segsize : (segment + 1) * segsize]

        outputaffine = np.eye(4)
        rawtransmat, thestats, lenlist = capcalc_utils.statestats(
            thesestatelabels, n_clusters, 0, minout=minoutlength, minhold=minholdlength
        )
        allstatestats.append(thestats)
        for i in range(n_clusters):
            alllenlists[i] += lenlist[i]
        normtransmat = 1.0 * rawtransmat
        for i in range(n_clusters):
            if np.sum(rawtransmat[i, :]) > 0.0:
                normtransmat[i, :] /= np.sum(rawtransmat[i, :])
        offdiagtransmat = 1.0 * rawtransmat
        for i in range(n_clusters):
            offdiagtransmat[i, i] = 0.0
            if np.sum(offdiagtransmat[i, :]) > 0.0:
                offdiagtransmat[i, :] /= np.sum(offdiagtransmat[i, :])
        init_img = nib.Nifti1Image(normtransmat, outputaffine)
        init_hdr = init_img.header
        init_sizes = init_hdr["pixdim"]
        tide_io.savetonifti(
            np.transpose(rawtransmat),
            init_hdr,
            outfilename + "_seg_" + str(segment).zfill(4) + "_rawtransmat",
        )
        tide_io.savetonifti(
            np.transpose(normtransmat),
            init_hdr,
            outfilename + "_seg_" + str(segment).zfill(4) + "_normtransmat",
        )
        tide_io.savetonifti(
            np.transpose(offdiagtransmat),
            init_hdr,
            outfilename + "_seg_" + str(segment).zfill(4) + "_offdiagtransmat",
        )

        # write as text as well
        rows = []
        cols = []
        for i in range(n_clusters):
            rows.append("from state " + str(i + 1))
            cols.append("to state " + str(i + 1))
        df = pd.DataFrame(data=rawtransmat, columns=cols)
        df.insert(0, "sources", pd.Series(rows))
        df.to_csv(
            outfilename + "_seg_" + str(segment).zfill(4) + "_rawtransmat.csv",
            index=False,
        )
        df = pd.DataFrame(data=normtransmat, columns=cols)
        df.insert(0, "sources", pd.Series(rows))
        df.to_csv(
            outfilename + "_seg_" + str(segment).zfill(4) + "_normtransmat.csv",
            index=False,
        )
        df = pd.DataFrame(data=offdiagtransmat, columns=cols)
        df.insert(0, "sources", pd.Series(rows))
        df.to_csv(
            outfilename + "_seg_" + str(segment).zfill(4) + "_offdiagtransmat.csv",
            index=False,
        )
        # rawtransmat files are an n_clusters by n_clusters matrix with the total number of transitions from each state to each other state.
        # normtransmat files are an n_clusters by n_clusters matrix with the total for of transitions from each state to each other state.

        cols = [
            "% TRs in state",
            "Number of runs in state",
            "Total TRs in state",
            "Min run (TRs)",
            "Max run (TRs)",
            "Mean run (TRs)",
            "Median run (TRs)",
            "StdDev run (TRs)",
        ]
        df = pd.DataFrame(data=thestats, columns=cols)
        df.to_csv(
            outfilename + "_seg_" + str(segment).zfill(4) + "_statestats.csv",
            index=False,
        )
        # tide_io.writenpvecs(np.transpose(thestats), outfilename + '_seg_' + str(segment).zfill(4) + '_statestats.txt')
        thetimestats = 1.0 * thestats
        thetimestats[:, 2:] *= sampletime
        cols = [
            "% Seconds in state",
            "Number of runs in state",
            "Total seconds in state",
            "Min run (sec)",
            "Max run (sec)",
            "Mean run (sec)",
            "Median run (sec)",
            "StdDev run (sec)",
        ]
        df = pd.DataFrame(data=thetimestats, columns=cols)
        df.to_csv(
            outfilename + "_seg_" + str(segment).zfill(4) + "_statetimestats.csv",
            index=False,
        )
        # tide_io.writenpvecs(np.transpose(thetimestats), outfilename + '_seg_' + str(segment).zfill(4) + '_statetimestats.txt')

        tide_io.writenpvecs(
            thesestatelabels,
            outfilename + "_seg_" + str(segment).zfill(4) + "_statelabels.txt",
        )
        print("Segment %d average silhouette Coefficient: %0.3f" % (segment, thesilavgs[segment]))
        for state in range(n_clusters):
            tc = np.where(thesestatelabels == state, 1, 0)
            tide_io.writenpvecs(
                tc,
                outfilename
                + "_seg_"
                + str(segment).zfill(4)
                + "_instate_"
                + str(state).zfill(2)
                + ".txt",
            )
        if not summaryonly:
            cols = ["Mean", "Median", "Min", "Max"]
            df = pd.DataFrame(data=np.transpose(thesilclusterstats[segment, :, :]), columns=cols)
            df.to_csv(
                outfilename + "_seg_" + str(segment).zfill(4) + "_silhouetteclusterstats.csv",
                index=False,
            )
            # tide_io.writenpvecs(thesilclusterstats[segment, :, :],
            #             outfilename + '_seg_' + str(segment).zfill(4) + '_silhouetteclusterstats.txt')

        for state in range(n_clusters):
            if thestats[state, 2] > 0:
                silinfo[state].append(thesilclusterstats[segment, 0, state])

    # now generate some summary information
    themaxlen = 0
    for i in range(n_clusters):
        themaxlen = int(np.max([themaxlen, np.max(alllenlists[i])]))
    for i in range(n_clusters):
        thishist = tide_stats.makeandsavehistogram(
            np.array(alllenlists[i]),
            themaxlen,
            0,
            outfilename + "_" + str(i).zfill(2) + "_lenhist",
            therange=[1, themaxlen],
        )
    silavgs = []
    if not summaryonly:
        for state in range(n_clusters):
            silavgs.append(np.mean(np.asarray(silinfo[state], dtype="float")))
        tide_io.writenpvecs(
            np.asarray(silavgs, dtype="float"),
            outfilename + "_overallsilhouettemean.txt",
        )
    pctarray = np.asarray(allstatestats[:], dtype="float")
    cols = [
        "% TRs in state",
        "Number of runs in state",
        "Total TRs in state",
        "Min run (TRs)",
        "Max run (TRs)",
        "Mean run (TRs)",
        "Median run (TRs)",
        "StdDev run (TRs)",
    ]
    df = pd.DataFrame(data=np.mean(pctarray, axis=0), columns=cols)
    df.to_csv(
        outfilename + "_seg_" + str(segment).zfill(4) + "_overallmeanstats.csv",
        index=False,
    )
    # tide_io.writenpvecs(np.transpose(np.mean(pctarray, axis=0)), outfilename + '_overallmeanstats.txt')

    if doGBR:
        clf = GradientBoostingRegressor().fit(X, thestatelabels)
        print("GBR fitting score is:", clf.score(X, thestatelabels))
        tide_io.writenpvecs(
            np.reshape(clf.feature_importances_, (n_features, 1)),
            outfilename + "_featureimportances.txt",
        )

elif clustertype == "dbscan":
    if trainedmodelroot is None:
        db = DBSCAN(eps=eps, min_samples=min_samples, n_jobs=-1).fit(X)

        # save the model
        joblib.dump(db, outfilename + "_dbscan.joblib")
    else:
        modelfilename = trainedmodelroot + "_dbscan.joblib"
        print("reading dbscan model from", modelfilename)
        try:
            db = joblib.load(modelfilename)
        except Exception as ex:
            template = (
                "An exception of type {0} occurred when trying to open {1}. Arguments:\n{2!r}"
            )
            message = template.format(type(ex).__name__, modelfilename, ex.args)
            print(message)
            sys.exit()

        db.predict(X)

    print("dbscan done")

    # core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
    # core_samples_mask[db.core_sample_indices_] = True

    thestatelabels = db.labels_
    print(thestatelabels)
    print("thestatelabels shape", thestatelabels.shape)
    tide_io.writenpvecs(thestatelabels, outfilename + "_statelabels.txt")

    print("core_sample_indices:", db.core_sample_indices_)
    core_centers = np.transpose(X[db.core_sample_indices_, :])
    tide_io.writenpvecs(core_centers, outfilename + "_core_centers.txt")

    # Number of clusters in labels, ignoring noise if present.
    n_clusters_ = len(set(thestatelabels)) - (1 if -1 in thestatelabels else 0)
    print("Estimated number of clusters: %d" % n_clusters_)

    methodname = "dbscan_" + str(n_clusters).zfill(2)

elif clustertype == "hdbscan":
    if trainedmodelroot is None:
        hdb = hdbs.HDBSCAN(
            min_samples=min_samples,
            alpha=alpha,
            memory="/Users/frederic/Documents/MR_data/connectome/movies",
        ).fit(X)

        # save the model
        joblib.dump(hdb, outfilename + "_hdbscan.joblib")
    else:
        modelfilename = trainedmodelroot + "_hdbscan.joblib"
        print("reading hdbscan model from", modelfilename)
        try:
            hdb = joblib.load(modelfilename)
        except Exception as ex:
            template = (
                "An exception of type {0} occurred when trying to open {1}. Arguments:\n{2!r}"
            )
            message = template.format(type(ex).__name__, modelfilename, ex.args)
            print(message)
            sys.exit()

        hdb.predict(X)

    thestatelabels = hdb.labels_
    print(thestatelabels)

    # Number of clusters in labels, ignoring noise if present.
    n_clusters_ = len(set(thestatelabels)) - (1 if -1 in thestatelabels else 0)

    print("Estimated number of clusters: %d" % n_clusters_)
    methodname = "hdbscan_" + str(n_clusters).zfill(2)

else:
    print("unknown clustering type")
    sys.exit()
