"""Computes low-rank PEFs by computing their random projection and then using SVD to reduce rank.

Approximates the expectation via sampling from the model's predictive distribution.
"""
import dataclasses
import json
import math
from typing import List, Optional

import h5py
import torch
from transformers import PreTrainedModel

from npeff_torch.peis import random_projectors
from npeff_torch.util import flat_pack
from . import pef_computer_common


###############################################################################
PefComputerInput = pef_computer_common.PefComputerInput
###############################################################################


@dataclasses.dataclass
class BatchedSamplingRpSvdLrmComputer(pef_computer_common.PefComputerAbc):

    model: PreTrainedModel

    # List of parameters to compute Fishers for.
    parameters: List[torch.nn.Parameter]

    random_projection_params: 'random_projectors.RandomProjectionParams'

    output_rank: int

    random_projection_batch_size: int

    # Numbers of samples to take per example.
    n_samples: int

    # Seed for the generator.
    sampler_seed: int

    error_on_null_grads: bool = True

    def __post_init__(self):
        self.flat_packer = flat_pack.FlatPacker(p.shape for p in self.parameters)
        self.n_parameters = self.flat_packer.flat_size

        self.random_projector = random_projectors.RandomProjector.create(params=self.random_projection_params)
        self.d_projection = self.random_projection_params.d_projection

        self.generator = None

        self._tensors_created = False

    def is_output_projected(self) -> bool:
        return True
        
    #######################################################

    @torch.no_grad()
    def _maybe_create_tensors(self, device):
        # Initializes tensors used within this LRM computer.
        if self._tensors_created:
            return
        self._tensors_created = True

        self._pef_vectors = torch.zeros([self.random_projection_batch_size, self.n_parameters], dtype=torch.float32, device=device)

        self.generator = torch.Generator(device=device)
        self.generator.manual_seed(self.sampler_seed)

    @torch.no_grad()
    def _reset_shared_tensors(self):
        self._pef_vectors.zero_()

    #######################################################

    def _compute_projected_pef_vectors_for_classes_batch(
        self,
        example_info: PefComputerInput,
        # shape = [actual_batch_size]
        batch_class_indices: torch.Tensor,
        *,
        retain_graph: bool,
    ) -> torch.Tensor:
        actual_batch_size = batch_class_indices.numel()

        with torch.no_grad():
            self._pef_vectors.zero_()

        for i, class_index in enumerate(batch_class_indices):
            log_prob = example_info.log_probs[class_index]

            with torch.no_grad():
                self.model.zero_grad()
                log_prob.backward(retain_graph=retain_graph or i + 1 < actual_batch_size)

                for j, p in enumerate(self.parameters):
                    if p.grad is not None:
                        p_start, p_end = self.flat_packer.get_range_for_tensor_by_index(j)
                        self._pef_vectors[i, p_start:p_end] = torch.reshape(p.grad, [-1])
                    elif self.error_on_null_grads:
                        raise ValueError(f'Received null gradient for parameter with index {j}')

                self._pef_vectors[i].mul_(1.0 / math.sqrt(self.n_samples))

        with torch.no_grad():
            return self.random_projector.project(self._pef_vectors[:actual_batch_size])

    #######################################################

    def compute_dense_pef(self, example_info: PefComputerInput) -> torch.Tensor:
        # NOTE: example_info.class_indices are NOT used here.
        log_probs = example_info.log_probs

        device = log_probs.device

        self._maybe_create_tensors(device)
        self._reset_shared_tensors()

        with torch.no_grad():
            # probs.shape = [n_classes_total]
            probs = torch.softmax(log_probs, dim=-1)
            class_indices = torch.multinomial(probs, self.n_samples, replacement=True, generator=self.generator)

        projected_pef_vectors = []
        for i in range(0, len(class_indices), self.random_projection_batch_size):
            batch_class_indices = class_indices[i : i + self.random_projection_batch_size]
            batch_projected_pef_vectors = self._compute_projected_pef_vectors_for_classes_batch(
                example_info, batch_class_indices,
                retain_graph=i + self.random_projection_batch_size < len(class_indices))
            projected_pef_vectors.append(batch_projected_pef_vectors)

        with torch.no_grad():
            projected_pef_vectors = torch.cat(projected_pef_vectors, dim=0)
            _, S, Vh = torch.linalg.svd(projected_pef_vectors, full_matrices=False)
            dense_pef = Vh[:self.output_rank] * S[:self.output_rank, None]

            # Handle the case where we need to pad to produce the output of the correct shape.
            if dense_pef.shape[0] < self.output_rank:
                ret = torch.zeros([self.output_rank, self.d_projection], dtype=torch.float32, device=device)
                ret[:dense_pef.shape[0], :] = dense_pef
                dense_pef = ret

            return dense_pef
