#!/usr/bin/env python
#
#   Copyright 2016 Blaise Frederick
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
#
#       $Author: frederic $
#       $Date: 2016/06/14 12:04:50 $
#       $Id: linfit,v 1.4 2016/06/14 12:04:50 frederic Exp $
#
import argparse
import os

import numpy as np
import rapidtide.io as tide_io
import rapidtide.miscmath as tide_math
import rapidtide.workflows.parser_funcs as pf
import scipy.sparse as ss
from pylab import *
from scipy.cluster.hierarchy import dendrogram
from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans, MiniBatchKMeans
from sklearn.decomposition import PCA, FastICA
from sklearn.manifold import TSNE
from sklearn.metrics import davies_bouldin_score, silhouette_score
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import RobustScaler, StandardScaler

clusteringmethods = ["kmeans", "agglomerative", "dbscan"]

try:
    import hdbscan as hdbs

    clusteringmethods.append("hdbscan")
except:
    pass


def plot_dendrogram(model, **kwargs):

    # Children of hierarchical clustering
    children = model.children_

    # Distances between each pair of children
    # Since we don't have this information, we can use a uniform one for plotting
    distance = np.arange(children.shape[0])

    # The number of observations contained in each cluster level
    no_of_observations = np.arange(2, children.shape[0] + 2)

    # Create linkage matrix and then plot the dendrogram
    linkage_matrix = np.column_stack([children, distance, no_of_observations]).astype(float)

    # Plot the corresponding dendrogram
    dendrogram(linkage_matrix, **kwargs)


def save_sparse_csr(filename, array):
    np.savez(
        filename,
        data=array.data,
        indices=array.indices,
        indptr=array.indptr,
        shape=array.shape,
    )


def load_sparse_csr(filename):
    loader = np.load(filename)
    return ss.csr_matrix(
        (loader["data"], loader["indices"], loader["indptr"]), shape=loader["shape"]
    )


def N2one(loc, strides):
    retval = 0
    for i in range(len(loc) - 1):
        retval += loc[i] * strides[1 - i]
    retval += loc[-1]
    return retval


def three2one(loc, strides):
    return loc[0] * strides[1] + loc[1] * strides[0] + loc[2]


def one2N(index, strides):
    coords = []
    localindex = index + 0
    for i in range(len(strides)):
        coords.append(int(np.floor(localindex / strides[i - 1])))
        localindex -= coords[-1] * strides[i - 1]
    coords.append(np.mod(localindex, strides[0]))
    return coords


def one2three(index, strides):
    x = int(np.floor(index / strides[1]))
    y = int(np.floor((index - x * strides[1]) / strides[0]))
    z = int(np.mod(index, strides[0]))
    return [x, y, z]


def mkconnectivity(ptlist, radius, theshape, dodiag=False, dims=None):
    # convert the 1d index list to Nd locations
    ndmods = [theshape[-1]]
    for i in range(1, len(theshape) - 1):
        ndmods.append(theshape[-1 - i] * ndmods[i - 1])
    ptlistconv = []
    if len(theshape) != 3:
        # special case, 3d array
        for thepoint in ptlist:
            ptlistconv.append(one2three(thepoint, ndmods))
    else:
        for thepoint in ptlist:
            ptlistconv.append(one2N(thepoint, ndmods))

    # now make the connectivity matrix for nearest neighbors with given radius
    print("len ptlist =", len(ptlist))
    sm = ss.lil_matrix((len(ptlist), len(ptlist)), dtype="int")

    # sanity check first
    rangeinpts = int(np.floor(radius))
    if rangeinpts < 1:
        print("radius too small - no points in connectivity matrix")
        sys.exit()

    # now iterate over every pair
    nonzeroelems = 0
    radsq = radius * radius
    ptlistnp = np.asarray(ptlistconv)
    reportstep = 1000
    checkpointstep = 20000
    for ridx in range(len(ptlist)):
        if ridx % reportstep == 0:
            print(
                "row:",
                ridx,
                ", nonzero elements:",
                nonzeroelems,
                ", loc:",
                ptlistconv[ridx],
            )
        if ridx % checkpointstep == 0:
            print("checkpoint file")
            save_sparse_csr("connectivity_checkpoint", sm.tocsr())
        distarray = np.sum(np.square(ptlistnp[:, :] - ptlistnp[ridx, :]), axis=1)
        print()
        matchlocs = np.where(distarray < radsq)[0]
        sm[ridx, matchlocs] = 1
        nonzeroelems += len(matchlocs)
    print(nonzeroelems, "nonzero elements")

    # now stuff this into a sparse matrix
    # sm = ss.lil_matrix((len(ptlist), len(ptlist)), dtype='int')
    # for thepair in pairlist:
    #    sm[thepair[0], thepair[1]] = 1
    return sm.tocsr()


