import torch
from torch import nn
from torch import optim
import numpy as np
from math import floor
from matplotlib import pyplot as plt

from model import Net
#from torch.nn import functional as F 


def trans_to_tensor(x, time_len, lag):
    time_length = time_len - lag
    data = np.mat(np.zeros((time_length,lag)))
    for i in range(time_length):
        data[i,] = np.flip(x[i:(lag+i)])
    x_tensor = torch.tensor(data).float().view(time_len-lag,lag)
    y_tensor = torch.tensor(x[lag:]).float().view(time_len-lag,1)
    return x_tensor, y_tensor

def AIC(x, y, d_dim):
    num = x.shape[0]
    aic = np.zeros(d_dim)
    for i in range(d_dim):
        x_temp = x[:,0:(i+1)]
        xx_inv = torch.linalg.inv(torch.mm(x_temp.T,x_temp))
        xy = torch.mm(x_temp.T,y)
        y_pred = torch.mm(x_temp,torch.mm(xx_inv,xy))
        SSE = ((y_pred - y)**2).sum().numpy()
        aic[i] = num*np.log(SSE) + 2*i
    return np.argmin(aic)


def gen1(A_mat, time_len, d_dim, lag_AR):
    epi = np.random.normal(loc=0.0, scale = 1.0, size = time_len+ d_dim +lag_AR)
    x_seq = np.zeros(time_len + d_dim +lag_AR)
    x_seq[0:d_dim] = epi[0:d_dim]
    for i in range(time_len +lag_AR):
        x_temp = torch.tensor(x_seq[i:(i+d_dim)],dtype=torch.float32).reshape(d_dim, 1)
        x_seq[i+d_dim] = torch.mm(A_mat,x_temp) + epi[i+d_dim]
    
    y_no_epi = x_seq[d_dim:] - epi[d_dim:]
    x_ts, y_ts = trans_to_tensor(x_seq[d_dim:], len(x_seq[d_dim:]), lag_AR)
    epi_y_ts = torch.tensor(epi[(d_dim+lag_AR):(time_len+d_dim+lag_AR)]).view(-1,1)
    y_no_epi = torch.tensor(y_no_epi[lag_AR:]).view(-1,1)
    return x_ts, y_ts, y_no_epi, epi_y_ts 

def gen2(time_len, d_dim, lag_AR):
    epi = np.random.normal(loc=0.0, scale = 1.0, size = time_len+ d_dim +lag_AR)
    x_seq = np.zeros(time_len + d_dim +lag_AR)
    x_seq[0:d_dim] = epi[0:d_dim]
    for i in range(time_len +lag_AR):
        x_temp = (x_seq[i:(i+d_dim)]).reshape(d_dim, 1)
        x_seq[i+d_dim] = 0.5*((abs(x_temp))**(0.5)).sum() + epi[i+d_dim]
    
    y_no_epi = x_seq[d_dim:] - epi[d_dim:]
    x_ts, y_ts = trans_to_tensor(x_seq[d_dim:], len(x_seq[d_dim:]), lag_AR)
    epi_y_ts = torch.tensor(epi[(d_dim+lag_AR):(time_len+d_dim+lag_AR)]).view(-1,1)
    y_no_epi = torch.tensor(y_no_epi[lag_AR:]).view(-1,1)
    return x_ts, y_ts, y_no_epi, epi_y_ts  


def gen3(time_len, d_dim, lag_AR):
    epi = np.random.normal(loc=0.0, scale = 1.0, size = time_len+ d_dim +lag_AR)
    x_seq = np.zeros(time_len + d_dim +lag_AR)
    x_seq[0:d_dim] = epi[0:d_dim]
    for i in range(time_len +lag_AR):
        x_temp = (x_seq[i:(i+d_dim)]).reshape(d_dim, 1)
        x_seq[i+d_dim] = (0.5*abs(x_temp)).sum() + epi[i+d_dim]
    
    y_no_epi = x_seq[d_dim:] - epi[d_dim:]
    x_ts, y_ts = trans_to_tensor(x_seq[d_dim:], len(x_seq[d_dim:]), lag_AR)
    epi_y_ts = torch.tensor(epi[(d_dim+lag_AR):(time_len+d_dim+lag_AR)]).view(-1,1)
    y_no_epi = torch.tensor(y_no_epi[lag_AR:]).view(-1,1)
    return x_ts, y_ts, y_no_epi, epi_y_ts  



