# **************************************************************************
# * Authors:    Mohamad Harastani            (mohamad.harastani@igbmc.fr)
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************

from continuousflex.protocols.protocol_tomoflow import FlexProtHeteroFlow
from pwem.viewers import EmProtocolViewer
from pyworkflow.protocol.params import LabelParam, IntParam, EnumParam, StringParam
from pyworkflow.viewer import ProtocolViewer, DESKTOP_TKINTER, WEB_DJANGO
from pwem.viewers import ObjectView
import numpy as np
import matplotlib.pyplot as plt
from continuousflex.protocols.utilities.OF_plots import plot_quiver_3d, plot_quiver_2d
from continuousflex.protocols.utilities.spider_files3 import open_volume, open_image
import pyworkflow.protocol.params as params
from pyworkflow.utils.process import runJob
from pyworkflow.utils.path import makePath


XY = 0
XZ = 1
ZY = 2


class FlexHeteroFlowViewer(EmProtocolViewer):
    """ Visualization of results from the HeteroFlow protocol
    """
    _label = 'viewer heteroflow'
    _targets = [FlexProtHeteroFlow]
    _environments = [DESKTOP_TKINTER, WEB_DJANGO]

    def __init__(self, **kwargs):
        ProtocolViewer.__init__(self, **kwargs)
        self._data = None

    def _defineParams(self, form):
        form.addSection(label='Visualization')
        group = form.addGroup('Optical flows')
        group.addParam('FlowNumber', IntParam, default=1,
              label='Optical flow for volume number')
        group.addParam('DownSample', IntParam, default=2,
                       expertLevel=params.LEVEL_ADVANCED,
                       label='Downsample the 3D quiver plot',
                       help='Loading the 3D quiver plot can take time, we can downsample it by this number')
        group.addParam('displayFlow', LabelParam,
                      label="Display 3D optical flow",
                      help="Display the calculated optical flow of volume of specified number vs the reference")
        group.addParam('displayFlow2', EnumParam,
                       choices=['     x-y plane: Euler( 0, 0, 0)', '     x-z plane: Euler( 0,90, 0)',
                                '     z-y plane: Euler(90,90, 0)'],
                       default=XY,
                       display=params.EnumParam.DISPLAY_COMBO,
                       label='Display projected 3D-to-2D optical flow',
                       help="Display the projected 3D-2D optical flow on one of the planes")
        group.addParam('RotTiltPsi', StringParam, default=None,
                       expertLevel=params.LEVEL_ADVANCED,
                       label='rot tilt psi',
                       help='Project the 3D optical flow using specific Euler angles')
        form.addParam('displayVolumes', LabelParam,
                      label="Display warped volumes",
                      help="Display the volumes that are generated by applying the calculated optical flow of each of "
                           "the input volumes on the reference")
        form.addParam('displayHistCC', LabelParam,
                      label="Histogram of normalized cross correlation",
                      help="Histogram of the normalized cross correlation between the input volumes and "
                           "the warped reference")
        form.addParam('displayHistmsd', LabelParam,
                      label="Histogram of mean square distance",
                      help="Histogram of the mean square distance between the input volumes and "
                           "the warped reference")
        form.addParam('displayHistmad', LabelParam,
                      label="Histogram of normalized mean absolute distance",
                      help="Histogram of the mean absolute distance between the input volumes and "
                           "the warped reference")

    def _getVisualizeDict(self):
        return {'displayVolumes': self._viewVolumes,
                'displayHistCC': self._viewParam,
                'displayHistmsd': self._viewParam,
                'displayHistmad': self._viewParam,
                'displayFlow' : self._viewFlow,
                'displayFlow2' : self._viewFlow2,
                'RotTiltPsi' : self._viewFlow2
                }

    def _viewVolumes(self, paramName):
        volumes = self.protocol.WarpedRefByFlows
        return [ObjectView(self._project, volumes.strId(), volumes.getFileName())]

    def _viewParam(self, paramName):
        datamat_fn = self.protocol._getExtraPath('cc_msd_mad.txt')
        datamat = np.loadtxt(datamat_fn, delimiter=' ')
        if paramName == 'displayHistCC':
            plt.figure()
            plt.hist(datamat[:, 0])
            # plt.xlim(0,1)
            plt.title('Histogram of normalized cross correlation between warped\n reference by '
                      'optical flows (estimated volumes) and the input volumes')
            plt.xlabel('Cross correlation')
            plt.ylabel('Number of volumes')
            plt.show()
        elif paramName == 'displayHistmsd':
            plt.figure()
            plt.hist(datamat[:, 1])
            plt.title('Histogram of normalized mean square distance between warped\n reference by '
                      'optical flows (estimated volumes) and the input volumes')
            plt.xlabel('normalized mean square distance')
            plt.ylabel('Number of volumes')
            plt.show()
        elif paramName == 'displayHistmad':
            plt.figure()
            plt.hist(datamat[:, 2])
            plt.title('Histogram of normalized mean absolute distance between warped\n reference by '
                      'optical flows (estimated volumes) and the input volumes')
            plt.xlabel('normalized absolute square distance')
            plt.ylabel('Number of volumes')
            plt.show()
        pass

    def _viewFlow(self, paramName):
        number = str(self.FlowNumber).zfill(6)
        flow = self.read_optical_flow_by_number(number)
        title = '3D optical flow for input volume number %d' % self.FlowNumber
        plot_quiver_3d(flow, downsample=self.DownSample.get(), title=title)
        pass

    def _viewFlow2(self, paramName):
        number = str(self.FlowNumber).zfill(6)
        flow3D = self.read_optical_flow_by_number(number)
        op_path = self.protocol._getExtraPath() + '/optical_flows/'
        path_flowx = op_path + str(number).zfill(6) + '_opflowx.spi'
        path_flowy = op_path + str(number).zfill(6) + '_opflowy.spi'
        path_flowz = op_path + str(number).zfill(6) + '_opflowz.spi'
        makePath(self.protocol._getTmpPath())
        proj_x = self.protocol._getTmpPath('proj_x.spi')
        proj_y = self.protocol._getTmpPath('proj_y.spi')
        proj_z = self.protocol._getTmpPath('proj_z.spi')
        rot = 0
        tilt = 0
        psi = 0
        title = 'Projected optical flow of input volume number %d on XY plane' % self.FlowNumber
        if self.displayFlow2 == XZ:
            title = 'Projected optical flow of input volume number %d on XZ plane' % self.FlowNumber
            rot = 90
            tilt = 90
        elif self.displayFlow2 == ZY:
            title = 'Projected optical flow of input volume number %d on ZY plane' % self.FlowNumber
            tilt = 90
        if paramName == 'RotTiltPsi':
            rot, tilt, psi = list(map(float, self.RotTiltPsi.get().split()))
            title = 'Projected optical flow of input volume number %d \n using Euler angles' \
                    ' (%.1f, %.1f, %.1f)' % (self.FlowNumber, rot, tilt, psi)
            #print(rot, tilt, psi)
        command_x = '-i ' + path_flowx + ' -o ' + proj_x + ' --angles ' + str(rot) + ' ' + str(tilt) + ' ' + str(psi)
        command_y = '-i ' + path_flowy + ' -o ' + proj_y + ' --angles ' + str(rot) + ' ' + str(tilt) + ' ' + str(psi)
        command_z = '-i ' + path_flowz + ' -o ' + proj_z + ' --angles ' + str(rot) + ' ' + str(tilt) + ' ' + str(psi)
        runJob(None, 'xmipp_phantom_project', command_x)
        runJob(None, 'xmipp_phantom_project', command_y)
        runJob(None, 'xmipp_phantom_project', command_z)
        px = open_image(proj_x)
        py = open_image(proj_y)
        pz = open_image(proj_z)
        p = np.zeros([3, np.shape(px)[0], np.shape(px)[1]])
        p[0, :, :] = px
        p[1, :, :] = py
        p[2, :, :] = pz
        T = self.euler_matrix(rot, tilt, psi)
        p_reshaped = np.reshape(p, [3, np.shape(px)[0] * np.shape(px)[1]])
        pn = np.reshape(np.matmul(T, p_reshaped), [3, np.shape(px)[0], np.shape(px)[1]])
        flow2D = np.zeros([np.shape(px)[0], np.shape(px)[1], 2])
        flow2D[:,:,0] = pn[0,:,:]
        flow2D[:,:,1] = pn[1,:,:]

        # We need to scale flow2D by the magnitude of flow3D
        mag_3D = np.sqrt(flow3D[0, :, :, :] * flow3D[0, :, :, :] +
                         flow3D[1, :, :, :] * flow3D[1, :, :, :] +
                         flow3D[2, :, :, :] * flow3D[2, :, :, :])
        max_3D = np.max(mag_3D)
        mag_2D = np.sqrt(flow2D[:, :, 0] * flow2D[:, :, 0] +
                         flow2D[:, :, 1] * flow2D[:, :, 1])
        max_2D = np.max(mag_2D)
        flow2D = (max_3D/max_2D)*flow2D
        plot_quiver_2d(flow2D, title=title)
        pass

    def read_optical_flow_by_number(self, num):
        op_path = self.protocol._getExtraPath() + '/optical_flows/'
        path_flowx = op_path + str(num).zfill(6) + '_opflowx.spi'
        path_flowy = op_path + str(num).zfill(6) + '_opflowy.spi'
        path_flowz = op_path + str(num).zfill(6) + '_opflowz.spi'
        flow = self.read_optical_flow(path_flowx, path_flowy, path_flowz)
        return flow

    def read_optical_flow(self, path_flowx, path_flowy, path_flowz):
        x = open_volume(path_flowx)
        y = open_volume(path_flowy)
        z = open_volume(path_flowz)
        l = np.shape(x)
        # print(l)
        flow = np.zeros([3, l[0], l[1], l[2]])
        flow[0, :, :, :] = x
        flow[1, :, :, :] = y
        flow[2, :, :, :] = z
        return flow

    def euler_matrix(self,rot, tilt, psi):
        from math import sin, cos, radians
        t1 = -radians(psi)
        t2 = -radians(tilt)
        t3 = -radians(rot)
        a11 = cos(t1) * cos(t2) * cos(t3) - sin(t1) * sin(t3)
        a12 = -cos(t3) * sin(t1) - cos(t1) * cos(t2) * sin(t3)
        a13 = cos(t1) * sin(t2)
        a21 = cos(t1) * sin(t3) + cos(t2) * cos(t3) * sin(t1)
        a22 = cos(t1) * cos(t3) - cos(t2) * sin(t1) * sin(t3)
        a23 = sin(t1) * sin(t2)
        a31 = -cos(t3) * sin(t2)
        a32 = sin(t2) * sin(t3)
        a33 = cos(t2)
        T = np.array([[a11, a12, a13], [a21, a22, a23], [a31, a32, a33]])
        return T