DEFAULT_NCLUSTERS = 8
DEFAULT_NCOMPONENTS = 8
DEFAULT_MINCLUSTERSIZE = 50
DEFAULT_REPEATS = 1
DEFAULT_NINIT = 100


def _get_parser():
    # get the command line parameters
    parser = argparse.ArgumentParser(
        prog="clusternifti",
        description="Generates a histogram of the values in a NIFTI file.",
        usage="%(prog)s  datafile outputroot",
    )
    parser.add_argument(
        "datafilename", help="The name of the 4 dimensional nifti file to cluster."
    )
    parser.add_argument("outputrootname", help="The root of the output file names.")

    parser.add_argument(
        "--maskfile",
        dest="maskfilename",
        type=lambda x: pf.is_valid_file(parser, x),
        metavar="MASK",
        help="Only process voxels within the 3D mask MASK.",
        default=None,
    )
    parser.add_argument(
        "--maskthresh",
        dest="maskthresh",
        action="store",
        type=lambda x: pf.is_float(parser, x),
        metavar="THRESH",
        help="Mask cutoff value.  Default is 0.5",
        default=0.5,
    )
    parser.add_argument(
        "--datareduction",
        dest="datareduction",
        action="store",
        type=str,
        choices=["ica", "pca"],
        help=(
            "Select an initial data reduction step. "
            "Default is not to perform dimensionality reduction."
        ),
        default=None,
    )
    parser.add_argument(
        "--n_components",
        dest="n_pca",
        action="store",
        type=lambda x: pf.is_float(parser, x),
        metavar="NCOMPONENTS",
        help=f"Number of components to use for dimensionality reduction.  Default is {DEFAULT_NCOMPONENTS}.",
        default=DEFAULT_NCOMPONENTS,
    )
    parser.add_argument(
        "--scaler",
        dest="scaler",
        action="store",
        type=str,
        choices=["robust", "standard"],
        help=("Type of prescaler to use. Default is None."),
        default=None,
    )
    parser.add_argument(
        "--normalize",
        dest="normalize",
        action="store_true",
        help=("Normalize each vector along the time dimension."),
        default=False,
    )

    parser.add_argument(
        "--clustertype",
        dest="clustertype",
        action="store",
        type=str,
        choices=clusteringmethods,
        help=("Select clustering type.  Default is kmeans. "),
        default="kmeans",
    )
    parser.add_argument(
        "--linkage",
        dest="linkage",
        action="store",
        type=str,
        choices=["ward"],
        help=("Select linkage type.  Default is ward. "),
        default="ward",
    )
    parser.add_argument(
        "--affinity",
        dest="affinity",
        action="store",
        type=str,
        choices=["euclidean"],
        help=("Select affinity type.  Default is euclidian. "),
        default="euclidian",
    )
    parser.add_argument(
        "--n_clusters",
        dest="n_clusters",
        action="store",
        type=lambda x: pf.is_int(parser, x),
        metavar="NCLUSTERS",
        help=f"Number of clusters.  Default is {DEFAULT_NCLUSTERS}.",
        default=DEFAULT_NCLUSTERS,
    )
    parser.add_argument(
        "--repeats",
        dest="repeats",
        action="store",
        type=lambda x: pf.is_int(parser, x),
        metavar="REPS",
        help=f"Number of times to perform clustering.  Default is {DEFAULT_REPEATS}.",
        default=DEFAULT_REPEATS,
    )
    parser.add_argument(
        "--display",
        dest="display",
        action="store_true",
        help=("Display intitial tSNE map."),
        default=False,
    )
    parser.add_argument(
        "--min_samples",
        dest="min_samples",
        action="store",
        type=lambda x: pf.is_int(parser, x),
        metavar="MIN_SAMPLES",
        help="Minimum number of samples.  Default is 100",
        default=100,
    )
    parser.add_argument(
        "--eps",
        dest="eps",
        action="store",
        type=lambda x: pf.is_float(parser, x),
        metavar="EPS",
        help="Epsilon.  Default is 0.3",
        default=0.3,
    )
    parser.add_argument(
        "--alpha",
        dest="alpha",
        action="store",
        type=lambda x: pf.is_float(parser, x),
        metavar="ALPHA",
        help="Alpha.  Default is 1.0",
        default=1.0,
    )
    parser.add_argument(
        "--min_cluster_size",
        dest="min_cluster_size",
        action="store",
        type=lambda x: pf.is_int(parser, x),
        metavar="MIN_CLUSTER_SIZE",
        help=f"Minimum cluster size.  Default is {DEFAULT_MINCLUSTERSIZE}",
        default=DEFAULT_MINCLUSTERSIZE,
    )
    parser.add_argument(
        "--nominibatch",
        dest="minibatch",
        action="store_false",
        help=("Disable minibatches for kmeans clustering."),
        default=True,
    )
    parser.add_argument(
        "--batch_size",
        dest="batch_size",
        action="store",
        type=lambda x: pf.is_int(parser, x),
        metavar="BATCH_SIZE",
        help="Minibatch size.  Default is 1000",
        default=1000,
    )
    parser.add_argument(
        "--max_iter",
        dest="max_iter",
        action="store",
        type=lambda x: pf.is_int(parser, x),
        metavar="MAX_ITER",
        help="Maximum number of iterations.  Default is 250",
        default=250,
    )
    parser.add_argument(
        "--n_init",
        dest="n_init",
        action="store",
        type=lambda x: pf.is_int(parser, x),
        metavar="N_INIT",
        help=f"N initial.  Default is {DEFAULT_NINIT}",
        default=DEFAULT_NINIT,
    )
    return parser


