import math
import torch
from cfgs.process_cfg import set_cfgs
from utils.set_seed import setup_seed
from utils.loss import MaskedCELoss
from utils.ctx_single_dataset import generateData

# get cfgs
cfgs = set_cfgs()
print("---cfgs---")
print(cfgs)

# set device
device = torch.device("cuda:{}".format(cfgs['gpu_id']) if cfgs['device'] == "cuda" else "cpu")

# set seed
setup_seed(cfgs['seed'])

# define model
from model.customSNN_TAUN_THR import *
model = customSNN(cfgs, device=device)
model.to(device)

trainable_var = {name: v for name, v in model.named_parameters()}
for name, param in trainable_var.items():
    print(name, param.shape, param.requires_grad)

loss_fn = MaskedCELoss()

# load imgs
all_imgs = np.load('../../data/cddms/all_imgs.npy')

# train
model.train()
for p in range(cfgs['problems']):

    images = all_imgs[p, :, :]
    loss_p = []

    # set optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=cfgs['learning_rate'],
                                 weight_decay=cfgs['weight_decay'],
                                 betas=(cfgs['beta1'], cfgs['beta2']))

    for iter in range(cfgs['iters']):

        # get data
        inputs, labels, label_mask, images, stims = generateData(cfgs,
                                                                 images=images,
                                                                 test=False,
                                                                 stim=None)
        inputs = inputs.to(device=device, dtype=torch.float)
        labels = labels.to(device=device, dtype=torch.float)
        label_mask = label_mask.to(device=device, dtype=torch.float)

        # feed-forward
        out, final_rnn_outputs, final_rnn_inputs, _ = model(cfgs, inputs)

        # calculate loss and update weights
        optimizer.zero_grad()
        loss = loss_fn(cfgs, model, out, final_rnn_outputs, labels, label_mask)
        loss.backward()

        # Gradient, state post-processing
        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=1., norm_type=2)
        model.initS.data = torch.nn.functional.relu(model.initS).data

        optimizer.step()
        functional.reset_net(model)

        print("problem {}/ iter {} / label {} / loss {}".format(p, iter, stims, loss.item()))
        loss_p.append(loss.item())

        if len(loss_p) > 50 and (sum(loss_p[-50:]) / len(loss_p[-50:]) < cfgs['thr']):
            from utils.params import testAndSaveParams2
            testAndSaveParams2(cfgs, model, images, taskIndex=p, iter=iter)
            loss_fn.lhTargVar = loss_fn.hNorm.item()
            break
        if iter == cfgs['max_iters'] or math.isnan(loss.item()):
            break
