from torch.utils.data import DataLoader
from torch import optim
import pickle
import os
import time
import gc

from src.signal_embed import *

def num_params(model):
    nums = sum(p.numel() for p in model.parameters() if p.requires_grad)
    # print("Number of Parameters: {}M".format(round(nums / 1e6, 2)))
    print("Number of Parameters: {}".format(nums))

def get_and_save_embed(model, dataset, ds_name, model_name, batch_size=512, is_on_cluster=False):
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    save_root = "data/{}/{}".format(ds_name, model_name)
    if is_on_cluster:
        save_root = "../../" + save_root
    os.makedirs(save_root, exist_ok=True)

    model.eval()
    with torch.no_grad():
        for data_pack in tqdm(data_loader):
            if 'msitf' in model_name:
                _, embed = model(data_pack["samples"], query=data_pack['query'])
            else:
                _, embed = model(data_pack["samples"])
            for i in range(len(embed)):
                curr_fn = data_pack["fnames"][i].replace("\\", "/")
                curr_fn = curr_fn.split("/")[-1]
                save_path = "data/{}/{}/{}".format(ds_name, model_name, curr_fn)
                if is_on_cluster:
                    save_path = "../../" + save_path
                with open(save_path, 'wb') as f:
                    pickle.dump(embed[i].detach().cpu().float().numpy().astype(np.float16), f)

    # clear cache
    torch.cuda.empty_cache()
    gc.collect()

def fit_model(
    model_name,
    model, 
    train_dataset, 
    eval_dataset,
    batch_size=512,
    epochs=10,
    lr=1e-3,
    weight_decay=1e-5,
    step_size=10,
    gamma=0.997
):

    # main train pipeline
    print("Construct data loader")
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)

    # construct optimizer
    optimizer = optim.Adam(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )

    if model.task == "reg":
        loss_f = nn.L1Loss()
    elif model.task == "class":
        loss_f = nn.CrossEntropyLoss()
        # loss_f = nn.CrossEntropyLoss(weight=torch.tensor([0.1, 1.0, 1.0]).to(DEVICE))

    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

    # stored variables
    train_losses, eval_losses = list(), list()
    train_i, eval_i = list(), list()
    iteration = 0
    last_pred = None
    print("Start training")
    for e in tqdm(range(epochs)):
        model.train()
        # t3 = time.time()
        for data_pack in tqdm(train_loader):
            iteration += 1
            # forward
            # print(set(data_pack["labels"].detach().numpy().tolist()))
            # t1 = time.time()
            # print("dataload time:", t1-t3)
            if 'msitf' in model_name:
                pred, _ = model(data_pack["samples"], query=data_pack['query'])
            else:
                pred, _ = model(data_pack["samples"])
            loss = loss_f(pred, data_pack["labels"])
            # t2 = time.time()
            # print("forward time:", t2-t1)

            # backprop
            optimizer.zero_grad() # clear cache
            loss.backward() # calculate gradient
            # for p in model.parameters(): # addressing gradient vanishing
            #     if p.requires_grad and p.grad is not None:
            #         p.grad = torch.nan_to_num(p.grad, nan=0.0)
            optimizer.step() # update parameters
            scheduler.step() # update learning rate
            # t3 = time.time()
            # print("backprop time:", t3-t2)

            # update record
            train_losses.append(loss.detach().cpu().item())
            train_i.append(iteration)

        # validation
        if e%5 == 0 or e == epochs-1:
            print("Eval")
            model.eval()
            eval_loss = 0
            total_val_num = 0
            y_preds, y_trues = list(), list()
            with torch.no_grad():
                for data_pack in tqdm(eval_loader):
                    if 'msitf' in model_name:
                        pred, _ = model(data_pack["samples"], query=data_pack['query'])
                    else:
                        pred, _ = model(data_pack["samples"])
                    loss = loss_f(pred, data_pack["labels"])

                    # update record
                    eval_loss += loss.detach().cpu().item() * len(data_pack["labels"])
                    total_val_num += len(data_pack["labels"])

                    # update stored record
                    pred = pred.detach().cpu().float()
                    if model.task == 'class':
                        pred = torch.softmax(pred, dim=1)
                    y_preds += pred.numpy().tolist()
                    y_trues += data_pack["labels"].detach().cpu().float().numpy().tolist()
            last_pred = {"y_preds": y_preds, "y_trues": y_trues}
            eval_loss /= total_val_num
            eval_losses.append(eval_loss)
            eval_i.append(iteration)
            print("Check Val Loss:", eval_losses[-1])

    return {
        "train_losses": train_losses, # list
        "train_i": train_i,
        "eval_losses": eval_losses, # list
        "eval_i": eval_i,
        "last_pred": last_pred # dict{list, list}
    }