"""Computes low-rank PEFs with columns aligned to classes."""
import dataclasses
from typing import List

import torch
from transformers import PreTrainedModel

from npeff_torch.util import flat_pack
from . import pef_computer_common


###############################################################################


@dataclasses.dataclass
class ClassAlignedLrmComputer(pef_computer_common.PefComputerAbc):

    model: PreTrainedModel

    # List of parameters to compute Fishers for.
    parameters: List[torch.nn.Parameter]

    error_on_null_grads: bool = True

    def __post_init__(self):
        self.flat_packer = flat_pack.FlatPacker(p.shape for p in self.parameters)
        self.n_parameters = self.flat_packer.flat_size
        
    def is_output_projected(self) -> bool:
        return False

    def compute_dense_pef(self, example_info: pef_computer_common.PefComputerInput) -> torch.Tensor:
        log_probs = example_info.log_probs
        class_indices = example_info.class_indices

        device = log_probs.device

        indexed_log_probs = [log_probs[class_index] for class_index in class_indices]

        with torch.no_grad():
            pef = torch.zeros([len(class_indices), self.n_parameters], dtype=log_probs.dtype, device=device)

            for i, log_prob in enumerate(indexed_log_probs):

                self.model.zero_grad()
                log_prob.backward(retain_graph=i + 1 < len(class_indices))

                for j, p in enumerate(self.parameters):
                    if p.grad is not None:
                        p_start, p_end = self.flat_packer.get_range_for_tensor_by_index(j)
                        pef[i, p_start:p_end] = torch.reshape(p.grad, [-1])
                    elif self.error_on_null_grads:
                        raise ValueError(f'Received null gradient for parameter with index {j}')

            # NOTE: New to this and have been bitten by this before. Might not be neccessary here.
            self.model.zero_grad()

            # sqrt_probs.shape = [n_classes_total]
            sqrt_probs = torch.sqrt(torch.softmax(log_probs, dim=-1))
            sqrt_probs = sqrt_probs[class_indices]
            pef.mul_(sqrt_probs[:, None])

        return pef
