"""
Visualization module based on matplotlib

"""
# Copyright 2018 European Union
# This file is part of pyposeidon.
# Licensed under the EUPL, Version 1.2 or – as soon they will be approved by the European Commission - subsequent versions of the EUPL (the "Licence").
# Unless required by applicable law or agreed to in writing, software distributed under the Licence is distributed on an "AS IS" basis, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the Licence for the specific language governing permissions and limitations under the Licence.


import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import xarray as xr
import geopandas as gp
import shapely
from pyposeidon.utils.quads2tr import quads_to_tris

import sys
import os

ffmpeg = sys.exec_prefix + "/bin/ffmpeg"
os.environ["FFMPEG_BINARY"] = ffmpeg
from matplotlib import animation


matplotlib.rc("animation", html="html5")
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams["animation.embed_limit"] = "200."


def __init__(dark_background=False):

    # set plt style
    if dark_background:
        plt.style.use("dark_background")


@xr.register_dataset_accessor("gplot")
class gplot(object):
    def __init__(self, xarray_obj):
        self._obj = xarray_obj

    def contourf(self, x=None, y=None, z=None, tname="time", **kwargs):

        fig = plt.figure(figsize=(12, 8))

        if len(self._obj[x].shape) > 2:
            grid_x = self._obj[x].values[0, :, :]
            grid_y = self._obj[y].values[0, :, :]
        else:
            grid_x = self._obj[x].values
            grid_y = self._obj[y].values

        z_ = self._obj[z].values
        t = self._obj[tname].values

        vmin = kwargs.get("vmin", z_.min())
        vmax = kwargs.get("vmax", z_.max())

        nv = kwargs.get("nv", 10)

        title = kwargs.get("title", None)

        vrange = np.linspace(vmin, vmax, nv, endpoint=True)
        ## CHOOSE YOUR PROJECTION
        #   ax = plt.axes(projection=ccrs.Orthographic(grid_x.mean(), grid_y.mean()))
        ax = plt.axes(projection=ccrs.PlateCarree())
        # Limit the extent of the map to a small longitude/latitude range.

        ax.set_aspect("equal")
        ims = []
        for i in range(len(t)):
            im = ax.contourf(grid_x, grid_y, z_[i, :, :], vrange, vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree())
            #        im = ax.contourf(x,y,z[i,:,:],v,vmin=v1,vmax=v2,latlon=True)
            add_arts = im.collections
            text = "time={}".format(t[i])
            # te = ax.text(90, 90, text)
            an = ax.annotate(text, xy=(0.05, 1.05), xycoords="axes fraction")
            ims.append(add_arts + [an])
        if title:
            ax.set_title(title)
        # ax.set_global()
        ax.coastlines("50m")
        ax.set_extent([grid_x.min(), grid_x.max(), grid_y.min(), grid_y.max()])
        ax.gridlines(draw_labels=True)

        # cbar_ax = fig.add_axes([0.05, 0.05, 0.85, 0.05])
        cbar = fig.colorbar(im, ticks=vrange, orientation="vertical")  # ,fraction=0.046, pad=0.04)
        # plt.colorbar()

        v = animation.ArtistAnimation(fig, ims, interval=200, blit=False, repeat=False)

        plt.close()

        return v

    def update_quiver(num, Q, U, V, step):
        """updates the horizontal and vertical vector components by a
        fixed increment on each frame
        """

        Q.set_UVC(U[num, ::step, ::step], V[num, ::step, ::step])

        return (Q,)

    def quiver(self, x=None, y=None, z=None, tname="time", **kwargs):

        U = self._obj[z].values[:, :, :, 0]
        V = self._obj[z].values[:, :, :, 1]

        if len(self._obj[x].shape) > 2:
            X = self._obj[x].values[0, :, :]
            Y = self._obj[y].values[0, :, :]
        else:
            X = self._obj[x].values
            Y = self._obj[y].values

        fig = plt.figure(figsize=(12, 8))
        ax = plt.axes(projection=ccrs.PlateCarree())
        crs = ccrs.PlateCarree()
        ax.set_aspect("equal")

        land_50m = cfeature.NaturalEarthFeature(
            "physical", "land", "50m", edgecolor="face", facecolor=cfeature.COLORS["land"], zorder=0
        )

        sea_50m = cfeature.NaturalEarthFeature(
            "physical", "ocean", "50m", edgecolor="face", facecolor=cfeature.COLORS["water"], zorder=0
        )

        title = kwargs.get("title", None)

        ax.coastlines("50m")
        ax.add_feature(land_50m)
        ax.add_feature(sea_50m)

        scale = kwargs.get("scale", 1.0)  # change accordingly to fit your needs
        step = kwargs.get("step", 1)  # change accordingly to fit your needs

        Q = ax.quiver(
            X[::step, ::step],
            Y[::step, ::step],
            U[0, ::step, ::step],
            V[0, ::step, ::step],
            pivot="mid",
            color="k",
            angles="xy",
            scale_units="xy",
            scale=scale,
            transform=crs,
        )

        ax.set_xlim(X.min(), X.max())
        ax.set_ylim(Y.min(), Y.max())
        ax.set_title(title)
        # ax.set_global()

        plt.close()
        # you need to set blit=False, or the first set of arrows never gets
        # cleared on subsequent frames
        v = animation.FuncAnimation(
            fig, update_quiver, fargs=(Q, U, V, step), frames=range(0, np.size(t)), blit=False, repeat=False
        )  # , interval=1)

        return v


