"""Common ABCs and functionality for saving PEFs to disk."""
import abc
import dataclasses
import collections
import itertools
import json
import os
from typing import Dict, List, Optional, Tuple

import h5py
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import PreTrainedModel

from npeff_torch.models import model_utils
from npeff_torch.models import parameter_infos as parameter_infos_module
from npeff_torch.peis import position_selectors
from npeff_torch.peis.fishers import class_selectors
from npeff_torch.peis.fishers.computers import pef_computer_common
from npeff_torch.util import hdf5_utils


###############################################################################


@dataclasses.dataclass
class ExampleDensePefComputationInfo:
    # shape = [], dtype=int32
    position: Optional[torch.Tensor]

    # shape = [], dtype=int32
    label: Optional[torch.Tensor]

    # shape = [n_classes_total]
    log_probs: torch.Tensor

    # shape = [n_selected_classes], dtype=int32
    class_indices: Optional[torch.Tensor]

    dense_pef: torch.Tensor


###############################################################################


class StreamingPefSaverAbc(abc.ABC):

    @classmethod
    def create(
        cls,
        model: PreTrainedModel,
        fisher_computer: 'pef_computer_common.PefComputerAbc',
        position_selector: 'position_selectors.PositionSelectorAbc',
        class_subset_selector: 'class_selectors.ClassSubsetSelectorAbc',
        label_key: Optional[str],
        device: torch.device,
        # save_logits, save_logits, save_examples
        **kwargs,
    ):
        return cls(
            model=model,
            fisher_computer=fisher_computer,
            position_selector=position_selector,
            class_subset_selector=class_subset_selector,
            label_key=label_key,
            device=device,
            **kwargs,
        )

    #######################################################

    def compute_dense_pef_info(self, batch: Dict[str, torch.Tensor]) -> Optional[ExampleDensePefComputationInfo]:
        # Returns None when something indicates to skip this example.
        logits = model_utils.compute_logits(self.model, batch, self.device)
        log_probs = torch.log_softmax(logits, axis=-1)
        nb_log_probs = torch.squeeze(log_probs, dim=0)

        device_batch = {k: v.to(self.device) for k, v in batch.items()}
        nb_device_batch = {k: torch.squeeze(v, dim=0) for k, v in device_batch.items()}

        labels = device_batch.get(self.label_key, None)
        nb_labels = None if labels is None else torch.squeeze(labels, dim=0)

        n_non_paddings = torch.sum(device_batch['attention_mask'], dim=-1)

        position_selector_input = position_selectors.PositionSelectorInput(
            examples=device_batch,
            n_non_paddings=n_non_paddings,
            logits=log_probs,
            labels=labels,
        )
        position = self.position_selector.select_positions(position_selector_input)

        if position is not None:
            # Remove the batch dim.
            position = torch.squeeze(position, dim=0)
            # A negative position indicates to skip this example.
            if position < 0:
                return None

            nb_position_log_probs = nb_log_probs[position]

            if nb_labels is None:
                nb_position_labels = None
            elif nb_labels.ndim == 0:
                nb_position_labels = nb_labels
            else:
                nb_position_labels = nb_labels[position]

        else:
            nb_position_log_probs = nb_log_probs
            nb_position_labels = nb_labels

        class_subset_selector_input = class_selectors.ClassSubsetSelectorInput(
            example=nb_device_batch,
            logits=nb_position_log_probs,
            label=nb_position_labels,
        )
        selected_classes = self.class_subset_selector.select_classes(class_subset_selector_input)

        fisher_computer_input = pef_computer_common.PefComputerInput(
            log_probs=nb_position_log_probs,
            class_indices=selected_classes,
        )

        dense_pef = self.fisher_computer.compute_dense_pef(fisher_computer_input)

        return ExampleDensePefComputationInfo(
            position=position,
            label=nb_position_labels,
            log_probs=nb_position_log_probs,
            class_indices=selected_classes,
            dense_pef=dense_pef,
        )

    #######################################################

    @abc.abstractmethod
    def _initialize_file(self, file: h5py.File, n_examples: int):
        raise NotImplementedError

    @abc.abstractmethod
    def _finalize_file(self, file: h5py.File):
        raise NotImplementedError

    @abc.abstractmethod
    def _process_batch(self, batch: Dict[str, torch.Tensor]) -> bool:
        """Process a particular example.
        
        Returns:
            True if the example was succesfully processed.
            False if the example was skipped.
        """
        raise NotImplementedError

    #######################################################

    def compute_and_save_pefs(self, filepath: str, dataloader: DataLoader, n_examples: int):
        # The dataloader must be batched with a batch_size of 1.
        self.model.eval()

        with h5py.File(filepath, "w") as file:
            self._initialize_file(file, n_examples)

            n_examples_processed = 0
            progress_bar = tqdm(range(n_examples))

            for batch in dataloader:
                processed = self._process_batch(batch)
                if not processed:
                    continue

                n_examples_processed += 1
                progress_bar.update(1)

                if n_examples_processed >= n_examples:
                    break

            progress_bar.close()
            
            self.fisher_computer.write_additional_information_to_pefs_file(file)
            self._finalize_file(file)


###############################################################################
###############################################################################


