import torch
from tqdm import tqdm


def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

def count_parameters_list(paramters):
    return sum(p.numel() for p in paramters)

def parameters_to_vector(parameters):
    """
    Same as https://pytorch.org/docs/stable/generated/torch.nn.utils.parameters_to_vector.html
    but with :code:`reshape` instead of :code:`view` to avoid a pesky error.
    """
    vec = []
    for param in parameters:
        vec.append(param.reshape(-1))
    return torch.cat(vec)


def grad_calculator(data_loader,
                    model,
                    parameters,
                    func,
                    normalize_factor,
                    device,
                    projector,
                    checkpoint_id,
                    disable_dropout=False):
    if disable_dropout:
        model.disable_dropout()
    if not isinstance(checkpoint_id, list):
        checkpoint_id = [checkpoint_id]
    res = [[] for i in range(len(checkpoint_id))]
    for _, data in enumerate(tqdm(data_loader)):
        model_output = func(data, model)
        if torch.isinf(model_output):
            # TODO: handle numerical problem better
            # print("numerical problem happens, model output function equals to inf")
            grads = torch.zeros(count_parameters(model), dtype=torch.float32).to(device)
            for i, random_id in enumerate(checkpoint_id):
                grads_p = projector.project(grads.clone().detach().unsqueeze(0), model_id=random_id, is_grads_dict=False)
                res[i].append(grads_p)
        else:
            grads = parameters_to_vector(torch.autograd.grad(model_output, parameters, retain_graph=True))
            grads /= normalize_factor

            # # Direct mask on grad
            # p = 0.1  # Example value, can be adjusted
            # mask = torch.bernoulli(torch.full(grads.shape, 1-p)).float().to(device)
            # mask = torch.where(mask == 0, torch.tensor(0.0), torch.tensor(1.0 / (1 - p)))
            # grads *= mask
            for i, random_id in enumerate(checkpoint_id):
                grads_p = projector.project(grads.clone().detach().unsqueeze(0), model_id=random_id, is_grads_dict=False)
                res[i].append(grads_p)
    for i in range(len(res)):
        res[i] = torch.cat(res[i], dim=0)
    if len(res) == 1:
        return res[0]
    else:
        return res


def out_to_loss_grad_calculator(data_loader,
                                model,
                                func):
    out_to_loss_grads = []
    for _, data in enumerate(tqdm(data_loader)):
        out_to_loss_grad = func(data, model)
        out_to_loss_grads.append(out_to_loss_grad)
    return torch.diag(torch.cat(out_to_loss_grads).reshape(-1))
