#!/usr/bin/env python
#
#   Copyright 2016 Blaise Frederick
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
#
#       $Author: frederic $
#       $Date: 2016/06/14 12:04:50 $
#       $Id: histnifti,v 1.8 2016/06/14 12:04:50 frederic Exp $
#
from __future__ import print_function

import rapidtide.io as tide_io
from pylab import *


def usage():
    print("usage: maptoroi inputfile templatefile outputroot")
    print("")
    print("required arguments:")
    print(
        "    inputfile        - the name of the file with the roi values to be mapped back to image space"
    )
    print("    templatefile     - the name of the template region file")
    print("    outputfile       - the name of the output nifti file")
    print("")
    return ()


# get the command line parameters
if len(sys.argv) < 4:
    usage()
    exit()

# handle required args first
inputfilename = sys.argv[1]
templatefile = sys.argv[2]
outputfile = sys.argv[3]

print("loading data")
theclustercenters = tide_io.readvecs(inputfilename)
template_img, template_data, template_hdr, thedims, thesizes = tide_io.readfromnifti(templatefile)

print(theclustercenters.shape)
numregions = theclustercenters.shape[0]
numclusters = theclustercenters.shape[1]
print(numregions, "regions, ", numclusters, "clusters")
xsize = thedims[1]
ysize = thedims[2]
numslices = thedims[3]
numpatterns = thedims[4]
numvoxels = int(xsize) * int(ysize) * int(numslices)
output_data = zeros((numvoxels, numclusters), dtype="float")

# check to see if the template file has ROIs or networks
isroifile = True
if numpatterns > 1:
    # handle multipattern files
    print("treating template as a network file")
    templatevoxels = reshape(template_data, (numvoxels)).astype(int)
    isroifile = False
else:
    print("treating template as an ROI file")
    templatevoxels = reshape(template_data, (numvoxels, numpatterns))

for thecluster in range(numclusters):
    print("processing cluster", thecluster)
    if isroifile:
        for i in range(1, numregions + 1):
            output_data[where(templatevoxels == i), thecluster] = theclustercenters[
                i - 1, thecluster
            ]
    else:
        output_data[:, thecluster] += theclustercenters[i - 1, thecluster] * templatevoxels[:, i]

theheader = template_hdr
theheader["dim"][4] = numclusters
tide_io.savetonifti(
    output_data.reshape((xsize, ysize, numslices, numclusters)), theheader, outputfile
)
