"""Computes low-rank PEFs by computing their random projection and then using SVD to reduce rank. Also
approximates the expectation in the logits using random projection.
"""
import dataclasses
import json
import math
from typing import List, Optional

import h5py
import torch
from transformers import PreTrainedModel

from npeff_torch.peis import random_projectors
from npeff_torch.util import flat_pack
from . import pef_computer_common


###############################################################################
PefComputerInput = pef_computer_common.PefComputerInput
###############################################################################
_VALID_EVRP_TYPES = ('bernoulli',)
###############################################################################


@dataclasses.dataclass
class BatchedEvrpRpSvdLrmComputer(pef_computer_common.PefComputerAbc):
    """Same as above, but does the random projection on multiple vectors at once to improve performance."""

    model: PreTrainedModel

    # List of parameters to compute Fishers for.
    parameters: List[torch.nn.Parameter]

    random_projection_params: 'random_projectors.RandomProjectionParams'

    output_rank: int

    random_projection_batch_size: int

    # Since we need to differentiate through the random projection of the expectation,
    # we use a different way to specify the random projection used for them.
    evrp_seed: int
    evrp_d_projection: int
    evrp_type: str = 'bernoulli'

    error_on_null_grads: bool = True

    def __post_init__(self):
        if self.evrp_type not in _VALID_EVRP_TYPES:
            raise ValueError(f'Invalid evrp_type: {self.evrp_type}')
        if self.evrp_d_projection < self.output_rank:
            raise ValueError('evrp_d_projection must be greater than or equal to output_rank')

        self.flat_packer = flat_pack.FlatPacker(p.shape for p in self.parameters)
        self.n_parameters = self.flat_packer.flat_size

        self.random_projector = random_projectors.RandomProjector.create(params=self.random_projection_params)
        self.d_projection = self.random_projection_params.d_projection

        self._tensors_created = False

    def is_output_projected(self) -> bool:
        return True
        
    #######################################################

    @torch.no_grad()
    def _maybe_create_tensors(self, n_classes_total: int, device: torch.device):
        # Initializes tensors used within this LRM computer.
        if self._tensors_created:
            return
        self._tensors_created = True

        self._pef_vectors = torch.zeros([self.random_projection_batch_size, self.n_parameters], dtype=torch.float32, device=device)

        generator = torch.Generator(device=device)
        generator.manual_seed(self.evrp_seed)

        self._evrp_matrix = torch.rand((self.evrp_d_projection, n_classes_total), generator=generator, dtype=torch.float32, device=device)
        self._evrp_matrix = (self._evrp_matrix > 0.5).type(torch.float32)
        self._evrp_matrix = 2.0 * self._evrp_matrix - 1.0
        self._evrp_matrix /= math.sqrt(float(n_classes_total))

    @torch.no_grad()
    def _reset_shared_tensors(self):
        self._pef_vectors.zero_()

    #######################################################

    def _compute_projected_pef_vectors_for_batch(
        self,
        # shape = [actual_batch_size]
        batch_expectation_projection: torch.Tensor,
        *,
        retain_graph: bool,
    ) -> torch.Tensor:
        actual_batch_size = batch_expectation_projection.numel()

        with torch.no_grad():
            self._pef_vectors.zero_()

        for i, target in enumerate(batch_expectation_projection):
            with torch.no_grad():
                self.model.zero_grad()
                target.backward(retain_graph=retain_graph or i + 1 < actual_batch_size)

                for j, p in enumerate(self.parameters):
                    if p.grad is not None:
                        p_start, p_end = self.flat_packer.get_range_for_tensor_by_index(j)
                        self._pef_vectors[i, p_start:p_end] = torch.reshape(p.grad, [-1])
                    elif self.error_on_null_grads:
                        raise ValueError(f'Received null gradient for parameter with index {j}')

        with torch.no_grad():
            return self.random_projector.project(self._pef_vectors[:actual_batch_size])

    #######################################################

    def compute_dense_pef(self, example_info: PefComputerInput) -> torch.Tensor:
        # NOTE: The example_info.class_indices gets ignored.
        log_probs = example_info.log_probs

        device = log_probs.device

        self._maybe_create_tensors(log_probs.shape[-1], device)
        self._reset_shared_tensors()

        with torch.no_grad():
            # sqrt_probs.shape = [n_classes_total]
            sqrt_probs = torch.sqrt(torch.softmax(log_probs, dim=-1))
            # Be super sure that we do not differentiate through the sqrt_probs.
            sqrt_probs = sqrt_probs.detach()

        expectation_projection = torch.einsum('v,v,pv->p', sqrt_probs, log_probs, self._evrp_matrix)

        projected_pef_vectors = []
        for i in range(0, self.evrp_d_projection, self.random_projection_batch_size):
            batch_expectation_projection = expectation_projection[i : i + self.random_projection_batch_size]
            batch_projected_pef_vectors = self._compute_projected_pef_vectors_for_batch(
                batch_expectation_projection, retain_graph=i + self.random_projection_batch_size < self.evrp_d_projection)
            projected_pef_vectors.append(batch_projected_pef_vectors)

        with torch.no_grad():
            projected_pef_vectors = torch.cat(projected_pef_vectors, dim=0)

            # # Handle the case where we do not have to perform a SVD.
            # if self.evrp_d_projection == self.output_rank:
            #     return projected_pef_vectors

            _, S, Vh = torch.linalg.svd(projected_pef_vectors, full_matrices=False)
            return Vh[:self.output_rank] * S[:self.output_rank, None]
