import pandas as pd
import numpy as np

def multi_assets(prz:'df', w_tar:'df', dates_reb:'arr', fee = 0, ret_ext = False):
    # w_tar: target weight
    ret = prz.pct_change()
    noa = prz.shape[1]   # number of assets
    w = pd.Series(0.0, index = w_tar.columns)    # the current weight
    dates = prz.index.values

    rst = pd.DataFrame(index=prz.index, columns = ['reb','ret_gross','fee','ret_net'])
    w_curr = pd.DataFrame(index=prz.index, columns=prz.columns)
    
    # 第一天
    date0 = dates[0]
    rst.at[date0,'ret_gross'] = 0.0
    if date0 in dates_reb:
        w1 = w_tar.loc[date0]  #the target weight
        rst.at[date0,'fee'] = (abs(w1-w) * fee).sum()
        w = w1
        rst.at[date0,'reb'] = 1
    else:
        rst.at[date0,'fee'] = 0.0
        rst.at[date0,'reb'] = 0  
        
    rst.at[date0,'ret_net'] = rst.at[date0,'ret_gross'] - rst.at[date0,'fee']
    w_curr.loc[date0] = w
    
    # 第一天之后
    for date in rst.index[1:]:
        rst.at[date,'ret_gross'] = (ret.loc[date] * w).sum()        
        if date in dates_reb:
            w1 = w_tar.loc[date]  #the target weight
            w = w * (1 + ret.loc[date])
            wsum = w.sum()
            if wsum != 0:
                w = w / wsum
            rst.at[date,'fee'] = (abs(w1-w) * fee).sum()
            rst.at[date,'reb'] = 1
            w = w1
        else:
            w = w * (1 + ret.loc[date])  
            wsum = w.sum()
            if wsum != 0:
                w = w / wsum
            rst.at[date,'fee'] = 0.0
            rst.at[date,'reb'] = 0
    
        rst.at[date,'ret_net'] = rst.at[date,'ret_gross'] - rst.at[date,'fee']
        w_curr.loc[date] = w
        
    rst['nav'] = (1 + rst.ret_net).cumprod()
    
    if ret_ext:
        level1 = np.array(['prz'] * noa + ['ret'] * noa + ['w_tar'] * noa + ['w_curr'] * noa + ['rst'] * len(rst.columns))
        level2 = list(prz.columns) * 4 + rst.columns.tolist()
        df = pd.DataFrame(columns = [level1,level2], index = prz.index)
        df['prz'] = prz
        df['ret'] = ret
        df['w_tar'] = w_tar
        df['rst'] = rst
        df['w_curr'] = w_curr
        return df
    else:
        return rst

def time_series(prz:'arr', pos:'arr', fee = 0, ret_ext = False):
    df = pd.DataFrame(prz)
    df['price'] = prz
    df['position'] = pos
    df['return'] = df.price.pct_change()
    df['value'] = df['position'] * df['price']
    position_last = df.position.shift(1)
    position_last[0] = 0
    df['fee'] = abs(df.position - position_last)*fee
    df['strategy_return_before_fee'] = position_last * df['return']
    df['strategy_return_before_fee'].fillna(0,inplace=True)
    df['strategy_return'] = df['strategy_return_before_fee'] - df.fee
    df['nav'] = (df['strategy_return']+1).cumprod()
    df['nav_before_fee'] = (df['strategy_return_before_fee']+1).cumprod()
    df['nav_price'] = df['price']/df['price'][0]
    df['benchmark'] = df.nav_price
    
    df['fee_amt'] = df.fee * df.price
    df['pnl_before_fee'] = df.price.diff().fillna(0) * position_last
    df['pnl'] = df.pnl_before_fee - df.fee_amt
    df['pnl_accm_before_fee'] = df.pnl_before_fee.cumsum()
    df['pnl_accm'] = df.pnl.cumsum()
    df['pnl_accm_benchmark'] = df.price - df.price.values[0]
    
    if ret_ext:
        return df
    else:
        return df.nav.values


if __name__ == '__main__':
    #test
    prz = pd.read_csv('Data/prz.csv',index_col = 'index', encoding = 'gbk')
    w_tar = pd.read_csv('Data/w_tar.csv',index_col = 'index')
#    rst = multi_assets(prz,w_tar,w_tar.index.values,fee=0/1000)
#    rst.nav.plot()
    df = multi_assets(prz,w_tar,w_tar.index.values,fee=2/1000, ret_ext = True)
    df.rst.reset_index().nav.plot()
    df.to_csv('tmp.csv')