def usage():
    print("usage: clusternifti datafile outputroot")
    print("")
    print("required arguments:")
    print("    datafile      - the name of the 4 dimensional nifti file to cluster")
    print("    outputroot    - the root name of the output nifti files")
    print("")
    print("optional arguments:")
    print(
        "    --dmask=DATAMASK            - use DATAMASK to specify which voxels in the data to use"
    )
    print("    --prescale                  - prescale data prior to clustering")
    print(
        "    --nclusters=NCLUSTERS       - set the number of clusters to NCLUSTERS (default is 8)"
    )
    print(
        "    --type=CLUSTERTYPE          - set the clustering type (options are agglomerative, kmeans,"
    )
    print(
        "                                  dbscan (and hdbscan if installed). Default is kmeans)"
    )
    print("    --connfile=CONNECTIVITY     - use a precomputed connectivity file)")
    print("    --eps=EPS                   - set eps to EPS")
    print("    --alpha=ALPHA               - set alpha to ALPHA")
    print("    --min_samples=MINSAMPLES    - set min_samples to MINSAMPLES")
    print("    --min_cluster_size=SIZE     - set min_cluster_size to SIZE")
    print("    --affinity=AFFINITY         - set affinity to AFFINITY")
    print("    --linkage=LINKAGE           - set linkage to LINKAGE")
    print("    --radius=RADIUS             - set connectivity radius to RADIUS")
    print("    --noconn                    - do not use a connectivity matrix")
    print("    --display                   - display a 2d representation of the input data")
    print("    --prescale                  - prescale input data prior to clustering")
    print("    --scaleintervals=R1,R2,...  - apply scaling to ranges R1, R2,... independently")
    print("                                  NOTE: turns on prescaling")
    print("")
    return ()


