"""Computes low-rank PEFs by computing their random projection and then using SVD to reduce rank."""
import dataclasses
import json
from typing import List, Optional

import h5py
import torch
from transformers import PreTrainedModel

from npeff_torch.peis import random_projectors
from npeff_torch.util import flat_pack
from . import pef_computer_common


###############################################################################
PefComputerInput = pef_computer_common.PefComputerInput
###############################################################################


@dataclasses.dataclass
class RpSvdLrmComputer(pef_computer_common.PefComputerAbc):

    model: PreTrainedModel

    # List of parameters to compute Fishers for.
    parameters: List[torch.nn.Parameter]

    random_projection_params: 'random_projectors.RandomProjectionParams'

    output_rank: int

    error_on_null_grads: bool = True

    def __post_init__(self):
        self.flat_packer = flat_pack.FlatPacker(p.shape for p in self.parameters)
        self.n_parameters = self.flat_packer.flat_size

        self.random_projector = random_projectors.RandomProjector.create(params=self.random_projection_params)
        self.d_projection = self.random_projection_params.d_projection

        self._tensors_created = False

    def is_output_projected(self) -> bool:
        return True
        
    #######################################################

    def _maybe_create_tensors(self, device):
        # Initializes tensors used within this LRM computer.
        if self._tensors_created:
            return
        self._tensors_created = True

        self._pef_vector = torch.zeros([self.n_parameters], dtype=torch.float32, device=device)

    def _reset_shared_tensors(self):
        self._pef_vector.zero_()

    #######################################################

    def _compute_projected_pef_vector_for_class(
        self,
        example_info: PefComputerInput,
        # sqrt_probs.shape = [n_classes_total]
        sqrt_probs: torch.Tensor,
        class_index: int,
        *,
        retain_graph: bool,
    ) -> torch.Tensor:
        # Fills out self._pef_vector.
        log_prob = example_info.log_probs[class_index]

        with torch.no_grad():
            self._pef_vector.zero_()
            self.model.zero_grad()
            log_prob.backward(retain_graph=retain_graph)

            for j, p in enumerate(self.parameters):
                if p.grad is not None:
                    p_start, p_end = self.flat_packer.get_range_for_tensor_by_index(j)
                    self._pef_vector[p_start:p_end] = torch.reshape(p.grad, [-1])
                elif self.error_on_null_grads:
                    raise ValueError(f'Received null gradient for parameter with index {j}')

            self._pef_vector.mul_(sqrt_probs[class_index])

            return self.random_projector.project(self._pef_vector)

    #######################################################

    def compute_dense_pef(self, example_info: PefComputerInput) -> torch.Tensor:
        log_probs = example_info.log_probs
        class_indices = example_info.class_indices

        device = log_probs.device

        with torch.no_grad():
            self._maybe_create_tensors(device)
            self._reset_shared_tensors()

            # sqrt_probs.shape = [n_classes_total]
            sqrt_probs = torch.sqrt(torch.softmax(log_probs, dim=-1))

        projected_pef_vectors = []
        for i, class_index in enumerate(class_indices):
            projected_pef_vector = self._compute_projected_pef_vector_for_class(example_info, sqrt_probs, class_index, retain_graph=i + 1 < len(class_indices))
            projected_pef_vectors.append(projected_pef_vector)

        with torch.no_grad():
            # Handle the case where we do not have to perform a SVD.
            if len(class_indices) <= self.output_rank:
                dense_pef = torch.zeros([self.output_rank, self.d_projection], dtype=torch.float32, device=device)
                for i, pef_vector in enumerate(projected_pef_vectors):
                    dense_pef[i, :] = pef_vector
                return dense_pef

            projected_pef_vectors = torch.stack(projected_pef_vectors, dim=0)
            _, S, Vh = torch.linalg.svd(projected_pef_vectors, full_matrices=False)
            return Vh[:self.output_rank] * S[:self.output_rank, None]

    #######################################################

    def write_additional_information_to_pefs_file(self, file: h5py.File):
        if 'data' not in file:
            file.create_group('data')
        data_grp = file['data']
        # TODO: I'm not sure if the random projection used in this PEF computer can be shown
        # to be effectively the same as doing a random projection on a regular dense PEF.
        data_grp.attrs['random_projection_params'] = json.dumps(self.random_projection_params.to_json())
        data_grp.attrs['rp_svd_random_projection_params'] = json.dumps(self.random_projection_params.to_json())


###############################################################################


