import torch
from torch import nn
from torch import optim
from collections import OrderedDict
import numpy as np
from matplotlib import pyplot as plt
from math import floor

from ts_to_iid import ts_iid
from model import Net

# functions for generating time series, independent observations and repsonse
def gen_y(x, epi):
    #time_len = x.shape[0]
    var_num = x.shape[1]
    y = epi.clone()
    for it in range(var_num):
        #y += (x[:,it].reshape(-1,1))**2
        y += 2*np.cos((x[:,it].reshape(-1,1)))
    return y
    
def gen_x_ts(A_mat, p_dim, d_dim, time_len, Sigma):
    A_mat_new = A_mat.clone()
    for i in range(d_dim):
        A_mat_new[:,i*p_dim:(i+1)*p_dim] = A_mat[:,(d_dim-i-1)*p_dim:(d_dim-i)*p_dim]
    
    epi = torch.tensor(np.random.multivariate_normal(np.zeros(p_dim), Sigma, size=time_len+d_dim))
    x_seq = torch.zeros(time_len + d_dim, p_dim)
    x_seq[0:d_dim,:] = epi[0:d_dim,:]
    x_temp = torch.zeros(p_dim*d_dim, 1)
    for it in range(time_len):
        x_temp = (x_seq[it:(it+d_dim),:]).reshape(p_dim*d_dim, 1)
        x_seq[it+d_dim,:] = torch.mm(A_mat_new, x_temp).reshape(1,-1) + epi[it+d_dim,:]
        #x_temp = (x_seq[it:(it+d_dim),:]).reshape(p_dim*d_dim, 1)
    return x_seq[d_dim:(time_len + d_dim),:], epi[d_dim:(time_len + d_dim),:]


def gen_x_iid(A_mat, p_dim, d_dim, time_len, Sigma):
    Sigma_epi_x = Sigma
    x_seq = torch.tensor(ts_iid(A_mat, p_dim, d_dim, Sigma_epi_x, time_len),dtype=torch.float32)
    return x_seq


lr = 1e-3
rho = 0.2
A_mat = torch.tensor([[0.0,rho,0.0,0.0],[0.0,0.0,rho, 0.0],[0.0,0.0,0.0,rho],[0.0,0.0,0.0,0.0]])
p_dim, d_dim = 4, 1
Sigma = np.eye(p_dim)

# simulation 1 --- time series
n_grid = [100,200,400,800,1600,3200,6400,12800]
ls_loss_ts_train = [] # mean square error for training set
ls_loss_ts_valid = [] # mean square error for validation set
ls_loss_ts_test = [] # mean square error for testing set
ls_loss_ts_test_ne = [] # empirical mean of R(hat(f),f_0)
l1_lambda = 0.1
R_ts = np.zeros((200,4))

