# some data processing function #

import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def create_date_df(startDate = '20200101', endDate = '20210101'):
  date_list = [datetime.strftime(x, '%Y-%m-%d') for x in list(pd.date_range(start = startDate, end= endDate))]
  date_pd = pd.DataFrame(date_list)
  date_pd.rename(columns={0:'DATE'}, inplace= True)
  date_pd = date_pd.sort_values(by='DATE')
  date_pd['DATE'] = pd.to_datetime(date_pd['DATE'], infer_datetime_format=True)
  return date_pd

def sg_holiday_feature(holiday_df, startDate = '20200101', endDate = '20210101'):
  date_pd = create_date_df(startDate = startDate, endDate = endDate)
  holiday_df.rename(columns = {'Date':'DATE'}, inplace = True)
  holiday_df =holiday_df.sort_values(by='DATE')
  holiday_df['DATE'] = pd.to_datetime(holiday_df['DATE'], infer_datetime_format=True)
  df = date_pd.merge(holiday_df, on='DATE', how = 'left')
  df['Holiday'] = df['Holiday'].fillna('Non-Holiday')
  df = df[['DATE', 'Holiday']]
  df_dummy = pd.get_dummies(df, columns = ['Holiday'])
  return df_dummy

def get_dummy_value(df, dummy_columns):
  df_dummy = pd.get_dummies(df, columns = dummy_columns)
  return df_dummy

def get_date_dummy(df, date_column = 'DATE'):
  df['month'] = df[date_column].dt.month
  df['weekofyear'] = df[date_column].map(lambda x:x.isocalendar()[1])
  df['dayofweek'] = df[date_column].map(lambda x:x.dayofweek+1)
  raw_data_dummy = pd.get_dummies(df, columns=[ 'month', 'weekofyear', 'dayofweek'])
  return raw_data_dummy

def display_heatmap(df, show_font = False, show_square = True, picture_size = [18, 18]):
  plt.figure(figsize=picture_size, dpi=100)
  if show_square:
    sns.heatmap(data = df.corr(), vmax=0.3, annot=show_font, fmt='.2f')
  else:
    sns.heatmap(data = df.corr(), vmax=0.3, annot=False, fmt=".2f", mask=np.triu(np.ones_like(df.corr(), dtype=np.bool)), square=True, linewidths=.1)

def show_draft_plot(datas, x_label, title, legend, picture_size=[18, 5], shape = []):
    if shape == []:
        shape = np.zeros(len(datas), dtype=np.int64) 
    plt.rcParams["figure.figsize"] = picture_size
    for i in range(len(datas)):
        if shape[i] == 0 or shape[i] == 'line':
            plt.plot(x_label, datas[i], label=legend[i])
        if shape[i] == 1 or shape[i] == 'dot':
            plt.plot(x_label, datas[i], 'o', label=legend[i])
    plt.title(title)
    plt.legend(loc="best", shadow=True)
    plt.xticks(rotation= 45)
    plt.grid()
    plt.show()

def switch_y_column(df, column_name):
    c = df[column_name]
    new_df = df.drop(columns=column_name, axis=1)
    new_df.insert(new_df.shape[1], column_name, c)
    return new_df