"""Stuff for actually computing the gradients."""
import dataclasses
from typing import Dict, List, Optional

import torch
from transformers import PreTrainedModel

from npeff_torch.peis.gradients import logit_functions
from npeff_torch.util import flat_pack


###############################################################################


@dataclasses.dataclass
class ExampleGradientInfo:
    dense_gradient: torch.Tensor

    # shape = []
    fn_value: torch.Tensor


###############################################################################


@dataclasses.dataclass
class GradientComputer:
    
    model: PreTrainedModel

    # List of parameters to compute Fishers for.
    parameters: List[torch.nn.Parameter]

    logit_fn: 'logit_functions.LogitFunctionType'

    error_on_null_grads: bool = True

    def __post_init__(self):
        self.flat_packer = flat_pack.FlatPacker(p.shape for p in self.parameters)
        self.n_parameters = self.flat_packer.flat_size

    def compute_dense_gradient(self, example_info: 'logit_functions.LogitFunctionInput') -> ExampleGradientInfo:
        value = self.logit_fn(example_info)

        device = value.device

        with torch.no_grad():
            grad = torch.zeros([self.n_parameters], dtype=value.dtype, device=device)

            self.model.zero_grad()
            value.backward(retain_graph=False)

            for j, p in enumerate(self.parameters):
                if p.grad is not None:
                    p_start, p_end = self.flat_packer.get_range_for_tensor_by_index(j)
                    grad[p_start:p_end] = torch.reshape(p.grad, [-1])
                elif self.error_on_null_grads:
                    raise ValueError(f'Received null gradient for parameter with index {j}')

            # NOTE: New to this and have been bitten by this before. Might not be neccessary here.
            self.model.zero_grad()

        return ExampleGradientInfo(
            dense_gradient=grad,
            fn_value=value,
        )