@dataclasses.dataclass
class PefExtraInfos:
    """Non-PEF related example information that might be included in saved PEF files."""

    parameter_infos: Optional[Tuple['parameter_infos_module.ParameterInfo', ...]] = None

    token_positions: Optional[np.ndarray] = None

    examples: Optional[Dict[str, np.ndarray]] = None
    labels: Optional[np.ndarray] = None
    logits: Optional[np.ndarray] = None

    top_log_probs_class_indices: Optional[np.ndarray] = None
    top_log_probs_values: Optional[np.ndarray] = None

    @property
    def n_examples(self) -> Optional[int]:
        if self.token_positions is not None:
            return self.token_positions.shape[0]
        if self.labels is not None:
            return self.labels.shape[0]
        if self.logits is not None:
            return self.logits.shape[0]
        if self.top_log_probs_class_indices is not None:
            return self.top_log_probs_class_indices.shape[0]
        if self.top_log_probs_values is not None:
            return self.top_log_probs_values.shape[0]
        if self.examples is not None:
            return list(self.examples.values())[0].shape[0]
        return None

    def get_slice(self, start: int, end: int) -> 'PefExtraInfos':
        """Returns an instance of this class corresponding to the subset of examples."""
        if self.examples is not None:
            examples = {k: v[start:end] for k, v in self.examples.items()}
        else:
            examples = None

        maybe_slice = lambda a: None if a is None else a[start:end]

        return PefExtraInfos(
            parameter_infos=self.parameter_infos,
            # 
            token_positions=maybe_slice(self.token_positions),
            labels=maybe_slice(self.labels),
            logits=maybe_slice(self.logits),
            top_log_probs_class_indices=maybe_slice(self.top_log_probs_class_indices),
            top_log_probs_values=maybe_slice(self.top_log_probs_values),
            #
            examples=examples,
        )

    @classmethod
    def concat(cls, pef_extra_infos: List['PefExtraInfos'], *, ensure_same_fields: bool = True) -> 'PefExtraInfos':
        # If ensure_same_fields is True, then all of the peis must have the same set of non-None fields
        # and example keys. If ensure_same_fields is False, then fields and example keys not present on every
        # pei will be set to None in the concatenated pei.

        if len(pef_extra_infos) == 0:
            raise ValueError

        # NOTE: Examples are handled separately and are not in the field_values.
        field_names = ('token_positions', 'labels', 'logits', 'top_log_probs_class_indices', 'top_log_probs_values')
        field_values = {k: [] for k in field_names}
        pei_examples = collections.defaultdict(list)

        for pei in pef_extra_infos:
            for field_name in field_names:
                field_value = getattr(pei, field_name)
                if field_value is not None:
                    field_values[field_name].append(field_value)
            if pei.examples is not None:
                for k, v in pei.examples.items():
                    pei_examples[k].append(v)

        def concat_fields(field_values):
            fields = {}
            for field_name, values in field_values.items():
                if len(values) == 0:
                    fields[field_name] = None
                elif len(values) == len(pef_extra_infos):
                    fields[field_name] = np.concatenate(values, axis=0)
                elif ensure_same_fields:
                    raise ValueError
                else:
                    fields[field_name] = None
            return fields

        fields = concat_fields(field_values)

        examples = concat_fields(pei_examples)
        examples = {k: v for k, v in examples.items() if v is not None}
        if not examples:
            examples = None

        fields['examples'] = examples

        # Handle the parameter_infos. Check that they are the same and always throw an exception
        # if they are not.
        if len(set(pei.parameter_infos for pei in pef_extra_infos)) != 1:
            raise ValueError
        # These are all the same, so pick one arbitrarily.
        fields['parameter_infos'] = pef_extra_infos[0].parameter_infos

        return cls(**fields)

    @classmethod
    def read_from_file(cls, filepath: str) -> 'PefExtraInfos':
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            if 'data/examples' in f:
                examples = {}
                for k, ds in f['data/examples'].items():
                    examples[k] = hdf5_utils.load_h5_ds(ds)
            else:
                examples = None

            parameter_infos = f['data'].attrs.get('parameter_infos', None)
            if parameter_infos is not None:
                parameter_infos = tuple(
                    parameter_infos_module.ParameterInfo.from_json(p)
                    for p in json.loads(parameter_infos)
                )
            
            return cls(
                parameter_infos=parameter_infos,
                #
                token_positions=hdf5_utils.load_h5_ds_if_exists(f, 'data/token_positions'),
                labels=hdf5_utils.load_h5_ds_if_exists(f, 'data/labels'),
                logits=hdf5_utils.load_h5_ds_if_exists(f, 'data/logits'),
                top_log_probs_class_indices=hdf5_utils.load_h5_ds_if_exists(f, 'data/top_log_probs/class_indices'),
                top_log_probs_values=hdf5_utils.load_h5_ds_if_exists(f, 'data/top_log_probs/values'),
                #
                examples=examples,
            )

    @classmethod
    def read_from_files(cls, filepaths: List[str], n_examples_per_pef: Optional[List[Optional[int]]] = None) -> 'PefExtraInfos':
        # Validate n_examples_per_pef.
        if n_examples_per_pef is not None and len(filepaths) != len(n_examples_per_pef):
            raise ValueError('If `n_examples_per_pef` is provided, its number of entries must match that of the `filepaths`.')

        pef_extra_infos = []
        for i, pef_filepath in enumerate(filepaths):
            pei = cls.read_from_file(pef_filepath)
            pei_n_examples = pei.n_examples
            assert pei_n_examples is not None

            n_examples = n_examples_per_pef[i] if n_examples_per_pef is not None else None
            if n_examples is not None:
                if n_examples < pei_n_examples:
                    pei = pei.get_slice(0, n_examples)
                elif n_examples > pei_n_examples:
                    raise ValueError

            pef_extra_infos.append(pei)

        return cls.concat(pef_extra_infos, ensure_same_fields=True)

