from __future__ import division
import seaborn as sns
import six

from matplotlib import colors, cm
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore")

def _get_grid_points(x, num_grid_points):
    if num_grid_points is None:
        return x.unique()
    else:
        # unique is necessary, because if num_grid_points is too much larger
        # than x.shape[0], there will be duplicate quantiles (even with
        # interpolation)
        return x.quantile(np.linspace(0, 1, num_grid_points)).unique()


def _get_point_x_ilocs(grid_index, data_index):
    data_level = 'data_{}'.format(grid_index.name)

    return (np.abs(np.subtract
                      .outer(grid_index,
                             data_index.get_level_values(data_level)))
              .argmin(axis=0))


def _get_quantiles(x):
    return np.greater.outer(x, x).sum(axis=1) / x.size

def _ax_hist(ax, x, **kwargs):
    '''画x轴数据分布的小竖条'''
    sns.rugplot(x, ax=ax, alpha=0.2)

def ice(data, column, predict, num_grid_points=None):
    """
    Generate individual conditional expectation (ICE) curves for a model.

    :param data: the sample data from which to generate ICE curves
    :type data: ``pandas`` ``DataFrame``

    :param column: the name of the column in ``data`` that will be varied to
        generate ICE curves
    :type column: ``str``

    :param predict: the function that generates predictions from the model.
        Must accept a ``DataFrame`` with the same columns as ``data``.
    :type predict: callable

    :param num_grid_points: the number of grid points to use for the independent
        variable of the ICE curves. The independent variable values for the
        curves will be quantiles of the data.

        If ``None``, the values of the independent variable will be the unique
        values of ``data[column]``.
    :type num_grid_points: ``None`` or ``int``

    :return: A ``DataFrame`` whose columns are ICE curves.  The row index is the
        independent variable, and the column index is the original data point
        corresponding to that ICE curve.
    :rtype: ``pandas`` ``DataFrame``
    """
    x_s = _get_grid_points(data[column], num_grid_points)
    ice_data, orig_column = _to_ice_data(data, column, x_s)
    ice_data['ice_y'] = predict(ice_data.values)
    ice_data['data_{}'.format(column)] = orig_column

    other_columns = ['data_{}'.format(column)] + [col for col in data.columns if col != column]
    ice_data = ice_data.pivot_table(values='ice_y', index=other_columns, columns=column).T

    return ice_data


def ice_plot(ice_data, column,frac_to_plot=1.,
             plot_points=False, point_kwargs=None,
             x_quantile=False, plot_pdp=False,
             centered=False, centered_quantile=0.,
             color_by=None, cmap=None,
             ax=None, pdp_kwargs=None, **kwargs):
    """
    Plot the ICE curves

    :param ice_data: the ICE data generated by :func:`pycebox.ice.ice`
    :type ice_data: ``pandas`` ``DataFrame``

    :param frac_to_plot: the fraction of ICE curves to plot.  If less than one,
        randomly samples columns of ``ice_data`` to plot.
    :type frac_to_plot: ``float``

    :param plot_points: whether or not to plot the original data points on the
        ICE curves.  In this case, ``point_kwargs`` is passed as keyword
        arguments to plot.
    :type plot_points: ``bool``

    :param x_quantile: if ``True``, the plotted x-coordinates are the quantiles of
        ``ice_data.index``
    :type x_quantile: ``bool``

    :param plot_pdp: if ``True``, plot the partial depdendence plot.  In this
        case, ``pdp_kwargs`` is passed as keyword arguments to ``plot``.

    :param centered: if ``True``, each ICE curve is centered to zero at the
        percentile closest to ``centered_quantile``.
    :type centered: ``bool``

    :param color_by:  If a string, color the ICE curve by that level of the
        column index.

        If callable, color the ICE curve by its return value when applied to a
        ``DataFrame`` of the column index of ``ice_data``
    :type color_by: ``None``, ``str``, or callable

    :param cmap:
    :type cmap: ``matplotlib`` ``Colormap``

    :param ax: the ``Axes`` on which to plot the ICE curves
    :type ax: ``None`` or ``matplotlib`` ``Axes``

    Other keyword arguments are passed to ``plot``
    """
    if not ice_data.index.is_monotonic_increasing:
        ice_data = ice_data.sort_index()

    if centered:
        quantiles = _get_quantiles(ice_data.index)
        centered_quantile_iloc = np.abs(quantiles - centered_quantile).argmin()
        ice_data = ice_data - ice_data.iloc[centered_quantile_iloc]

    if frac_to_plot < 1.:
        n_cols = ice_data.shape[1]
        icols = np.random.choice(n_cols, size=frac_to_plot * n_cols, replace=False)
        plot_ice_data = ice_data.iloc[:, icols]
    else:
        plot_ice_data = ice_data


    if x_quantile:
        x = _get_quantiles(ice_data.index)
    else:
        x = ice_data.index

    if plot_points:
        point_x_ilocs = _get_point_x_ilocs(plot_ice_data.index, plot_ice_data.columns)
        point_x = x[point_x_ilocs]
        point_y = plot_ice_data.values[point_x_ilocs, np.arange(point_x_ilocs.size)]

    if ax is None:
        _, ax = plt.subplots()

    if color_by is not None:
        if isinstance(color_by, six.string_types):
            colors_raw = plot_ice_data.columns.get_level_values(color_by).values
        elif hasattr(color_by, '__call__'):
            col_df = pd.DataFrame(list(plot_ice_data.columns.values), columns=plot_ice_data.columns.names)
            colors_raw = color_by(col_df)
        else:
            raise ValueError('color_by must be a string or function')

        norm = colors.Normalize(colors_raw.min(), colors_raw.max())
        m = cm.ScalarMappable(norm=norm, cmap=cmap)

        for color_raw, (_, ice_curve) in zip(colors_raw, plot_ice_data.iteritems()):
            c = m.to_rgba(color_raw)
            ax.plot(x, ice_curve, c=c, zorder=0, **kwargs)
    else:
        ax.plot(x, plot_ice_data, zorder=0, c='grey', linewidth=0.5, **kwargs)

    if plot_points:
        ax.scatter(point_x, point_y, zorder=10, s=5, c='black',**(point_kwargs or {}))

    if plot_pdp:
        pdp_kwargs = pdp_kwargs or {}
        pdp_data = pdp(ice_data)
        ax.plot(x, pdp_data, color='yellow',linestyle='solid',linewidth=2,**pdp_kwargs)


    ax.set_ylabel("ICE value")
    ax.set_xlabel(column)
    # sns.rugplot(x=x,ax=ax,alpha=0.2)
    sns.rugplot(a=x, ax=ax, alpha=0.2)

    plt.show()

    return ax


def pdp(ice_data):
    """
    Calculate a partial dependence plot from ICE data

    :param ice_data: the ICE data generated by :func:`pycebox.ice.ice`
    :type ice_data: ``pandas`` ``DataFrame``

    :return: the partial dependence plot curve
    :rtype: ``pandas`` ``Series``
    """
    return ice_data.mean(axis=1)


def _to_ice_data(data, column, x_s):
    """
    Create the DataFrame necessary for ICE calculations
    """
    ice_data = pd.DataFrame(np.repeat(data.values, x_s.size, axis=0), columns=data.columns)
    data_column = ice_data[column].copy()
    ice_data[column] = np.tile(x_s, data.shape[0])

    return ice_data, data_column
