import torch
from models.base import ContinualLearning

Warning("This function is not usable")


def compute_empirical_fisher_information(
    model, dataloader, criterion, task_name, task_id, device="cpu"
):
    """
    Computes the empirical Fisher Information Matrix (FIM) for the whole dataset.

    Parameters:
    - model: The PyTorch model.
    - dataloader: DataLoader providing the training data.
    - criterion: The loss function used to compute gradients.
    - device: The device on which to perform computations ('cpu' or 'cuda').

    Returns:
    - fisher_information_matrix: The empirical FIM as a torch tensor.
    """
    model = model.to(device)
    model.eval()  # Set the model to evaluation mode

    fisher_information_matrix = None
    total_samples = 0

    for sample in dataloader:
        image = sample["image"].to(device)
        cur_task_y = sample[task_name].type(torch.LongTensor).to(device)
        if isinstance(model, ContinualLearning):
            # Forward pass
            loss = model.compute_loss_on_task_id(
                image, cur_task_y, criterion, task_id
            )
        elif isinstance(model, torch.nn.Module):
            # Forward pass
            outputs = model(image)
            loss = criterion(outputs, cur_task_y)
        else:
            raise NotImplementedError(
                "Model should be an instance of ContinualLearning or torch.nn.Module"
            )

        # Backward pass to compute gradients
        model.zero_grad()
        loss.backward()

        # Flatten gradients into a single vector
        grads = torch.cat(
            [
                param.grad.view(-1)
                for param in model.parameters()
                if param.grad is not None
            ]
        )

        # Compute outer product of gradients (outer product of vector with itself)
        outer_product = torch.ger(grads, grads)

        # Initialize the FIM or accumulate
        if fisher_information_matrix is None:
            fisher_information_matrix = outer_product
        else:
            fisher_information_matrix += outer_product

        total_samples += inputs.size(0)

    # Average by the number of samples (if desired)
    fisher_information_matrix /= total_samples

    return fisher_information_matrix


def compute_fisher_information_sum(
    model, dataloader, criterion, task_name, task_id, device="cpu"
):
    """
    Computes the sum of eigenvalues of the empirical Fisher Information Matrix (FIM) over a dataset.

    Parameters:
    - model: The PyTorch model.
    - dataloader: DataLoader providing the training data.
    - criterion: The loss function used to compute gradients.
    - device: The device on which to perform computations ('cpu' or 'cuda').

    Returns:
    - scalar_metric: The sum of eigenvalues of the FIM as a scalar.
    """
    # Compute the empirical FIM
    fim = compute_empirical_fisher_information(
        model, dataloader, criterion, task_name, task_id, device
    )

    # Compute eigenvalues of the FIM
    eigenvalues, _ = torch.linalg.eig(fim)

    # Return the sum of the real part of eigenvalues
    fisher_sum = eigenvalues.real.sum().item()

    return fisher_sum


def _check_param_device(param, old_param_device):
    r"""This helper function is to check if the parameters are located
    in the same device. Currently, the conversion between model parameters
    and single vector form is not supported for multiple allocations,
    e.g. parameters in different GPUs, or mixture of CPU/GPU.

    Arguments:
        param ([Tensor]): a Tensor of a parameter of a model
        old_param_device (int): the device where the first parameter of a
                                model is allocated.

    Returns:
        old_param_device (int): report device for the first time
    """

    # Meet the first parameter
    if old_param_device is None:
        old_param_device = param.get_device() if param.is_cuda else -1
    else:
        warn = False
        if param.is_cuda:  # Check if in same GPU
            warn = param.get_device() != old_param_device
        else:  # Check if in CPU
            warn = old_param_device != -1
        if warn:
            raise TypeError(
                "Found two parameters on different devices, "
                "this is currently not supported."
            )
    return old_param_device


def vector_to_parameter_list(vec, parameters):
    r"""Convert one vector to the parameter list

    Arguments:
        vec (Tensor): a single vector represents the parameters of a model.
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.
    """
    # Ensure vec of type Tensor
    if not isinstance(vec, torch.Tensor):
        raise TypeError(
            "expected torch.Tensor, but got: {}".format(torch.typename(vec))
        )
    # Flag for the device where the parameter is located
    param_device = None
    params_new = []
    # Pointer for slicing the vector for each parameter
    pointer = 0
    for param in parameters:
        # Ensure the parameters are located in the same device
        param_device = _check_param_device(param, param_device)

        # The length of the parameter
        num_param = param.numel()
        # Slice the vector, reshape it, and replace the old data of the parameter
        param_new = vec[pointer : pointer + num_param].view_as(param).data
        params_new.append(param_new)
        # Increment the pointer
        pointer += num_param

    return list(params_new)