def to_html5(fname, mimetype):
    """Load the video in the file `fname`, with given mimetype, and display as HTML5 video."""
    from IPython.display import HTML
    import base64

    video = open(fname, "rb").read()
    video_encoded = base64.b64encode(video)
    video_tag = '<video controls alt="test" src="data:video/{0};base64,{1}">'.format(
        mimetype, video_encoded.decode("ascii")
    )
    return HTML(data=video_tag)


def update_qframes(num, Q, U, V):
    """updates the horizontal and vertical vector components by a
    fixed increment on each frame
    """

    Q.set_UVC(U[num, :], V[num, :])

    return (Q,)


@xr.register_dataset_accessor("pplot")
# @xr.register_dataarray_accessor('pplot')


class pplot(object):
    def __init__(self, xarray_obj):
        self._obj = xarray_obj

    def contour(self, it=None, **kwargs):

        x = kwargs.get("x", self._obj.SCHISM_hgrid_node_x[:].values)
        y = kwargs.get("y", self._obj.SCHISM_hgrid_node_y[:].values)
        try:
            t = kwargs.get("t", self._obj.time.values)
        except:
            pass

        tes = kwargs.get("tes", self._obj.SCHISM_hgrid_face_nodes.values[:, :4])

        var = kwargs.get("var", "depth")
        z = kwargs.get("z", self._obj[var].values[it, :].flatten())

        # sort out quads
        try:
            mask = np.isnan(tes)[:, 3]
            tr3 = tes[mask][:, :3]
            tr3_ = quads_to_tris(tes[~mask])
            if tr3_:
                tri3 = np.append(tr3, tr3_, axis=0).astype(int)
            else:
                tri3 = tr3.astype(int)
        except:
            tri3 = tes.astype(int)

        fig, ax = plt.subplots(figsize=(12, 8))
        vmin = kwargs.get("vmin", z.min())
        vmax = kwargs.get("vmax", z.max())

        nv = kwargs.get("nv", 10)
        xy = kwargs.get("xy", (0.3, 1.05))
        title = kwargs.get("title", "contour plot for {}".format(var))

        vrange = np.linspace(vmin, vmax, nv, endpoint=True)
        ## CHOOSE YOUR PROJECTION
        #   ax = plt.axes(projection=ccrs.Orthographic(x.mean(), y.mean()))
        #   ax = plt.axes(projection=ccrs.PlateCarree())
        #   ax.background_patch.set_facecolor('k')

        ax = plt.axes()

        # optional mask for the data
        mask = kwargs.get("mask", None)
        if "mask" in kwargs:
            z = np.ma.masked_array(z, mask)
            z = z.filled(fill_value=-99999)

        for val in ["x", "y", "t", "it", "vmin", "vmax", "title", "nv", "tes", "mask", "xy", "z", "var"]:
            try:
                del kwargs[val]
            except:
                pass

        ax.set_aspect("equal")

        p = plt.tricontour(x, y, tri3, z, vrange, vmin=vmin, vmax=vmax, **kwargs)
        cbar = fig.colorbar(p, ticks=vrange, orientation="vertical")
        if it:

            text = "time={}".format(t[it])
            an = ax.annotate(text, xy=xy, xycoords="axes fraction")

        ax.set_title(title, pad=30)
        plt.xlabel("Longitude (degrees)")
        plt.ylabel("Latitude (degrees)")

        return p  # , ax

    def contourf(self, it=None, **kwargs):

        x = kwargs.get("x", self._obj.SCHISM_hgrid_node_x[:].values)
        y = kwargs.get("y", self._obj.SCHISM_hgrid_node_y[:].values)
        try:
            t = kwargs.get("t", self._obj.time.values)
        except:
            pass
        tes = kwargs.get("tes", self._obj.SCHISM_hgrid_face_nodes.values[:, :4])

        # sort out quads
        try:
            mask = np.isnan(tes)[:, 3]
            tr3 = tes[mask][:, :3]
            tr3_ = quads_to_tris(tes[~mask])
            if tr3_:
                tri3 = np.append(tr3, tr3_, axis=0).astype(int)
            else:
                tri3 = tr3.astype(int)
        except:
            tri3 = tes.astype(int)

        var = kwargs.get("var", "depth")
        z = kwargs.get("z", self._obj[var].values[it, :].flatten())

        vmin = kwargs.get("vmin", z.min())
        vmax = kwargs.get("vmax", z.max())

        nv = kwargs.get("nv", 10)

        title = kwargs.get("title", "contourf plot for {}".format(var))

        vrange = np.linspace(vmin, vmax, nv, endpoint=True)
        ## CHOOSE YOUR PROJECTION
        #   ax = plt.axes(projection=ccrs.Orthographic(grid_x.mean(), grid_y.mean()))
        #    [fig,ax] = kwargs.get('figure',[plt.figure(figsize=(12,8)),plt.axes(projection=ccrs.PlateCarree())])
        #    ax.set_extent([x.min(), x.max(), y.min(), y.max()])
        #     ax.background_patch.set_facecolor('k')

        fig = plt.figure(figsize=(12, 8))
        ax = plt.axes()

        # optional mask for the data
        mask = kwargs.get("mask", None)
        if "mask" in kwargs:
            z = np.ma.masked_array(z, mask)
            z = z.filled(fill_value=-99999)

        xy = kwargs.get("xy", (0.3, 1.05))

        for val in ["x", "y", "t", "it", "z", "vmin", "vmax", "title", "nv", "tes", "mask", "xy", "var", "figure"]:
            try:
                del kwargs[val]
            except:
                pass

        ax.set_aspect("equal")

        p = ax.tricontourf(x, y, tri3, z, vrange, vmin=vmin, vmax=vmax, **kwargs)  # , transform=ccrs.PlateCarree() )
        cbar = fig.colorbar(p, ticks=vrange, orientation="vertical")
        if it:

            text = "time={}".format(t[it])
            an = ax.annotate(text, xy=xy, xycoords="axes fraction")

        ax.set_title(title, pad=30)
        ax.set_xlabel("Longitude (degrees)")
        ax.set_ylabel("Latitude (degrees)")

        return p  # fig, ax

    def quiver(self, it=None, u=None, v=None, title=None, scale=0.1, color="k", **kwargs):

        x = kwargs.get("x", self._obj.SCHISM_hgrid_node_x[:].values)
        y = kwargs.get("y", self._obj.SCHISM_hgrid_node_y[:].values)
        try:
            t = kwargs.get("t", self._obj.time.values)
        except:
            pass

        fig = plt.figure(figsize=(12, 8))
        title = kwargs.get("title", "vector plot for {}".format(title))
        xy = kwargs.get("xy", (0.05, -0.1))

        ## CHOOSE YOUR PROJECTION
        #   ax = plt.axes(projection=ccrs.Orthographic(grid_x.mean(), grid_y.mean()))
        #   ax = plt.axes(projection=ccrs.PlateCarree())
        #   ax.background_patch.set_facecolor('k')

        ax = plt.gca()

        # optional mask for the data
        mask = kwargs.get("mask", None)
        if "mask" in kwargs:
            u = np.ma.masked_array(u, mask)
            v = np.ma.masked_array(v, mask)
            v = v.filled(fill_value=-99999)
            u = u.filled(fill_value=-99999)

        for val in ["x", "y", "t", "it", "u", "v", "title", "tes", "xy", "scale", "mask", "color", "var"]:
            try:
                del kwargs[val]
            except:
                pass

        ax.set_aspect("equal")

        p = plt.quiver(x, y, u, v, angles="xy", scale_units="xy", scale=scale, color=color, **kwargs)
        plt.xlabel("Longitude (degrees)")
        plt.ylabel("Latitude (degrees)")
        ax.set_title(title, pad=30)

        if it:

            text = "time={}".format(t[it])
            an = ax.annotate(text, xy=xy, xycoords="axes fraction")

        return p  # , ax

    def grid(self, **kwargs):

        x = kwargs.get("x", self._obj.SCHISM_hgrid_node_x[:].values)
        y = kwargs.get("y", self._obj.SCHISM_hgrid_node_y[:].values)
        tes = kwargs.get("tes", self._obj.SCHISM_hgrid_face_nodes.values[:, :4])

        # sort out quads
        try:
            mask = np.isnan(tes)[:, 3]
            tri3 = tes[mask][:, :3].astype(int).tolist()
            quads = tes[~mask].astype(int).tolist()
        except:
            tri3 = tes.astype(int)
            quads = []

        for val in ["x", "y", "tes"]:
            try:
                del kwargs[val]
            except:
                pass

        fig = plt.figure(figsize=(12, 8))
        ax = plt.gca()
        # ax = plt.axes(projection=ccrs.PlateCarree())
        # ax.background_patch.set_facecolor('k')

        ax.set_aspect("equal")

        g = plt.triplot(x, y, tri3, "go-", **kwargs)  # , lw=.5, markersize=5)#, transform=ccrs.PlateCarree() )

        lw = kwargs.get("lw", plt.rcParams["lines.linewidth"])
        # https://stackoverflow.com/questions/52202014/how-can-i-plot-2d-fem-results-using-matplotlib
        for element in quads:
            x_ = [x[element[i]] for i in range(len(element))]
            y_ = [y[element[i]] for i in range(len(element))]
            plt.fill(x_, y_, edgecolor="green", fill=False, lw=lw)

        title = kwargs.get("title", "Grid plot")
        ax.set_title(title, pad=30)
        ax.set_xlabel("Longitude (degrees)")
        ax.set_ylabel("Latitude (degrees)")

        return g

    def qframes(self, u=None, v=None, scale=0.01, color="k", **kwargs):

        x = kwargs.get("x", self._obj.SCHISM_hgrid_node_x[:].values)
        y = kwargs.get("y", self._obj.SCHISM_hgrid_node_y[:].values)

        cr = kwargs.get("coastlines", None)
        c_attrs = kwargs.get("coastlines_attrs", {})

        t = kwargs.get("t", self._obj.time.values)

        #        ax = plt.axes(projection=ccrs.PlateCarree())
        #  ax.set_extent([x.min(), x.max(), y.min(), y.max()])

        fig = plt.figure(figsize=(12, 8))
        ax = plt.gca()

        ax.set_aspect("equal")

        title = kwargs.get("title", None)

        step = kwargs.get("step", 1)  # change accordingly to fit your needs

        Q = ax.quiver(x, y, u[0, :], v[0, :], pivot="mid", color=color, angles="xy", scale_units="xy", scale=scale)

        #        if cr is not None:
        #            try:
        #                coastl = gp.GeoDataFrame.from_file(cr)
        #            except:
        #                coastl = gp.GeoDataFrame(cr)
        #            coastl.plot(ax=ax, **c_attrs)

        ax.set_xlim(x.min(), x.max())
        ax.set_ylim(y.min(), y.max())
        ax.set_title(title)
        # ax.set_global()

        # you need to set blit=False, or the first set of arrows never gets
        # cleared on subsequent frames
        v = animation.FuncAnimation(
            fig, update_qframes, fargs=(Q, u, v), blit=False, repeat=False, frames=range(0, np.size(t))
        )

        plt.close()

        return v

    def frames(self, **kwargs):

        cr = kwargs.get("coastlines", None)
        c_attrs = kwargs.get("coastlines_attrs", {})

        x = kwargs.get("x", self._obj.SCHISM_hgrid_node_x[:].values)
        y = kwargs.get("y", self._obj.SCHISM_hgrid_node_y[:].values)
        t = kwargs.get("t", self._obj.time.values)
        tes = kwargs.get("tes", self._obj.SCHISM_hgrid_face_nodes.values[:, :4])

        # sort out quads
        try:
            mask = np.isnan(tes)[:, 3]
            tr3 = tes[mask][:, :3]
            tr3_ = quads_to_tris(tes[~mask])
            if tr3_:
                tri3 = np.append(tr3, tr3_, axis=0).astype(int)
            else:
                tri3 = tr3.astype(int)
        except:
            tri3 = tes.astype(int)

        var = kwargs.get("var", "depth")
        z = kwargs.get("z", self._obj[var].values)

        # set figure size
        xr = x.max() - x.min()
        yr = y.max() - y.min()
        ratio = yr / xr
        xf = 12
        yf = np.ceil(12 * ratio).astype(int)

        fig = plt.figure(figsize=(xf, yf))
        vmin = kwargs.get("vmin", z.min())
        vmax = kwargs.get("vmax", z.max())

        nv = kwargs.get("nv", 10)

        title = kwargs.get("title", None)

        vrange = np.linspace(vmin, vmax, nv, endpoint=True)

        # optional mask for the data
        mask = kwargs.get("mask", None)
        if "mask" in kwargs:
            z = np.ma.masked_array(z, mask)
            z = z.filled(fill_value=-99999)

        ## CHOOSE YOUR PROJECTION
        #   ax = plt.axes(projection=ccrs.Orthographic(grid_x.mean(), grid_y.mean()))
        #  ax = plt.axes(projection=ccrs.PlateCarree())
        #   ax.background_patch.set_facecolor('k')
        # Limit the extent of the map to a small longitude/latitude range.
        #    ax.set_extent([x.min(), x.max(), y.min(), y.max()])

        ax = plt.axes()
        ax.set_aspect("equal")

        ims = []
        for i in range(len(t)):
            im = ax.tricontourf(x, y, tri3, z[i, :], vrange, vmin=vmin, vmax=vmax)  # , transform=ccrs.PlateCarree())
            #        im = ax.contourf(x,y,z[i,:,:],v,vmin=v1,vmax=v2,latlon=True)
            add_arts = im.collections
            text = "time={}".format(t[i])
            # te = ax.text(90, 90, text)
            an = ax.annotate(text, xy=(0.05, -0.1), xycoords="axes fraction")
            ims.append(add_arts + [an])

        #            if cr is not None: TO DO
        #                try:
        #                    coastl = gp.GeoDataFrame.from_file(cr)
        #                except:
        #                    coastl = gp.GeoDataFrame(cr)
        #                coastl.plot(ax=ax, **c_attrs)

        if title:
            ax.set_title(title)
        # ax.set_global()
        # ax.coastlines('50m')

        # cbar_ax = fig.add_axes([0.05, 0.05, 0.85, 0.05])
        cbar = fig.colorbar(im, ticks=vrange, orientation="vertical", fraction=0.017, pad=0.04)
        # plt.colorbar()

        v = animation.ArtistAnimation(fig, ims, interval=200, blit=False, repeat=False)

        plt.close()

        return v
