import torch
from tqdm import tqdm
import sys
sys.path.append("./MusicTransformer_Pytorch")

from utilities.constants import *
from .projector import vectorize

def count_parameters(model):
    return sum(p.numel() for p in model.parameters())


def parameters_to_vector(parameters):
    """
    Same as https://pytorch.org/docs/stable/generated/torch.nn.utils.parameters_to_vector.html
    but with :code:`reshape` instead of :code:`view` to avoid a pesky error.
    """
    vec = []
    for param in parameters:
        vec.append(param.reshape(-1))
    return torch.cat(vec)


def grad_calculator(data_loader,
                    model,
                    parameters,
                    func,
                    normalize_factor,
                    device,
                    projector,
                    checkpoint_id):
    # model.disable_dropout()
    all_grads = []
    for _, data in enumerate(data_loader):
        model_output = func(data, model)
        if torch.isinf(model_output):
            # TODO: handle numerical problem better
            print("numerical problem happens, model output function equals to inf")
            grads = torch.zeros(count_parameters(model), dtype=torch.float32).to(device)
            grads_p = projector.project(grads.clone().detach().unsqueeze(0), model_id=checkpoint_id, is_grads_dict=False)
            all_grads.append(grads_p)
        else:
            grads = parameters_to_vector(torch.autograd.grad(model_output, parameters, retain_graph=True))
            grads /= normalize_factor
            grads_p = projector.project(grads.clone().detach().unsqueeze(0), model_id=checkpoint_id, is_grads_dict=False)
            all_grads.append(grads_p)
    all_grads = torch.cat(all_grads, dim=0)
    return all_grads

def grad_calculator_list(data_loader,
                        model,
                        parameters,
                        func,
                        normalize_factor,
                        device,
                        projector,
                        checkpoint_id_list):
    model.disable_dropout()
    all_grads = [[] for _ in range(len(checkpoint_id_list))]
    # print("checkpoint id list: ", checkpoint_id_list)

    for _, data in enumerate(data_loader):
        model_output = func(data, model)
        if torch.isinf(model_output):
            # TODO: handle numerical problem better
            print("numerical problem happens, model output function equals to inf")
            grads = torch.zeros(count_parameters(model), dtype=torch.float32).to(device)
            for idx, checkpoint_id in enumerate(checkpoint_id_list):
                all_grads[idx].append(projector.project(grads.clone().detach().unsqueeze(0), model_id=checkpoint_id.item(), is_grads_dict=False))
        else:
            grads = parameters_to_vector(torch.autograd.grad(model_output, parameters, retain_graph=True))
            grads /= normalize_factor
            for idx, checkpoint_id in enumerate(checkpoint_id_list):
                all_grads[idx].append(projector.project(grads.clone().detach().unsqueeze(0), model_id=checkpoint_id.item(), is_grads_dict=False))
        
    return [torch.cat(grads, dim=0) for grads in all_grads]




def out_to_loss_grad_calculator(data_loader,
                                model,
                                func,
                                ):
    # model.enable_dropout()
    out_to_loss_grads = []
    for _, data in enumerate(data_loader):
        out_to_loss_grad = func(data, model)
        out_to_loss_grads.append(out_to_loss_grad)
    return torch.diag(torch.cat(out_to_loss_grads).reshape(-1))