#!/usr/bin/env python
#
#   Copyright 2016 Blaise Frederick
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
#       $Author: frederic $
#       $Date: 2016/01/25 20:17:55 $
#       $Id: pixelcomp,v 1.4 2016/01/25 20:17:55 frederic Exp $
#
from __future__ import division, print_function

import getopt
import string
import sys

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import rapidtide.io as tide_io
from pylab import *


def bland_altman_plot(data1, data2, *args, **kwargs):
    data1 = np.asarray(data1)
    data2 = np.asarray(data2)
    mean = np.mean([data1, data2], axis=0)
    diff = data1 - data2  # Difference between data1 and data2
    md = np.mean(diff)  # Mean of the difference
    sd = np.std(diff, axis=0)  # Standard deviation of the difference

    plt.scatter(mean, diff, *args, **kwargs)
    plt.axhline(md, color="gray", linestyle="--")
    plt.axhline(md + 2 * sd, color="gray", linestyle="--")
    plt.axhline(md - 2 * sd, color="gray", linestyle="--")


def usage():
    print("usage: pixelcomp inputfile1 maskfile1 inputfile2 maskfile2 outputfile")
    print("")
    print("required arguments:")
    print("	inputfile1	- the name of the first input nifti file")
    print("	maskfile1	- the name of the first input nifti mask file")
    print("	inputfile2	- the name of the second input nifti file")
    print("	maskfile2	- the name of the second input nifti mask file")
    print("	outputfile	- the name of the output text file")
    print("")
    print("optional arguments:")
    print("	-l		- perform linear fit only")
    print("")
    return ()


# set default variable values
displayplots = False
fastfilter = True
linfitonly = False
order = 3

# get the command line parameters
if len(sys.argv) < 6:
    usage()
    exit()

# handle required args first
inputfilename1 = sys.argv[1]
maskfilename1 = sys.argv[2]
inputfilename2 = sys.argv[3]
maskfilename2 = sys.argv[4]
outputfilename = sys.argv[5]

# now scan for optional arguments
try:
    opts, args = getopt.getopt(sys.argv[4:], "l", ["help"])
except getopt.GetoptError as err:
    # print help information and exit:
    print(str(err))  # will print something like "option -x not recognized"
    usage()
    sys.exit(2)

for o, a in opts:
    if o == "-l":
        linfitonly = True
    else:
        assert False, "unhandled option"

(
    input1_img,
    input1_data,
    input1_hdr,
    thedims1,
    thesizes1,
) = tide_io.readfromnifti(inputfilename1)

(
    mask1_img,
    mask1_data,
    mask1_hdr,
    themaskdims1,
    themasksizes1,
) = tide_io.readfromnifti(maskfilename1)

if not tide_io.checkspacematch(input1_hdr, mask1_hdr):
    print("input mask 1 does not match spatial coverage of image 1")
    exit()
(
    input2_img,
    input2_data,
    input2_hdr,
    thedims2,
    thesizes2,
) = tide_io.readfromnifti(inputfilename2)

mask2_img = nib.load(maskfilename2)
mask2_data = mask2_img.get_data()
mask2_hdr = mask2_img.get_header()
themaskdims2 = mask2_hdr["dim"]
(
    mask2_img,
    mask2_data,
    mask2_hdr,
    themaskdims2,
    themasksizes2,
) = tide_io.readfromnifti(maskfilename2)


if not tide_io.checkspacematch(input2_hdr, mask2_hdr):
    print("input mask 2 does not match spatial coverage of image 2")
    exit()

if not tide_io.checkspacematch(input1_hdr, input2_hdr):
    print("input images 1 and 2 do not have matching spatial coverage")
    exit()

totalmask = mask1_data * mask2_data
nonzeropoints = np.where(totalmask > 0)
pairlist = []
for i in range(0, len(nonzeropoints[0])):
    pairlist.append(
        [
            input1_data[nonzeropoints[0][i], nonzeropoints[1][i], nonzeropoints[2][i]],
            input2_data[nonzeropoints[0][i], nonzeropoints[1][i], nonzeropoints[2][i]],
        ]
    )

thearray = np.asarray(pairlist)
print(thearray[:, 0])

# construct a 2d histogram
numbins = 100
H, xedges, yedges = np.histogram2d(thearray[:, 0], thearray[:, 1], bins=numbins, normed=True)
extent = [yedges[0], yedges[-1], xedges[0], xedges[-1]]
cset = plt.contour(H, extent=extent)
# plt.plot(thearray[:,0], thearray[:,1], 'k.')
show()

# now fit the line
fitorder = 1
thecoffs = np.polyfit(np.array(pairlist[0]), np.array(pairlist[1]), fitorder)
print("thecoffs=", thecoffs)
with open(outputfilename + "_linfit", "w") as file:
    file.writelines(str(thecoffs))

if not linfitonly:
    with open(outputfilename, "w") as file:
        for pair in pairlist:
            file.writelines(str(pair[0]) + "\t" + str(pair[1]) + "\n")
    bland_altman_plot(thearray[:, 0], thearray[:, 1])
    plt.title("Bland-Altman Plot")
    plt.show()