for n_ind in range(len(n_grid)):
    time_len = n_grid[n_ind]
    sp1 = floor(time_len*1/2)
    sp2 = floor(time_len*3/4)
    loss1, loss2, loss3, loss4 = 0, 0, 0, 0
    print(time_len)
    for i in range(200):
        #generate data
        x_ts, epi_x_ts = gen_x_ts(A_mat, p_dim, d_dim, time_len, Sigma)
        epi_y_ts = torch.normal(mean=torch.zeros(time_len), 
                           std=torch.tensor([1]*time_len)).reshape(-1,1)
        y_ts = gen_y(x_ts, epi_y_ts)
        
        model_ts = Net(p_dim)  
        criteon = nn.MSELoss() 
        optimizer = optim.Adam(model_ts.parameters(), lr)
        
        # train the neural network
        # training set: 0~sp1, validation set: sp1~sp2 , testing set: sp2~time_len
        loss_valid = 9999
        for epoch in range(20000):
            output_ts = model_ts(x_ts[0:sp1,:])
            l1_reg = torch.tensor(0.)
            for it, param in enumerate(model_ts.parameters()):
                if it % 2 == 0:
                    l1_reg += torch.linalg.norm(param, 1)
            loss_ts_train = criteon(output_ts, y_ts[0:sp1,]) + l1_reg*l1_lambda
            #backprop
            optimizer.zero_grad()   
            loss_ts_train.backward()
            optimizer.step()
            
            if epoch % 100 == 0:
                valid_ts = model_ts(x_ts[sp1:sp2,:])
                loss_valid_temp = criteon(valid_ts, y_ts[sp1:sp2,]).item()
                if loss_valid_temp > loss_valid:
                    break
                else :
                    loss_valid = loss_valid_temp
        
        #print(epoch)
        with torch.no_grad():
            valid_ts = model_ts(x_ts[sp1:sp2,:])
            output_ts = model_ts(x_ts[sp2:time_len,:])
            loss_ts_valid = criteon(valid_ts, y_ts[sp1:sp2,])
            loss_ts = criteon(output_ts, y_ts[sp2:time_len,])
            loss_ts_ne = criteon(output_ts, (y_ts-epi_y_ts)[sp2:time_len,])
        
        R_ts[i,n_ind] = loss_ts_ne
        loss1 += loss_ts_train.item()
        loss2 += loss_ts_valid.item()
        loss3 += loss_ts.item()
        loss4 += loss_ts_ne
    
    ls_loss_ts_train.append(loss1/200)
    ls_loss_ts_valid.append(loss2/200)
    ls_loss_ts_test.append(loss3/200)    
    ls_loss_ts_test_ne.append(loss4/200)    
    

# simulation 1 --- iid
n_grid = np.array([100,200,400,800,1600,3200,6400,12800])
ls_loss_iid_train = []
ls_loss_iid_valid = []
ls_loss_iid_test = []
ls_loss_iid_test_ne = []
l1_lambda = 0.1
R_iid = np.zeros((200,8))

for n_ind in range(8):
    time_len = n_grid[n_ind]
    sp1 = floor(time_len*1/2)
    sp2 = floor(time_len*3/4)
    loss1, loss2, loss3, loss4 = 0, 0, 0, 0
    print(time_len)
    for i in range(200):
        #generate data
        x_iid = gen_x_iid(A_mat, p_dim, d_dim, time_len, Sigma)
        epi_y_iid = torch.normal(mean=torch.zeros(time_len), 
                           std=torch.tensor([1]*time_len)).reshape(-1,1)
        y_iid = gen_y(x_iid, epi_y_iid)
        
        model_iid = Net(p_dim)  
        criteon = nn.MSELoss() 
        optimizer = optim.Adam(model_iid.parameters(), lr)

        # train the neural network
        # training set: 0~sp1, validation set: sp1~sp2 , testing set: sp2~time_len
        loss_valid = 9999
        for epoch in range(20000):
            output_iid = model_iid(x_iid[0:sp1,:])
            l1_reg = torch.tensor(0.)
            for it, param in enumerate(model_iid.parameters()):
                if it % 2 == 0:
                    l1_reg += torch.linalg.norm(param, 1)
            loss_iid_train = criteon(output_iid, y_iid[0:sp1,]) + l1_reg*l1_lambda
            #backprop
            optimizer.zero_grad()   
            loss_iid_train.backward()
            optimizer.step()
            
            if epoch % 100 == 0:
                with torch.no_grad():
                    valid_iid = model_iid(x_iid[sp1:sp2,:])
                    loss_valid_temp = criteon(valid_iid, y_iid[sp1:sp2,]).item()
                    if loss_valid_temp > loss_valid:
                        break
                    else :
                        loss_valid = loss_valid_temp
        
        #print(epoch)
        with torch.no_grad():
            valid_iid = model_iid(x_iid[sp1:sp2,:])
            output_iid = model_iid(x_iid[sp2:time_len,:])
            loss_iid_valid = criteon(valid_iid, y_iid[sp1:sp2,])
            loss_iid = criteon(output_iid, y_iid[sp2:time_len,])
            loss_iid_ne = criteon(output_iid, (y_iid-epi_y_iid)[sp2:time_len,])
         
        R_iid[i,n_ind] = loss_iid_ne
        loss1 += loss_iid_train.item()
        loss2 += loss_iid_valid.item()
        loss3 += loss_iid.item()
        loss4 += loss_iid_ne
    
    ls_loss_iid_train.append(loss1/200)
    ls_loss_iid_valid.append(loss2/200)
    ls_loss_iid_test.append(loss3/200)    
    ls_loss_iid_test_ne.append(loss4/200)    


