import torch
import matplotlib.pyplot as plt
import sys
import os
import copy
import numpy as np
import argparse
sys.path.append("../src")
from networks import AutoGenNet
from functions_to_generate import vanderpol_system
from scipy.optimize import fsolve
plt.style.use('../plot_params.dms')
units_convert = {'cm': 1 / 2.54, 'mm': 1 / 2.54 / 10}

"""
Description:
-----------
Generate the attractor in the van der Pol system.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--gainrec", type=float, help="gain of recurrent weights", default=1.)
parser.add_argument("--bias_learning", type=int, help="whether to use bias learning (True) of full learning (False)", default=True)
parser.add_argument("--seed", type=int, help="seed", default=1)
parser.add_argument("--resultdir", type=str, help="directory prefix for results", default="")
parser.add_argument("--epochs", type=int, help="number of epochs", default=1000)
parser.add_argument("--datadir", type=str, help="directory prefix for data", default="")
parser.add_argument("--noise-ic", type=float, help="noise in initial conditions", default=0)
parser.add_argument('--use-cpu', action='store_true', default=True, help='disables CUDA training')
args = parser.parse_args()


# ======================  MAIN CODE  ====================== #
if torch.cuda.is_available() and not args.use_cpu:
    device = torch.device('cuda')
    torch.cuda.manual_seed(args.seed)
elif torch.backends.mps.is_available() and not args.use_cpu:
    device = torch.device('mps')
    torch.mps.manual_seed(args.seed)
else:
    device = torch.device('cpu')
    torch.manual_seed(args.seed)
print(f"Device used: {device}")

# Parameters
bias_learning = bool(args.bias_learning)  # whether to learn only using biases (weights fixed)
bias_init = 'uniform'
batch_size = 1
total_time = 15    # length of one episode
time_step = 0.1
network_size = (1, 25, 1)
if bias_learning:
    #f = lambda x: x ** 2 + 2 * x - network_size[1]
    #sol = fsolve(f, network_size[1]**0.5)
    #network_size = (1, int(np.ceil(sol[0])), 1)
    network_size = (1, network_size[1] ** 2 + 2 * network_size[1], 1)
lr = 1.e-1 if bias_learning else 1.e-4
#print(network_size)

# Define network
gain_rec = args.gainrec
gain_readout = 1.
net = AutoGenNet(network_size=network_size, nonlinearity='relu', bias_learning=bias_learning, bias_init=bias_init,
                 gain_init=(gain_rec, gain_readout, 1))
net.to(device)

# Dataset
t, y = vanderpol_system(total_time, time_step)
L = len(t)
x = torch.zeros((L, batch_size, network_size[0]), device=device)
y = 2 * y[0]/ (y[0].max() - y[0].min()) - (y[0].max() + y[0].min()) / (y[0].max() - y[0].min())
y = y.reshape((len(y), 1, 1))  # L, batch_size, 1
y = y.repeat(1, batch_size, 1)
y = y.to(device)

net_before_training = copy.deepcopy(net)

# Train
print(f"Training {'bias' if bias_learning else 'all parameters'} with recurrent gain {gain_rec} for {args.epochs} epochs")
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr)
epochs = args.epochs
losses = torch.empty(epochs, device=device, requires_grad=False)
for t in range(epochs):
    h0 = args.noise_ic * torch.rand((1, batch_size, network_size[1]), device=device)
    pred, _ = net(x, h0)
    loss = loss_fn(pred, y)
    losses[t] = loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if t % 1000 == 0:
        print(f"Epoch {t}: loss = {loss.item():>10e}")

# TEST
# Create simulation-specific results and data folder
suffix = f"epochs{epochs}-gain{gain_rec}-gainout{gain_readout}"
if args.resultdir == "":
    result_folder = (f"../results/vandderpol-hiddensize{network_size[1]}-lr{lr}"
                     f"-biaslearning{bias_learning}-biasinit{bias_init}-{suffix}-seed{args.seed}")
    data_folder = (f"../data/vandderpol-hiddensize{network_size[1]}-lr{lr}"
                   f"-biaslearning{bias_learning}-biasinit{bias_init}-{suffix}-seed{args.seed}")
else:
    result_folder = (f"{args.resultdir}hiddensize{network_size[1]}-lr{lr}"
                     f"-biaslearning{bias_learning}-biasinit{bias_init}-{suffix}-seed{args.seed}")
    data_folder = (f"{args.datadir}hiddensize{network_size[1]}-lr{lr}"
                   f"-biaslearning{bias_learning}-biasinit{bias_init}-{suffix}-seed{args.seed}")

if not os.path.exists(result_folder):
    os.makedirs(result_folder)
if not os.path.exists(data_folder):
    os.makedirs(data_folder)

net.eval()
with torch.no_grad():
    h0 = args.noise_ic * torch.rand((1, 1, network_size[1]), device=device)
    x = torch.zeros((L, 1, network_size[0]), device=device)
    pred, h = net(x, h0)

    plt.figure(figsize=(45*units_convert['mm'], 45/1.25*units_convert['mm']))
    plt.plot(y.to('cpu')[:, 0, 0], color='k', label='target')
    plt.plot(pred.to('cpu')[:, 0, 0], '--', color='orange', label='network')
    plt.xlabel('Time')
    plt.ylabel('Output')
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(result_folder, 'Output.png'))
    plt.close()

    # Save a few parameters and network observables while we're at it.
    np.save(os.path.join(data_folder, 'activity.npy'), h.squeeze().cpu())
    n_learned_params = 0
    for ((name, param), (name_old, param_old)) in zip(net.named_parameters(), net_before_training.named_parameters()):
        n_learned_params += param.numel() if param.requires_grad==True else 0
        if name == "recurrent_layer.bias_ih_l0" and bias_learning:
            np.save(os.path.join(data_folder, 'delta_bias.npy'), (param - param_old).cpu())
            np.save(os.path.join(data_folder, 'bias_post_training.npy'), param.cpu())
        if name == "recurrent_layer.weight_hh_l0" and bias_learning:
            np.save(os.path.join(data_folder, 'recurrent_weights.npy'), param.cpu())
    # print("\nLearnable parameters")
    # for ((name, param), (name_old, param_old)) in zip(net.named_parameters(), net_before_training.named_parameters()):
    #     if param.requires_grad==True:
    #         print(name)
    # print(f"Number of learnable parameters: {n_learned_params}")
# Save
np.save(os.path.join(data_folder, 'loss.npy'), losses.cpu())
torch.save(net.state_dict(), os.path.join(data_folder, 'model_weights.pth'))
