"""Computes low-rank PEFs constructed via a streaming SVD-based approximation."""
import dataclasses

from typing import List, Optional

import torch
from transformers import PreTrainedModel

from npeff_torch.util import flat_pack
from . import pef_computer_common


###############################################################################
PefComputerInput = pef_computer_common.PefComputerInput
###############################################################################


@dataclasses.dataclass
class StreamingSvdLrmComputer(pef_computer_common.PefComputerAbc):

    model: PreTrainedModel

    # List of parameters to compute Fishers for.
    parameters: List[torch.nn.Parameter]

    # The maximum rank of the SVD that will be achieved during the decomposition.
    max_intermediate_rank: int

    # Use this to enforce a fixed shape of the output of the LRM-pefs. If provided,
    # this is technically an upper bound on its rank, the output LRM PEF will always
    # have this dimension.
    #
    # If not provided, then the shape of the output can vary with inputs with a maximum
    # dimension size of max_intermediate_rank.
    output_rank: Optional[int]

    error_on_null_grads: bool = True

    def __post_init__(self):
        self.flat_packer = flat_pack.FlatPacker(p.shape for p in self.parameters)
        self.n_parameters = self.flat_packer.flat_size

        if self.output_rank is not None:
            assert self.output_rank <= self.max_intermediate_rank

        self._tensors_created = False

    def is_output_projected(self) -> bool:
        return False
        
    #######################################################
    
    def _maybe_create_tensors(self, device):
        # Initializes tensors used within this LRM computer.
        if self._tensors_created:
            return
        self._tensors_created = True

        zeros = lambda s: torch.zeros(s, dtype=torch.float32, device=device)

        self._pef_vector = zeros([self.n_parameters])

        # NOTE: Technically, these are the transpose of the U.
        self._U = zeros([self.max_intermediate_rank + 1, self.n_parameters])
        self._Up = zeros([self.max_intermediate_rank + 1, self.n_parameters])

        self._sigma = zeros([self.max_intermediate_rank + 1])

    def _reset_shared_tensors(self):
        self._current_rank = 0

        self._pef_vector.zero_()

        self._U.zero_()
        self._Up.zero_()
        self._sigma.zero_()

    #######################################################

    def _compute_pef_vector_for_class(
        self,
        example_info: PefComputerInput,
        # sqrt_probs.shape = [n_classes_total]
        sqrt_probs: torch.Tensor,
        class_index: int,
        *,
        retain_graph: bool,
    ) -> torch.Tensor:
        # Fills out self._pef_vector.
        log_prob = example_info.log_probs[class_index]

        with torch.no_grad():
            self._pef_vector.zero_()
            self.model.zero_grad()
            log_prob.backward(retain_graph=retain_graph)

            for j, p in enumerate(self.parameters):
                if p.grad is not None:
                    p_start, p_end = self.flat_packer.get_range_for_tensor_by_index(j)
                    self._pef_vector[p_start:p_end] = torch.reshape(p.grad, [-1])
                elif self.error_on_null_grads:
                    raise ValueError(f'Received null gradient for parameter with index {j}')

            self._pef_vector.mul_(sqrt_probs[class_index])

        return self._pef_vector

    #######################################################

    def _process_first_pef_vector(self, pef_vector: torch.Tensor):
        assert self._current_rank == 0

        torch.linalg.vector_norm(pef_vector, out=self._sigma[0])
        self._U[0] = pef_vector
        self._U[0] /= self._sigma[0]
        torch.square_(self._sigma[0])

    def _process_pef_vector(self, pef_vector: torch.Tensor):
        if self._current_rank == 0:
            self._process_first_pef_vector(pef_vector)
            self._current_rank += 1
            return

        U = self._U[:self._current_rank]
        m = torch.matmul(U, pef_vector)

        # TODO: Make sure the transpose U.t() does NOT make a copy.
        pef_vector -= torch.matmul(U.t(), m)
        R = torch.linalg.vector_norm(pef_vector)
        pef_vector /= R

        # mR = torch.cat([m, [R]], dim=-1)
        mR = torch.cat([m, torch.tensor([R], dtype=m.dtype, device=m.device)], dim=-1)

        self._sigma[self._current_rank].zero_()
        K = torch.diagflat(self._sigma[:self._current_rank + 1])
        K += mR[:, None] * mR[None, :]

        sigma2, U2 = torch.linalg.eigh(K)
        self._sigma[:self._current_rank + 1] = sigma2

        Up = self._Up[:self._current_rank + 1]
        Up[:self._current_rank] = U
        Up[self._current_rank] = pef_vector

        torch.matmul(U2.t(), Up, out=self._U[:self._current_rank + 1])

        self._current_rank = min(self._current_rank + 1, self.max_intermediate_rank)

    #######################################################

    def compute_dense_pef(self, example_info: PefComputerInput) -> torch.Tensor:
        log_probs = example_info.log_probs
        class_indices = example_info.class_indices

        device = log_probs.device

        with torch.no_grad():
            self._maybe_create_tensors(device)
            self._reset_shared_tensors()

            # sqrt_probs.shape = [n_classes_total]
            sqrt_probs = torch.sqrt(torch.softmax(log_probs, dim=-1))

        for i, class_index in enumerate(class_indices):
            pef_vector = self._compute_pef_vector_for_class(example_info, sqrt_probs, class_index, retain_graph=i + 1 < len(class_indices))
            with torch.no_grad():
                self._process_pef_vector(pef_vector)

        with torch.no_grad():
            output_rank = self.output_rank if self.output_rank is not None else self._current_rank
            pef = self._U[:output_rank]
            pef.mul_(torch.sqrt(self._sigma[:output_rank])[:, None])

        return pef


# TODO: Allow for returning rank of max_intermediate_rank + 1 (and/or maybe rename it)???