#--------------------------------------------------------------------------------------
# boxplot for R(hat(f),f_0)
# sample size : 100, 400, 1600, 6400
loss_df = np.zeros((200,8))
for i in range(4):
    loss_df[:,2*i] = R_ts[:,2*i]
    loss_df[:,2*i+1] = R_iid[:,2*i] 

fig = plt.figure(figsize =(7, 5))
ax = fig.add_subplot(111)
bp = ax.boxplot((np.log(loss_df[:,0:8])), patch_artist = True, vert = 1 )

colors = ['#FFFF00', '#00FF00','#FFFF00', '#00FF00',
          '#FFFF00', '#00FF00','#FFFF00', '#00FF00']
 
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)

for whisker in bp['whiskers']:
    whisker.set(color ='#8B008B',
                linewidth = 2,
                linestyle =":")
 
for cap in bp['caps']:
    cap.set(color ='#8B008B',
            linewidth = 2)

ax.set_xticklabels(['100', '100','400','400',
                    '1600', '1600','6400','6400'])

ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()
ax.legend([bp["boxes"][0],bp["boxes"][1]], ['neural network', 'linear model'], loc='upper right',
          prop={'size':12})
plt.xlabel('sample size')
plt.ylabel('logarithm of $\widehat{R}(\widehat{f},f_0)$')
plt.show()
#------------------------------------------------------------------------------------------

#------------------------------------------------------------------------------------------
# plot MSE of training, validation and testing set in one figure.
# time series case:
fig = plt.figure(figsize =(10, 7)) 
#fig.set_size_inches(10,7)
plt.plot(np.log(n_grid), np.log(ls_loss_ts_test_ne), label = "$\widehat{R}(\widehat{f},f_0)$: esimator of  $R(\widehat{f},f_0)$")
plt.plot(np.log(n_grid), np.log(ls_loss_ts_train),label = "MSE of training set")
plt.plot(np.log(n_grid), np.log(ls_loss_ts_valid),label = "MSE of validation set")
plt.plot(np.log(n_grid), np.log(ls_loss_ts_test), label ="MSE of testing set")
plt.legend(loc='best',prop={'size':11})
#plt.title('trainging error')
plt.xlabel('logarithm of sample size')
plt.ylabel('logarithm of MSE')
plt.show()

#------------------------------------------------------------------------------------------
# plot MSE of training, validation and testing set in one figure.
# iid case:
fig = plt.figure(figsize =(10, 7)) 
#fig.set_size_inches(10,7)
plt.plot(np.log(n_grid), np.log(ls_loss_iid_test_ne), label = "$\widehat{R}(\widehat{f},f_0)$: esimator of  $R(\widehat{f},f_0)$")
plt.plot(np.log(n_grid), np.log(ls_loss_iid_train),label = "MSE of training set")
plt.plot(np.log(n_grid), np.log(ls_loss_iid_valid),label = "MSE of validation set")
plt.plot(np.log(n_grid), np.log(ls_loss_iid_test), label ="MSE of testing set")
plt.legend(loc='best',prop={'size':11})
#plt.title('trainging error')
plt.xlabel('logarithm of sample size')
plt.ylabel('logarithm of MSE')
plt.show()