def Rop(ys, xs, vs):
    if isinstance(ys, tuple):
        ws = [torch.tensor(torch.zeros_like(y), requires_grad=True) for y in ys]
    else:
        ws = torch.tensor(torch.zeros_like(ys), requires_grad=True)

    gs = torch.autograd.grad(
        ys,
        xs,
        grad_outputs=ws,
        create_graph=True,
        retain_graph=True,
        allow_unused=True,
    )
    re = torch.autograd.grad(
        gs,
        ws,
        grad_outputs=vs,
        create_graph=True,
        retain_graph=True,
        allow_unused=True,
    )
    return tuple([j.detach() for j in re])


def Lop(ys, xs, ws):
    vJ = torch.autograd.grad(
        ys,
        xs,
        grad_outputs=ws,
        create_graph=True,
        retain_graph=True,
        allow_unused=True,
    )
    return tuple([j.detach() for j in vJ])


def HesssianVectorProduct(f, x, v):
    df_dx = torch.autograd.grad(f, x, create_graph=True, retain_graph=True)
    Hv = Rop(df_dx, x, v)
    return tuple([j.detach() for j in Hv])


def FisherVectorProduct(loss, output, model, vp):

    Jv = Rop(output, list(model.parameters()), vp)
    batch, dims = output.size(0), output.size(1)
    if loss.grad_fn.__class__.__name__ == "NllLossBackward":
        outputsoftmax = torch.nn.functional.softmax(output, dim=1)
        M = (
            torch.zeros(batch, dims, dims).cuda()
            if outputsoftmax.is_cuda
            else torch.zeros(batch, dims, dims)
        )
        M.reshape(batch, -1)[:, :: dims + 1] = outputsoftmax
        H = M - torch.einsum("bi,bj->bij", (outputsoftmax, outputsoftmax))
        HJv = [torch.squeeze(H @ torch.unsqueeze(Jv[0], -1)) / batch]
    else:
        HJv = HesssianVectorProduct(loss, output, Jv)
    # Reshape vp to match the output shape
    vp_reshaped = vector_to_parameter_list(vp, model.parameters())
    vp_reshaped = [
        v.view_as(o) for v, o in zip(vp_reshaped, model.parameters())
    ]

    JHJv = Lop(output, list(model.parameters()), HJv)

    return torch.cat([torch.flatten(v) for v in JHJv])

    return torch.cat([torch.flatten(v) for v in JHJv])


def compute_fisher_information(
    model, data_loader, criterion, task_name, task_id, device="cpu"
):
    """
    Computes an efficient estimate of the Fisher Information Matrix using vector products.

    Arguments:
    - model: The PyTorch model.
    - data_loader: DataLoader providing the training data.
    - criterion: The loss function used to compute gradients.
    - task_name: The task identifier (string).
    - task_id: The specific task ID if using a task-specific loss function.
    - device: The device to run the computation on ('cpu' or 'cuda').

    Returns:
    - fisher_information_sum: The sum of eigenvalues of the FIM.
    """
    model.eval()
    fisher_information_sum = 0.0

    for sample in data_loader:
        inputs = sample["image"].to(device)
        targets = sample[task_name].type(torch.LongTensor).to(device)

        if isinstance(model, ContinualLearning):
            # Forward pass
            _, outputs, loss = model.compute_loss_on_task_id(
                inputs, targets, criterion, task_id
            )
        elif isinstance(model, torch.nn.Module):
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        else:
            raise NotImplementedError(
                "Model should be an instance of ContinualLearning or torch.nn.Module"
            )

        # Compute the gradient of the loss w.r.t model parameters
        model.zero_grad()
        loss.backward(create_graph=True)

        # Fisher vector product computation (memory efficient)
        with torch.no_grad():
            flat_params = torch.cat(
                [param.view(-1) for param in model.parameters()]
            )
            vp = torch.randn_like(
                flat_params
            )  # Random vector for Fisher information approximation
            fisher_vp = FisherVectorProduct(loss, outputs, model, vp)

        fisher_information_sum += fisher_vp.dot(fisher_vp).item()

    fisher_information_sum /= len(data_loader)

    return fisher_information_sum
