import torch
import matplotlib.pyplot as plt
import sys
import os
import copy
import numpy as np
import argparse
sys.path.append("../src")
from networks import AutoGenNet
from functions_to_generate import simple_cosine, rectified_cosine
plt.style.use('../plot_params.dms')
units_convert = {'cm': 1 / 2.54, 'mm': 1 / 2.54 / 10}

"""
Description:
-----------
"""
# Arguments parsing
parser = argparse.ArgumentParser()
#parser.add_argument("-size", type=tuple, help="size of the network (in, h, out)", default=(1, 100, 1))
parser.add_argument("--seed", type=str, help="seed", default=1)
parser.add_argument("--noise-ic", type=float, help="noise in initial conditions", default=1)
parser.add_argument('--use-cpu', action='store_true', default=True, help='disables CUDA training')
args = parser.parse_args()


# ======================  MAIN CODE  ====================== #
if torch.cuda.is_available() and not args.use_cpu:
    device = torch.device('cuda')
    torch.cuda.manual_seed(args.seed)
elif torch.backends.mps.is_available() and not args.use_cpu:
    device = torch.device('mps')
    torch.mps.manual_seed(args.seed)
else:
    device = torch.device('cpu')
    torch.manual_seed(args.seed)
print(f"Device used: {device}")

# Parameters
bias_learning = True  # whether to learn only using biases (weights fixed)
bias_init = 'uniform'
batch_size = 64
f = 1./25       # frequency in (0, 1) -- discrete-time periodic signal (must be rational to be exactly periodic)
L = int(5/f)    # length of one episode
network_size = (1, 200, 1)
lr = 1.e-1 if bias_learning else 1.e-3

# Define network
net = AutoGenNet(network_size=network_size, nonlinearity='relu', bias_learning=bias_learning, bias_init=bias_init)
net.to(device)

# Dataset
x = torch.zeros((L, batch_size, network_size[0]), device=device)
y = simple_cosine(torch.arange(L), f)
y = y.reshape((len(y), 1, 1))  # L, batch_size, 1
y = y.repeat(1, batch_size, 1)
y = y.to(device)

net_before_training = copy.deepcopy(net)

# Train
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr)
epochs = 500
losses = torch.empty(epochs, device=device, requires_grad=False)
for t in range(epochs):
    h0 = args.noise_ic * torch.rand((1, batch_size, network_size[1]), device=device)
    pred, _ = net(x, h0)
    loss = loss_fn(pred, y)
    losses[t] = loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if t % 100 == 0:
        print(f"Epoch {t}: loss = {loss.item():>10e}")

# TEST
# Create simulation-specific results and data folder
suffix = f"before-bifurcation-epoch{epochs}"
result_folder = (f"../results/sine-generation-hiddensize{network_size[1]}-lr{lr}"
                 f"-biaslearning{bias_learning}-biasinit{bias_init}-seed{args.seed}-{suffix}")
data_folder = (f"../data/sine-generation-hiddensize{network_size[1]}-lr{lr}"
               f"-biaslearning{bias_learning}-biasinit{bias_init}-seed{args.seed}-{suffix}")
if not os.path.exists(result_folder):
    os.makedirs(result_folder)
if not os.path.exists(data_folder):
    os.makedirs(data_folder)

net.eval()
with torch.no_grad():
    h0 = args.noise_ic * torch.rand((1, 1, network_size[1]), device=device)
    x = torch.zeros((L, 1, network_size[0]), device=device)
    pred, h = net(x, h0)

    plt.figure(figsize=(45*units_convert['mm'], 45/1.25*units_convert['mm']))
    plt.plot(y.to('cpu')[:, 0, 0], color='k', label='target')
    plt.plot(pred.to('cpu')[:, 0, 0], '--', color='orange', label='network')
    plt.xlabel('Time')
    plt.ylabel('Output')
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(result_folder, 'Output.png'))
    plt.close()

    # Save a few parameters and network observables while we're at it.
    np.save(os.path.join(data_folder, 'activity.npy'), h.squeeze().cpu())
    for ((name, param), (name_old, param_old)) in zip(net.named_parameters(), net_before_training.named_parameters()):
        if name == "recurrent_layer.bias_ih_l0" and bias_learning:
            np.save(os.path.join(data_folder, 'delta_bias.npy'), (param - param_old).cpu())
            np.save(os.path.join(data_folder, 'bias_post_training.npy'), param.cpu())
        if name == "recurrent_layer.weight_hh_l0" and bias_learning:
            np.save(os.path.join(data_folder, 'recurrent_weights.npy'), param.cpu())
        print(f"{name}: {torch.linalg.norm(param - param_old)}")  # just to check whether only the biases were learned when bias_learning = True
    #print(net.recurrent_layer.bias_ih_l0)
# Save
np.save(os.path.join(data_folder, 'loss.npy'), losses.cpu())
torch.save(net.state_dict(), os.path.join(data_folder, 'model_weights.pth'))
