import os
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import math
from fftconv_no_skip import FFTConv
import copy
from set_opt import setup_optimizer
import matplotlib.font_manager as fm



T, delta_T = 1000, 1
L = int(T/delta_T)
B = 100
n_test = 1000
global_lr = 1e-2
local_lr = 1e-3
weight_decay = 0.01


# input sequence
def delta(t, b):
    a = torch.exp(- (t/b)**2) / math.sqrt(math.pi) / b
    return a

t = torch.arange(1, 1+L) / L * T 
expanded_tensor = t.unsqueeze(0).expand(L, -1)
transposed_tensor = t.unsqueeze(1).expand(-1, L)



def get_data(b):
    K = delta(expanded_tensor - transposed_tensor, b)
    distribution = torch.distributions.MultivariateNormal(torch.zeros(L), covariance_matrix=K)

    # for delay network
    X_train = distribution.sample((B, )).unsqueeze(1)
    y_train = X_train[..., -int(L/2)].sin()
    X_test = distribution.sample((n_test, )).unsqueeze(1)
    y_test = X_test[..., -int(L/2)].sin()
    train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
    test_dataset = torch.utils.data.TensorDataset(X_test, y_test)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100)

    return train_loader, test_loader



# plot statistics at initialization without training
def get_init_result(model, train_loader, length, normalization=False):
    c_model = copy.deepcopy(model)
    output_norm = 0
    grad_norm = 0
    for i, (input, target) in enumerate(train_loader):
        input = input[..., :length]
        input, target = input.cuda(), target.cuda()
        if i == 0 and normalization:
            with torch.no_grad():
                _, _, measure_std, measure_mean = c_model(input)
                m_std = measure_std[:,:,-1]
                m_mean = measure_mean[:,:,-1].abs()
                reg = (m_std + m_mean).mean()**2
                c_model.kernel.C.data = c_model.kernel.C.data / reg.sqrt()
        output, _, measure_std, measure_mean = c_model(input)
        output = output[..., -1]
        output_norm += (output.norm(p=1) / output.shape[0]).detach().cpu()
        loss = criterion(output, target)
        loss.backward()
        grad = [p.grad.view(-1) for p in c_model.parameters() if p.grad is not None]
        grad = torch.cat(grad).detach().cpu()
        grad_norm += grad.norm()**2

    output_norm = output_norm / len(train_loader)
    grad_norm = grad_norm / len(train_loader)

    return output_norm, grad_norm.sqrt()





def train(model, train_loader, criterion, optimizer, lam=0.0):
    model.train()
    train_loss = 0
    for _, (input, target) in enumerate(train_loader):
        input, target = input.cuda(), target.cuda()
        optimizer.zero_grad()
        output, _, measure_std, measure_mean = model(input)
        output = output[...,-1]
        m_std = measure_std[:,:,-1]
        m_mean = measure_mean[:,:,-1].abs()
        reg = (m_std + m_mean).mean()**2
        loss = criterion(output, target) + lam * reg
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss = train_loss / len(train_loader)

    return train_loss



