#!/usr/bin/env python
# -*- coding: latin-1 -*-
#
#   Copyright 2016-2021 Blaise Frederick
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
#
import argparse
import os
import subprocess
import sys

import numpy as np
import rapidtide.filter as tide_filt
import rapidtide.io as tide_io
import rapidtide.workflows.parser_funcs as pf


def _get_parser():
    """
    Argument parser for atlastool
    """
    parser = argparse.ArgumentParser(
        prog="statlasgen",
        description=("A tool to generate atlases from value histograms"),
        usage="%(prog)s inputname outputtemplatename [options]",
    )

    # Required arguments
    pf.addreqinputniftifile(
        parser,
        "inputname",
        addedtext="Must be a 4D file with each timepoint being a sample segmentation. ",
    )
    pf.addreqoutputniftifile(parser, "outputtemplatename")

    # add optional arguments
    parser.add_argument(
        "--3d",
        dest="volumeperregion",
        action="store_false",
        help=("Return a 3d file with regions encoded as integers"),
        default=False,
    )
    parser.add_argument(
        "--4d",
        dest="volumeperregion",
        action="store_true",
        help=("Return a 4d file with one region per volume"),
    )
    parser.add_argument(
        "--spatialfilt",
        dest="gausssigma",
        action="store",
        type=float,
        metavar="GAUSSSIGMA",
        help=(
            "Spatially filter value histogram "
            "using GAUSSSIGMA in mm.  Set GAUSSSIGMA negative "
            "to have it set to half the mean voxel "
            "dimension (a rule of thumb for a good value)."
        ),
        default=0.0,
    )
    parser.add_argument(
        "--targetfile",
        dest="targetfile",
        action="store",
        type=lambda x: pf.is_valid_file(parser, x),
        metavar="TARGET",
        help=("Match the resolution of TARGET"),
        default=None,
    )
    parser.add_argument(
        "--debug",
        dest="debug",
        action="store_true",
        help=("Output additional debugging information."),
        default=False,
    )

    return parser


def main():
    # get the command line parameters
    try:
        args = _get_parser().parse_args()
    except SystemExit:
        _get_parser().print_help()
        raise

    print("loading input data")
    (
        input_img,
        input_data,
        input_hdr,
        inputdims,
        inputsizes,
    ) = tide_io.readfromnifti(args.inputname)

    print("reshaping")
    xsize, ysize, numslices, numtimepoints = tide_io.parseniftidims(inputdims)
    xdim, ydim, slicethickness, tr = tide_io.parseniftisizes(inputsizes)
    numvoxels = int(xsize) * int(ysize) * int(numslices)

    # make a 4d array
    if numtimepoints > 1:
        # array is already 4d, just reshape it
        numregions = np.floor(np.max(input_data)).astype(np.uint16)
        inputvoxels = np.reshape(input_data, (numvoxels, numtimepoints))
        outputhistogram = np.zeros((numvoxels, numregions), dtype=float)
        histmask = np.zeros((numvoxels), dtype=int)
    else:
        print("input file must be 4 dimensional.  Exiting.")
        sys.exit()

    # make histograms for all voxels
    for thevoxel in range(numvoxels):
        if np.max(inputvoxels[thevoxel, :]) > 0:
            thehist, thebins = np.histogram(
                inputvoxels[thevoxel, :],
                bins=(numregions + 1),
                range=[0, numregions],
                density=True,
            )
            if args.debug:
                print(len(thebins), thebins)
                print(len(thehist), thehist)
            outputhistogram[thevoxel, :] = thehist[1:]
            histmask[thevoxel] = 1

    # do spatial filtering if requested
    if args.gausssigma < 0.0:
        # set gausssigma automatically
        args.gausssigma = np.mean([xdim, ydim, slicethickness]) / 2.0
    if args.gausssigma > 0.0:
        print(f"applying gaussian spatial filter with sigma={args.gausssigma}")
        rs_outputhistogram = outputhistogram.reshape((xsize, ysize, numslices, numregions))
        for i in range(0, numregions):
            rs_outputhistogram[:, :, :, i] = tide_filt.ssmooth(
                xdim,
                ydim,
                slicethickness,
                args.gausssigma,
                rs_outputhistogram[:, :, :, i],
            )

    if args.targetfile is not None:
        # do the resampling here
        print("resampling to new resolution")
        fsldir = os.environ.get("FSLDIR")
        if fsldir is not None:
            # first write out a temp file with the data
            outputvoxels = inputvoxels
            input_hdr["dim"][4] = numregions
            tide_io.savetonifti(
                outputvoxels.reshape((xsize, ysize, numslices, numregions)),
                input_hdr,
                "temppre",
            )
            flirtcmd = os.path.join(fsldir, "bin", "flirt")
            thecommand = []
            thecommand.append(flirtcmd)
            thecommand.append("-in")
            thecommand.append("temppre")
            thecommand.append("-ref")
            thecommand.append(args.targetfile)
            thecommand.append("-applyxfm")
            thecommand.append("-init")
            thecommand.append(os.path.join(fsldir, "data", "atlases", "bin", "eye.mat"))
            thecommand.append("-out")
            thecommand.append("temppost")
            subprocess.call(thecommand)
            (
                input_img,
                input_data,
                input_hdr,
                inputdims,
                inputsizes,
            ) = tide_io.readfromnifti("temppost")
            xsize = inputdims[1]
            ysize = inputdims[2]
            numslices = inputdims[3]
            numregions = inputdims[4]
            numvoxels = int(xsize) * int(ysize) * int(numslices)
            inputvoxels = np.around(np.reshape(input_data, (numvoxels, numregions))).astype("int")
        else:
            print("FSL directory not found - aborting")
            sys.exit()

    if args.volumeperregion:
        outputvoxels = outputhistogram
        input_hdr["dim"][4] = numregions
        tide_io.savetonifti(
            outputvoxels.reshape((xsize, ysize, numslices, numregions)),
            input_hdr,
            args.outputtemplatename,
        )
    else:
        print("collapsing back to 3d")
        outputvoxels = outputhistogram[:, 0] * 0
        for i in range(numvoxels):
            if histmask[i] > 0:
                outputvoxels[i] = np.argmax(outputhistogram[i, :]) + 1
        input_hdr["dim"][4] = 1
        tide_io.savetonifti(
            outputvoxels.reshape((xsize, ysize, numslices)),
            input_hdr,
            args.outputtemplatename,
        )


if __name__ == "__main__":
    main()