def main():
    # set default variable values
    """
    max_iter = 250
    n_init = 100
    batch_size = 1000
    minibatch = True
    clustertype = "kmeans"
    connfilename = None
    affinity = "euclidean"
    linkage = "ward"
    eps = 0.3
    radius = 1.0
    min_samples = 100
    min_cluster_size = 20
    alpha = 1.0
    display = False
    scaler = "robust"
    intlist = None

    # parse command line arguments
    try:
        opts, args = getopt.gnu_getopt(
            sys.argv,
            "h",
            [
                "help",
                "linkage=",
                "scaleintervals=",
                "display",
                "dmask=",
                "radius=",
                "eps=",
                "noconn",
                "prescale",
                "min_samples=",
                "min_cluster_size=",
                "alpha=",
                "affinity=",
                "nclusters=",
                "type=",
                "connfile=",
            ],
        )
    except getopt.GetoptError as err:
        # print(help information and exit:
        print(str(err))  # will print something like "option -a not recognized"
        usage()
        sys.exit(2)

    # handle required args first
    if len(args) < 3:
        print("spatial fit has 2 required arguments - ", len(args) - 1, "found")
        usage()
        sys.exit()

    datafilename = args[1]
    outputrootname = args[2]

    for o, a in opts:
        if o == "--nclusters":
            n_clusters = int(a)
            print("will use", n_clusters, "clusters")
        elif o == "--dmask":
            usedmask = True
            maskfilename = a
            print("using", maskfilename, "as data mask")
        elif o == "--connfile":
            connfilename = a
            print("will use connectivity information from", connfilename)
        elif o == "--radius":
            radius = float(a)
            print("will use connectivity radius of", radius)
        elif o == "--eps":
            eps = float(a)
            print("will use eps of", eps)
        elif o == "--alpha":
            alpha = float(a)
            print("will use alpha of", alpha)
        elif o == "--min_cluster_size":
            min_cluster_size = int(a)
            print("will use min_cluster_size of", min_cluster_size)
        elif o == "--min_samples":
            min_samples = int(a)
            print("will use min_samples of", min_samples)
        elif o == "--prescale":
            prescale = True
            print("will prescale data prior to fitting")
        elif o == "--display":
            display = True
            print("will display representation of data")
        elif o == "--scaleintervals":
            prescale = True
            intlist = list(map(int, a.split(",")))
            print("will use intervals", intlist)
        elif o == "--linkage":
            linkage = a
            print("will use linkage", linkage)
        elif o == "--affinity":
            affinity = a
            print("will use affinity", affinity)
        elif o == "--type":
            clustertype = a
            if (
                clustertype != "kmeans"
                and clustertype != "agglomerative"
                and clustertype != "hdbscan"
                and clustertype != "dbscan"
            ):
                print(
                    "illegal clustering mode - must be kmeans, (h)dbscan, or agglomerative"
                )
                sys.exit()
        elif o in ("-h", "--help"):
            usage()
            sys.exit()
        else:
            assert False, "unhandled option"
    """

    # get the command line parameters
    try:
        args = _get_parser().parse_args()
    except SystemExit:
        _get_parser().print_help()
        raise

    print("Will perform", args.clustertype, "clustering")

    # read in data
    print("reading in data array")
    (
        datafile_img,
        datafile_data,
        datafile_hdr,
        datafiledims,
        datafilesizes,
    ) = tide_io.readfromnifti(args.datafilename)

    print("reading in mask array")
    if args.maskfilename is not None:
        (
            datamask_img,
            datamask_data,
            datamask_hdr,
            datamaskdims,
            datamasksizes,
        ) = tide_io.readfromnifti(args.maskfilename)

    xsize, ysize, numslices, timepoints = tide_io.parseniftidims(datafiledims)
    xdim, ydim, slicethickness, tr = tide_io.parseniftisizes(datafilesizes)

    # check dimensions
    if args.maskfilename is not None:
        print("checking mask dimensions")
        if not tide_io.checkspacematch(datafile_hdr, datamask_hdr):
            print("input mask spatial dimensions do not match image")
            exit()
        if not datamaskdims[4] == 1:
            print("input mask time must have time dimension of 1")
            exit()

    # allocating arrays
    print("reshaping arrays")
    numspatiallocs = int(xsize) * int(ysize) * int(numslices)
    print(f"there are {numspatiallocs} voxels")
    rs_datafile = datafile_data.reshape((numspatiallocs, timepoints))

    print("masking arrays")
    if args.maskfilename is not None:
        proclocs = np.where(datamask_data.reshape((numspatiallocs)) > args.maskthresh)
    else:
        datamaskdims = [1, xsize, ysize, numslices, 1]
        themaxes = np.max(rs_datafile, axis=1)
        themins = np.min(rs_datafile, axis=1)
        thediffs = (themaxes - themins).reshape((numspatiallocs))
        proclocs = np.where(thediffs > 0.0)
    procdata = rs_datafile[proclocs, :][0]
    print(f"unmasked shape: {rs_datafile.shape}, masked shape: {procdata.shape}")

    # set the initial methodname
    methodname = args.clustertype

    # Scale the data
    intlist = None
    if intlist is None:
        intlist = [timepoints]

    # scale every timepoint, if selected
    if args.scaler is not None:
        print("prescaling each timepoint")
        coefficients = procdata * 0.0

        thepos = 0
        for interval in intlist:
            if args.scaler == "standard":
                for coff in range(timepoints):
                    print(f"normalizing {coff}")
                    coefficients[:, coff] = tide_math.stdnormalize(procdata[:, coff])
            else:
                for coff in range(timepoints):
                    print(f"normalizing {coff}")
                    coefficients[:, coff] = tide_math.madnormalize(procdata[:, coff])
        print("After prescaling...")
        print()
        methodname += f"_{args.scaler}"
    else:
        coefficients = procdata

    # normalize if selected
    if args.normalize:
        print("normalizing each spatial location")
        print("calculating moduli")
        moduli = np.linalg.norm(coefficients, axis=1)
        moduli[np.where(moduli == 0.0)] = 1.0
        print("normalizing")
        coefficients /= moduli[:, None]
        methodname += "_normalize"

    if args.datareduction == "pca":
        print("running PCA")
        methodname += "_pca"
        print("shape going in:", coefficients.shape)

        if args.n_pca <= 0.0:
            thepca = PCA(n_components="mle", svd_solver="full").fit(coefficients)
        else:
            if args.n_pca >= 1.0:
                args.n_pca = int(args.n_pca)
            thepca = PCA(n_components=args.n_pca).fit(coefficients)

        print(f"n_components found: {thepca.n_components_}")
        print(f"n_samples: {thepca.n_samples_}")
        print(f"n_features: {thepca.n_features_}")
        coefficients = thepca.transform(coefficients)
        # coefficients = thepca.inverse_transform(thetransform)
        print("shape coming out:", coefficients.shape)
        for i in range(thepca.n_components_):
            print(
                "component",
                i,
                "explained variance:",
                thepca.explained_variance_[i],
                "explained variance %:",
                100.0 * thepca.explained_variance_ratio_[i],
            )
        tide_io.writenpvecs(
            thepca.components_,
            args.outputrootname + "_" + methodname + "_pcacomponents.txt",
        )
        tide_io.writenpvecs(
            np.transpose(thepca.components_),
            args.outputrootname + "_" + methodname + "_pcacomponents_transpose.txt",
        )
        print("data reduction done")
        theheader = datafile_hdr
        theheader["dim"][4] = thepca.n_components_
        tempout = np.zeros((numspatiallocs, thepca.n_components_), dtype="float")
        tempout[proclocs, :] = coefficients[:, :]
        tide_io.savetonifti(
            tempout.reshape((xsize, ysize, numslices, thepca.n_components_)),
            datafile_hdr,
            args.outputrootname + "_" + methodname + "_pcareduced",
        )

    elif args.datareduction == "ica":
        print("running FastICA")
        methodname += "_ica"
        if args.n_pca <= 1.0:
            args.n_pca = int(0)
        theica = FastICA(n_components=int(args.n_pca), algorithm="deflation").fit(coefficients)

        thetransform = theica.transform(coefficients)
        coefficients = theica.inverse_transform(thetransform)
        tide_io.writenpvecs(
            theica.components_,
            args.outputrootname + "_" + methodname + "_icacomponents.txt",
        )
        tide_io.writenpvecs(
            np.transpose(theica.components_),
            args.outputrootname + methodname + "_" + "_icacomponents_transpose.txt",
        )
        print("data reduction done")
        theheader = datafile_hdr
        theheader["dim"][4] = theica.n_components_
        tempout = np.zeros((numspatiallocs, theica.n_components_), dtype="float")
        tempout[proclocs, :] = coefficients[:, :]
        tide_io.savetonifti(
            tempout.reshape((xsize, ysize, numslices, theica.n_components_)),
            datafile_hdr,
            args.outputrootname + "_" + methodname + "_icareduced",
        )

    # take a look at it
    if args.display:
        print("calculating TSNE projection...")
        projection = TSNE().fit_transform(coefficients)
        print("done")
        plt.scatter(*projection.T)
        plt.show()

    # now cluster the data
    print(f"clustering: {args.clustertype}")
    if args.clustertype == "agglomerative":
        theshape = (xsize, ysize, numslices)
        print("computing connectivity matrix")
        themask = np.zeros((xsize, ysize, numslices), dtype=bool).reshape((numspatiallocs))
        themask[proclocs] = True
        thecartmask = themask.reshape((xsize, ysize, numslices))
        # connectivity = grid_to_graph(n_x=np.uint32(xsize), n_y=np.uint32(ysize), n_z=np.uint32(numslices), mask=thecartmask)
        connectivity = kneighbors_graph(coefficients, n_neighbors=5, n_jobs=-1)
        print("Done")
        methodname += f"_{args.linkage}_{args.affinity}_{str(args.n_clusters).zfill(2)}"
        agg = AgglomerativeClustering(
            n_clusters=args.n_clusters,
            memory=os.getcwd(),
            compute_full_tree=False,
            linkage=args.linkage,
            connectivity=connectivity,
            affinity=args.affinity,
        )
        agg.fit(coefficients)
        theregionlabels = agg.labels_
        print("there are", agg.n_components_, "components and", agg.n_leaves_, "leaves")
        print(agg.children_)
        # ii = itertools.count(coefficients.shape[0])
        # [{'node_id': next(ii), 'left': x[0], 'right':x[1]} for x in agg.children_]

    elif args.clustertype == "dbscan":
        db = DBSCAN(eps=args.eps, min_samples=args.min_samples, n_jobs=-1).fit(coefficients)
        core_samples_mask = np.zeros_like(db.labels_, dtype=bool)
        core_samples_mask[db.core_sample_indices_] = True
        theregionlabels = db.labels_
        print(theregionlabels)

        # Number of clusters in labels, ignoring noise if present.
        n_clusters_ = len(set(theregionlabels)) - (1 if -1 in theregionlabels else 0)

        print("Estimated number of clusters: %d" % n_clusters_)
        methodname += f"_{str(n_clusters_).zfill(2)}"

    elif args.clustertype == "hdbscan":
        hdb = hdbs.HDBSCAN(
            min_samples=args.min_samples,
            min_cluster_size=args.min_cluster_size,
            alpha=args.alpha,
            memory=os.getcwd(),
        ).fit(coefficients)
        theregionlabels = hdb.labels_
        print(theregionlabels)

        # Number of clusters in labels, ignoring noise if present.
        n_clusters_ = len(set(theregionlabels)) - (1 if -1 in theregionlabels else 0)

        print("Estimated number of clusters: %d" % n_clusters_)
        methodname += (
            f"_minsamples_{args.min_samples}_" + f"minclustersize_{args.min_cluster_size}"
        )
        methodname += f"_{str(n_clusters_).zfill(2)}"

    elif args.clustertype == "kmeans":
        print("coefficients shape:", coefficients.shape)
        theregionlabels = np.zeros((coefficients.shape[0], args.repeats), dtype="int")
        dbscores = np.zeros((args.repeats), dtype="float")
        for therepeat in range(args.repeats):
            if args.minibatch:
                kmeans = MiniBatchKMeans(
                    n_clusters=args.n_clusters,
                    batch_size=args.batch_size,
                    max_iter=args.max_iter,
                ).fit(coefficients)
            else:
                kmeans = KMeans(
                    n_clusters=args.n_clusters,
                    max_iter=args.max_iter,
                    n_init=args.n_init,
                ).fit(coefficients)

            theclusters = np.transpose(kmeans.cluster_centers_)
            print("cluster_centers shape", theclusters.shape)
            theregionlabels[:, therepeat] = kmeans.labels_ + 0
            dbscores[therepeat] = davies_bouldin_score(coefficients, kmeans.labels_)
            print(f"Davies Bouldin score for repeat {therepeat} = {dbscores[therepeat]}")
        methodname += f"_{str(args.n_clusters).zfill(2)}"
        tide_io.writevec(
            dbscores,
            args.outputrootname + "_" + methodname + "_dbscores.txt",
        )

        # print("Silhouette Coefficient: %0.3f"
        #% metrics.silhouette_score(coefficients, theregionlabels))

    else:
        print("illegal clustering mode")
        sys.exit()

    # save the command line
    tide_io.writevec(
        [" ".join(sys.argv)],
        args.outputrootname + "_" + methodname + "_commandline.txt",
    )

    print("theregionlabels shape", theregionlabels.shape)
    print("clustering done")
    theheader = datafile_hdr
    theheader["dim"][4] = args.repeats
    tempout = np.zeros((numspatiallocs, args.repeats), dtype="int")
    tempout[proclocs] = theregionlabels[:, :] + 1
    tide_io.savetonifti(
        tempout.reshape((xsize, ysize, numslices, args.repeats)),
        datafile_hdr,
        args.outputrootname + "_" + methodname + "_regions",
    )


if __name__ == "__main__":
    main()