@dataclasses.dataclass
class BatchedRpSvdLrmComputer(pef_computer_common.PefComputerAbc):
    """Same as above, but does the random projection on multiple vectors at once to improve performance."""

    model: PreTrainedModel

    # List of parameters to compute Fishers for.
    parameters: List[torch.nn.Parameter]

    random_projection_params: 'random_projectors.RandomProjectionParams'

    output_rank: int

    random_projection_batch_size: int

    error_on_null_grads: bool = True

    def __post_init__(self):
        self.flat_packer = flat_pack.FlatPacker(p.shape for p in self.parameters)
        self.n_parameters = self.flat_packer.flat_size

        self.random_projector = random_projectors.RandomProjector.create(params=self.random_projection_params)
        self.d_projection = self.random_projection_params.d_projection

        self._tensors_created = False

    def is_output_projected(self) -> bool:
        return True
        
    #######################################################

    @torch.no_grad()
    def _maybe_create_tensors(self, device):
        # Initializes tensors used within this LRM computer.
        if self._tensors_created:
            return
        self._tensors_created = True

        self._pef_vectors = torch.zeros([self.random_projection_batch_size, self.n_parameters], dtype=torch.float32, device=device)

    @torch.no_grad()
    def _reset_shared_tensors(self):
        self._pef_vectors.zero_()

    #######################################################

    def _compute_projected_pef_vectors_for_classes_batch(
        self,
        example_info: PefComputerInput,
        # sqrt_probs.shape = [n_classes_total]
        sqrt_probs: torch.Tensor,
        # shape = [actual_batch_size]
        batch_class_indices: torch.Tensor,
        *,
        retain_graph: bool,
    ) -> torch.Tensor:
        actual_batch_size = batch_class_indices.numel()

        with torch.no_grad():
            self._pef_vectors.zero_()

        for i, class_index in enumerate(batch_class_indices):
            log_prob = example_info.log_probs[class_index]

            with torch.no_grad():
                self.model.zero_grad()
                log_prob.backward(retain_graph=retain_graph or i + 1 < actual_batch_size)

                for j, p in enumerate(self.parameters):
                    if p.grad is not None:
                        p_start, p_end = self.flat_packer.get_range_for_tensor_by_index(j)
                        self._pef_vectors[i, p_start:p_end] = torch.reshape(p.grad, [-1])
                    elif self.error_on_null_grads:
                        raise ValueError(f'Received null gradient for parameter with index {j}')

                self._pef_vectors[i].mul_(sqrt_probs[class_index])

        with torch.no_grad():
            return self.random_projector.project(self._pef_vectors[:actual_batch_size])

    #######################################################

    def compute_dense_pef(self, example_info: PefComputerInput) -> torch.Tensor:
        log_probs = example_info.log_probs
        class_indices = example_info.class_indices

        device = log_probs.device

        self._maybe_create_tensors(device)
        self._reset_shared_tensors()

        with torch.no_grad():
            # sqrt_probs.shape = [n_classes_total]
            sqrt_probs = torch.sqrt(torch.softmax(log_probs, dim=-1))

        projected_pef_vectors = []
        for i in range(0, len(class_indices), self.random_projection_batch_size):
            batch_class_indices = class_indices[i : i + self.random_projection_batch_size]
            batch_projected_pef_vectors = self._compute_projected_pef_vectors_for_classes_batch(
                example_info, sqrt_probs, batch_class_indices,
                retain_graph=i + self.random_projection_batch_size < len(class_indices))
            projected_pef_vectors.append(batch_projected_pef_vectors)

        with torch.no_grad():
            projected_pef_vectors = torch.cat(projected_pef_vectors, dim=0)
            _, S, Vh = torch.linalg.svd(projected_pef_vectors, full_matrices=False)
            dense_pef = Vh[:self.output_rank] * S[:self.output_rank, None]

            # Handle the case where we need to pad to produce the output of the correct shape.
            if dense_pef.shape[0] < self.output_rank:
                ret = torch.zeros([self.output_rank, self.d_projection], dtype=torch.float32, device=device)
                ret[:dense_pef.shape[0], :] = dense_pef
                dense_pef = ret

            return dense_pef

        # with torch.no_grad():
        #     # Handle the case where we do not have to perform a SVD.
        #     if len(class_indices) <= self.output_rank:
        #         dense_pef = torch.zeros([self.output_rank, self.d_projection], dtype=torch.float32, device=device)

        #         offset = 0
        #         for batch_pef_vectors in projected_pef_vectors:
        #             dense_pef[offset : offset + batch_pef_vectors.shape[0], :] = batch_projected_pef_vectors
        #             offset += batch_pef_vectors.shape[0]

        #         return dense_pef

        #     projected_pef_vectors = torch.cat(projected_pef_vectors, dim=0)
        #     _, S, Vh = torch.linalg.svd(projected_pef_vectors, full_matrices=False)
        #     return Vh[:self.output_rank] * S[:self.output_rank, None]