#-----------------------------------------------------------------------------------
#simulation 2 : neural network 
lag = 50 # AIC to select the best lag from set {1,2,...,50}
ls_p = np.zeros((4,200)) # restore the value of lag we select in each replicate
lr = 1e-3
l1_lambda = 0.1
R_ts = np.zeros((200,4))
n_grid = np.array([100,400,1600,6400])
for n_ind in range(len(n_grid)):
    time_len = n_grid[n_ind]
    sp1 = floor(time_len*1/2)
    sp2 = floor(time_len*3/4)
    print(time_len)
    for i in range(200):
        #model 1:
        x_ts_full, y_ts, y_no_epi, epi_y_ts = gen1(torch.tensor([[0.6]]),time_len, 1, lag)
        #model 2:
        #x_ts_full, y_ts, y_no_epi, epi_y_ts = gen1(torch.tensor([[0.6,-0.4,0.2]]),time_len, 3, lag)
        #model3:
        #x_ts_full, y_ts, y_no_epi, epi_y_ts = gen2(time_len, 1,lag)
        #model 4:
        #x_ts_full, y_ts, y_no_epi, epi_y_ts = gen3(time_len, 1, lag)
        
        p_dim = AIC(x_ts_full, y_ts, lag) + 1
        x_ts = x_ts_full[:,0:p_dim]
        ls_p[n_ind, i] = p_dim
        
        model_ts = Net(p_dim)
        criteon = nn.MSELoss() 
        optimizer = optim.Adam(model_ts.parameters(), lr)
        
        # train the neural network
        # training set: 0~sp1, validation set: sp1~sp2 , testing set: sp2~time_len
        loss_valid = 9999
        for epoch in range(4000):
            model_ts.training = True
            output_ts = model_ts(x_ts[0:sp1,:])
            l1_reg = torch.tensor(0.)
            for it, param in enumerate(model_ts.parameters()):
                if it % 2 == 0:
                    l1_reg += torch.linalg.norm(param, 1)
            loss_ts_train = criteon(output_ts, y_ts[0:sp1,]) + l1_reg*l1_lambda      
            #backprop
            optimizer.zero_grad()
            loss_ts_train.backward()
            optimizer.step()
             
            if epoch % 20 == 0:
                valid_ts = model_ts(x_ts[sp1:sp2,:])
                loss_valid_temp = criteon(valid_ts, y_ts[sp1:sp2,]).item()
                if loss_valid_temp > loss_valid:
                    break
                else :
                    loss_valid = loss_valid_temp
                    
        with torch.no_grad():
            model_ts.training = False   
            output_ts = model_ts(x_ts[sp2:time_len,:])
            loss_ts_ne = criteon(output_ts, (y_ts-epi_y_ts)[sp2:time_len,])
    
        R_ts[i,n_ind] = loss_ts_ne

        
    
#---------------------------------------------------------------------------
#simulation 2 : linear regression
R_lm = np.zeros((200,4))
n_grid = np.array([100,400,1600,6400])
for n_ind in range(len(n_grid)):
    time_len = n_grid[n_ind]
    sp = floor(time_len*3/4)
    print(time_len)
    for i in range(200):
        #model 1:
        x_ts_full, y_ts, y_no_epi, epi_y_ts = gen1(torch.tensor([[0.6]]),time_len, 1, lag)
        #model 2:
        #x_ts_full, y_ts, y_no_epi, epi_y_ts = gen1(torch.tensor([[0.6,-0.4,0.2]]),time_len, 3, lag)
        #model3:
        #x_ts_full, y_ts, y_no_epi, epi_y_ts = gen2(time_len, 1,lag)
        #model 4:
        #x_ts_full, y_ts, y_no_epi, epi_y_ts = gen3(time_len, 1, lag)
        
        p_dim = AIC(x_ts_full, y_ts, lag)+1
        x_ts = x_ts_full[:,0:p_dim]
        
        xx_inv = torch.linalg.inv(torch.mm(x_ts[0:sp,].T,x_ts[0:sp,]))
        xy = torch.mm(x_ts[0:sp,].T,y_ts[0:sp,])
        y_pred = torch.mm(x_ts,torch.mm(xx_inv,xy))
        
        R_lm[i,n_ind] = criteon(y_pred[(sp+1):time_len,], y_no_epi[(sp+1):time_len,])
        
        
#--------------------------------------------------------------------------------------
# boxplot for R(hat(f),f_0)
# sample size : 100, 400, 1600, 6400
loss_df = np.zeros((200,8))
for i in range(4):
    loss_df[:,2*i] = R_ts[:,i]
    loss_df[:,2*i+1] = R_lm[:,i] 

fig = plt.figure(figsize =(7, 5))
ax = fig.add_subplot(111)
bp = ax.boxplot((np.log(loss_df[:,0:8])), patch_artist = True, vert = 1 )

colors = ['#FFFF00', '#00FF00','#FFFF00', '#00FF00',
          '#FFFF00', '#00FF00','#FFFF00', '#00FF00']
 
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)

for whisker in bp['whiskers']:
    whisker.set(color ='#8B008B',
                linewidth = 2,
                linestyle =":")
 
for cap in bp['caps']:
    cap.set(color ='#8B008B',
            linewidth = 2)

ax.set_xticklabels(['100', '100','400','400',
                    '1600', '1600','6400','6400'])

ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()
ax.legend([bp["boxes"][0],bp["boxes"][1]], ['neural network', 'linear model'], loc='upper right',
          prop={'size':12})
plt.xlabel('sample size')
plt.ylabel('logarithm of $\widehat{R}(\widehat{f},f_0)$')
plt.show()
#------------------------------------------------------------------------------------------
