"""Common ABCs and functionality for saving gradients to disk."""
import abc
import dataclasses
from typing import Dict, Optional

import h5py
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import PreTrainedModel

from npeff_torch.models import model_utils
from npeff_torch.peis import position_selectors
from npeff_torch.peis.gradients import gradient_computers
from npeff_torch.peis.gradients import logit_functions


###############################################################################


@dataclasses.dataclass
class ExampleDenseGradientComputationInfo:
    # shape = [], dtype=int32
    position: Optional[torch.Tensor]

    # shape = [], dtype=int32
    label: Optional[torch.Tensor]

    # shape = [n_classes_total]
    log_probs: torch.Tensor

    # shape = [n_parameters]
    dense_gradient: torch.Tensor

    # shape = []
    fn_value: torch.Tensor


###############################################################################


class StreamingGradientSaverAbc(abc.ABC):

    @classmethod
    def create(
        cls,
        model: PreTrainedModel,
        gradient_computer: 'gradient_computers.GradientComputer',
        position_selector: 'position_selectors.PositionSelectorAbc',
        label_key: Optional[str],
        device: torch.device,
        # save_logits, save_logits, save_examples
        **kwargs,
    ):
        return cls(
            model=model,
            gradient_computer=gradient_computer,
            position_selector=position_selector,
            label_key=label_key,
            device=device,
            **kwargs,
        )

    def compute_dense_gradient_info(self, batch: Dict[str, torch.Tensor]) -> Optional[ExampleDenseGradientComputationInfo]:
        # Returns None when something indicates to skip this example.
        logits = model_utils.compute_logits(self.model, batch, self.device)
        log_probs = torch.log_softmax(logits, axis=-1)
        nb_log_probs = torch.squeeze(log_probs, dim=0)

        device_batch = {k: v.to(self.device) for k, v in batch.items()}

        labels = device_batch.get(self.label_key, None)
        nb_labels = None if labels is None else torch.squeeze(labels, dim=0)

        n_non_paddings = torch.sum(device_batch['attention_mask'], dim=-1)

        position_selector_input = position_selectors.PositionSelectorInput(
            examples=device_batch,
            n_non_paddings=n_non_paddings,
            logits=log_probs,
            labels=labels,
        )
        position = self.position_selector.select_positions(position_selector_input)

        if position is not None:
            # Remove the batch dim.
            position = torch.squeeze(position, dim=0)
            # A negative position indicates to skip this example.
            if position < 0:
                return None

            nb_position_log_probs = nb_log_probs[position]
            
            if nb_labels is None:
                nb_position_labels = None
            elif nb_labels.ndim == 0:
                nb_position_labels = nb_labels
            else:
                nb_position_labels = nb_labels[position]

        else:
            nb_position_log_probs = nb_log_probs
            nb_position_labels = nb_labels

        gradient_computer_input = logit_functions.LogitFunctionInput(
            log_probs=nb_position_log_probs,
            label=nb_position_labels,
        )

        grad_info = self.gradient_computer.compute_dense_gradient(gradient_computer_input)

        return ExampleDenseGradientComputationInfo(
            position=position,
            label=nb_position_labels,
            log_probs=nb_position_log_probs,
            dense_gradient=grad_info.dense_gradient,
            fn_value=grad_info.fn_value,
        )

    #######################################################

    @abc.abstractmethod
    def _initialize_file(self, file: h5py.File, n_examples: int):
        raise NotImplementedError

    @abc.abstractmethod
    def _finalize_file(self, file: h5py.File):
        raise NotImplementedError

    @abc.abstractmethod
    def _process_batch(self, batch: Dict[str, torch.Tensor]) -> bool:
        """Process a particular example.
        
        Returns:
            True if the example was succesfully processed.
            False if the example was skipped.
        """
        raise NotImplementedError

    # #######################################################

    def compute_and_save_gradients(self, filepath: str, dataloader: DataLoader, n_examples: int):
        # The dataloader must be batched with a batch_size of 1.
        self.model.eval()

        with h5py.File(filepath, "w") as file:
            self._initialize_file(file, n_examples)

            n_examples_processed = 0
            progress_bar = tqdm(range(n_examples))

            for batch in dataloader:
                processed = self._process_batch(batch)
                if not processed:
                    continue

                n_examples_processed += 1
                progress_bar.update(1)

                if n_examples_processed >= n_examples:
                    break

            progress_bar.close()
            
            self._finalize_file(file)