def test(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for _, (input, target) in enumerate(test_loader):
            input, target = input.cuda(), target.cuda()
            output, *_ = model(input)
            output = output[...,-1]
            loss = criterion(output, target)
            test_loss += loss.item()
    
    test_loss = test_loss / len(test_loader)
    
    return test_loss



epochs = 100
model = FFTConv(
                d_model=1, 
                activation=None,
                mode='s4', 
                lr={'dt': local_lr, 'A': local_lr, 'B': local_lr, 'C': global_lr},
                wd={'dt': 0.0, 'A': 0.0, 'B': 0.0, 'C': weight_decay},
                ).cuda()
criterion = nn.MSELoss()
ls = torch.arange(10, L ,10)
final_train_loss_wo_normal, final_test_loss_wo_normal, final_train_loss_normal, final_test_loss_normal = [], [], [], []
final_train_loss_reg, final_test_loss_reg, final_train_loss_reg_normal, final_test_loss_reg_normal = [], [], [], []
for _ in range(3):
    output_norm_wo_normal, grad_norm_wo_normal, output_norm_normal, grad_norm_normal = [], [], [], []
    train_loss_wo_normal, test_loss_wo_normal, train_loss_normal, test_loss_normal = [], [], [], []
    train_loss_reg, test_loss_reg, train_loss_reg_normal, test_loss_reg_normal = [], [], [], []
    bs = [0.01, 0.1, 1]
    for b in bs:
        train_loader, test_loader = get_data(b)
        for length in ls:
            output_norm_wo_normal_, grad_norm_wo_normal_ = get_init_result(model, train_loader, length, normalization=False)
            output_norm_normal_, grad_norm_normal_ = get_init_result(model, train_loader, length, normalization=True)
            output_norm_normal.append(output_norm_normal_)
            grad_norm_normal.append(grad_norm_normal_)
            output_norm_wo_normal.append(output_norm_wo_normal_)
            grad_norm_wo_normal.append(grad_norm_wo_normal_)
        wo_normal_model = copy.deepcopy(model)
        normal_model = copy.deepcopy(model)
        reg_model = copy.deepcopy(model)
        reg_normal_model = copy.deepcopy(model)
        wo_normal_opt, wo_normal_scheduler = setup_optimizer(wo_normal_model, global_lr, weight_decay, epochs)
        reg_opt, reg_scheduler = setup_optimizer(reg_model, global_lr, weight_decay, epochs)
        for i, (input, _) in enumerate(train_loader):
            if i == 0:
                input = input.cuda()
                with torch.no_grad():
                    _, _, measure_std, measure_mean = normal_model(input)
                    m_std = measure_std[:,:,-1]
                    m_mean = measure_mean[:,:,-1].abs()
                    reg = (m_std + m_mean).mean()**2
            else:
                break
        normal_model.kernel.C.data = normal_model.kernel.C.data / reg.sqrt()
        reg_normal_model.kernel.C.data = reg_normal_model.kernel.C.data / reg.sqrt()
        normal_opt, normal_scheduler = setup_optimizer(normal_model, global_lr, weight_decay, epochs)
        reg_normal_opt, reg_normal_scheduler = setup_optimizer(reg_normal_model, global_lr, weight_decay, epochs)
        
        for _ in range(epochs):
            train_loss_wo_normal_ = train(wo_normal_model, train_loader, criterion, wo_normal_opt)
            test_loss_wo_normal_ = test(wo_normal_model, test_loader, criterion)
            train_loss_normal_ = train(normal_model, train_loader, criterion, normal_opt)
            test_loss_normal_ = test(normal_model, test_loader, criterion)
            train_loss_reg_ = train(reg_model, train_loader, criterion, reg_opt, lam=1e-2)
            test_loss_reg_ = test(reg_model, test_loader, criterion)
            train_loss_reg_normal_ = train(reg_normal_model, train_loader, criterion, reg_normal_opt, lam=1e-2)
            test_loss_reg_normal_ = test(reg_normal_model, test_loader, criterion)
            wo_normal_scheduler.step()
            normal_scheduler.step()
            reg_scheduler.step()
            reg_normal_scheduler.step()
            train_loss_wo_normal.append(train_loss_wo_normal_)
            test_loss_wo_normal.append(test_loss_wo_normal_)
            train_loss_normal.append(train_loss_normal_)
            test_loss_normal.append(test_loss_normal_)
            train_loss_reg.append(train_loss_reg_)
            test_loss_reg.append(test_loss_reg_)
            train_loss_reg_normal.append(train_loss_reg_normal_)
            test_loss_reg_normal.append(test_loss_reg_normal_)
            print(f'no normal: {train_loss_wo_normal_}, {test_loss_wo_normal_}')
            print(f'normal: {train_loss_normal_}, {test_loss_normal_}')
            print(f'reg: {train_loss_reg_}, {test_loss_reg_}')
            print(f'reg + normal: {train_loss_reg_normal_}, {test_loss_reg_normal_}')

    output_norm_wo_normal = torch.tensor(output_norm_wo_normal).reshape((len(bs), -1))
    grad_norm_wo_normal = torch.tensor(grad_norm_wo_normal).reshape((len(bs), -1))
    output_norm_normal = torch.tensor(output_norm_normal).reshape((len(bs), -1))
    grad_norm_normal = torch.tensor(grad_norm_normal).reshape((len(bs), -1))
    train_loss_wo_normal = torch.tensor(train_loss_wo_normal).reshape((len(bs), -1))
    test_loss_wo_normal = torch.tensor(test_loss_wo_normal).reshape((len(bs), -1))
    train_loss_normal = torch.tensor(train_loss_normal).reshape((len(bs), -1))
    test_loss_normal = torch.tensor(test_loss_normal).reshape((len(bs), -1))
    train_loss_reg = torch.tensor(train_loss_reg).reshape((len(bs), -1))
    test_loss_reg = torch.tensor(test_loss_reg).reshape((len(bs), -1))
    train_loss_reg_normal = torch.tensor(train_loss_reg_normal).reshape((len(bs), -1))
    test_loss_reg_normal = torch.tensor(test_loss_reg_normal).reshape((len(bs), -1))

    final_train_loss_wo_normal.append(train_loss_wo_normal[:, -1])
    final_test_loss_wo_normal.append(test_loss_wo_normal[:, -1])
    final_train_loss_normal.append(train_loss_normal[:, -1])
    final_test_loss_normal.append(test_loss_normal[:, -1])
    final_train_loss_reg.append(train_loss_reg[:, -1])
    final_test_loss_reg.append(test_loss_reg[:, -1])
    final_train_loss_reg_normal.append(train_loss_reg_normal[:, -1])
    final_test_loss_reg_normal.append(test_loss_reg_normal[:, -1])


# get the statistics over 3 random seeds
final_train_loss_wo_normal = torch.cat(final_train_loss_wo_normal).reshape((-1, len(bs)))
final_test_loss_wo_normal = torch.cat(final_test_loss_wo_normal).reshape((-1, len(bs)))
final_train_loss_normal = torch.cat(final_train_loss_normal).reshape((-1, len(bs)))
final_test_loss_normal = torch.cat(final_test_loss_normal).reshape((-1, len(bs)))
final_train_loss_reg = torch.cat(final_train_loss_reg).reshape((-1, len(bs)))
final_test_loss_reg = torch.cat(final_test_loss_reg).reshape((-1, len(bs)))
final_train_loss_reg_normal = torch.cat(final_train_loss_reg_normal).reshape((-1, len(bs)))
final_test_loss_reg_normal = torch.cat(final_test_loss_reg_normal).reshape((-1, len(bs)))

print('train wo', final_train_loss_wo_normal.mean(0), final_train_loss_wo_normal.std(0))
print('train w', final_train_loss_normal.mean(0), final_train_loss_normal.std(0))
print('train reg', final_train_loss_reg.mean(0), final_train_loss_reg.std(0))
print('train reg + normal', final_train_loss_reg_normal.mean(0), final_train_loss_reg_normal.std(0))
print('test wo', final_test_loss_wo_normal.mean(0), final_test_loss_wo_normal.std(0))
print('test w', final_test_loss_normal.mean(0), final_test_loss_normal.std(0))
print('test reg', final_test_loss_reg.mean(0), final_test_loss_reg.std(0))
print('test reg + normal', final_test_loss_reg_normal.mean(0), final_test_loss_reg_normal.std(0))


title = ['output norm at initialization', 'gradient norm at initialization', r'training loss for $L = 1000$']
title = ['Output norm at initialization', 'Gradient norm at initialization', r'Training loss for $L = 1000$']
f, axes = plt.subplots(1, 3, figsize=(18, 6))
linewidth = 2.0
axes[0].plot(ls, output_norm_wo_normal[0], label=fr'No rescale, $b = {bs[0]}$', linestyle='--', linewidth=linewidth)
axes[0].plot(ls, output_norm_wo_normal[1], label=fr'No rescale, $b = {bs[1]}$', linestyle='--', linewidth=linewidth)
axes[0].plot(ls, output_norm_wo_normal[2], label=fr'No rescale, $b = {bs[2]}$', linestyle='--', linewidth=linewidth)
axes[0].plot(ls, output_norm_normal[0], label=fr'+ rescale, $b = {bs[0]}$', linewidth=linewidth)
axes[0].plot(ls, output_norm_normal[1], label=fr'+ rescale, $b = {bs[1]}$', linewidth=linewidth)
axes[0].plot(ls, output_norm_normal[2], label=fr'+ rescale, $b = {bs[2]}$', linewidth=linewidth)
axes[0].grid()
axes[0].set_title(title[0], fontsize=20)
axes[0].set_xlabel('length', fontsize=20)
axes[0].set_ylabel(r'$\mathbb{E}_x[|y_L|]$', fontsize=20)
# axes[0].legend()
axes[0].set_yscale('log', base=2)
plt.setp(axes[0].get_xticklabels(), fontproperties=fm.FontProperties(size=20))  
plt.setp(axes[0].get_yticklabels(), fontproperties=fm.FontProperties(size=20))
axes[1].plot(ls, grad_norm_wo_normal[0], linestyle='--', linewidth=linewidth)
axes[1].plot(ls, grad_norm_wo_normal[1], linestyle='--', linewidth=linewidth)
axes[1].plot(ls, grad_norm_wo_normal[2], linestyle='--', linewidth=linewidth)
axes[1].plot(ls, grad_norm_normal[0], linewidth=linewidth)
axes[1].plot(ls, grad_norm_normal[1], linewidth=linewidth)
axes[1].plot(ls, grad_norm_normal[2], linewidth=linewidth)
axes[1].grid()
axes[1].set_title(title[1], fontsize=20)
axes[1].set_xlabel('length', fontsize=20)
axes[1].set_ylabel(r'$\|\nabla R_n(\theta)\|$', fontsize=20)
# axes[1].legend()
axes[1].set_yscale('log', base=2)
plt.setp(axes[1].get_xticklabels(), fontproperties=fm.FontProperties(size=20))  
plt.setp(axes[1].get_yticklabels(), fontproperties=fm.FontProperties(size=20))
axes[2].plot(train_loss_wo_normal[0], linestyle='--', linewidth=linewidth)
axes[2].plot(train_loss_wo_normal[1], linestyle='--', linewidth=linewidth)
axes[2].plot(train_loss_wo_normal[2], linestyle='--', linewidth=linewidth)
axes[2].plot(train_loss_normal[0], linewidth=linewidth)
axes[2].plot(train_loss_normal[1], linewidth=linewidth)
axes[2].plot(train_loss_normal[2], linewidth=linewidth)
axes[2].grid()
axes[2].set_title(title[2], fontsize=20)
axes[2].set_xlabel('epoch', fontsize=20)
# axes[0].legend()
axes[2].set_yscale('log', base=2)
plt.setp(axes[2].get_xticklabels(), fontproperties=fm.FontProperties(size=20))  
plt.setp(axes[2].get_yticklabels(), fontproperties=fm.FontProperties(size=20))
# Place common legend
handles, labels = axes[0].get_legend_handles_labels()
f.legend(handles, labels, loc='lower center', ncol=2, prop={'size': 20}) # placing it in the center at the bottom

plt.tight_layout(rect=[0, 0.25, 1, 1]) 
# the rect parameter makes room for the legend at the bottom. 
# rect : tuple (left, bottom, right, top), optional


# plt.tight_layout()
plt.savefig(f'./result', dpi=300)
