"""Contexts around evaluating a model with given parameters over a set of examples."""
import abc
import dataclasses
from typing import Dict, Iterator, Optional

import torch
from transformers import PreTrainedModel

from npeff_torch.models import model_utils


# Have something storing examples
#   - Optionally store token positions for some types of models.
# Have a logits generator or something
#   - One version literally just reads from a saved tensor of logits
#   - Another version produces predictions from a model.

###############################################################################


class ExampleBatch:

    def __init__(
        self, *,

        # The offset of the first example in the batch in the larger `Examples` instance.
        offset: int,

        # examples[*].shape = [batch_size, ...]
        examples: Dict[str, torch.Tensor],

        # shape=[batch_size], dtype=torch.int64
        token_positions: Optional[torch.Tensor],

    ):
        self.offset = offset
        self.examples = examples
        self.token_positions = token_positions

        self.batch_size = list(self.examples.values())[0].shape[0]


class Examples:
    """A list of examples."""

    def __init__(
        self, *,

        # examples[*].shape = [n_examples, ...]
        examples: Dict[str, torch.Tensor],

        # shape=[n_examples], dtype=torch.int32
        token_positions: Optional[torch.Tensor] = None,
    ):
        self.examples = examples
        self.token_positions = token_positions.type(torch.int64) if token_positions is not None else None

        self.n_examples = list(self.examples.values())[0].shape[0]

    def to(self, device: torch.device) -> 'Examples':
        """Moves the tensors attached to the instance to the device and returns self."""
        self.examples = {k: v.to(device) for k, v in self.examples.items()}
        if self.token_positions is not None:
            self.token_positions = self.token_positions.to(device)
        return self

    @property
    def sequence_length(self) -> int:
        return self.examples['input_ids'].shape[-1]

    @torch.no_grad()
    def get_batches(self, batch_size: int) -> Iterator[ExampleBatch]:
        for offset in range(0, self.n_examples, batch_size):
            yield ExampleBatch(
                offset=offset,
                examples={k: v[offset : offset + batch_size] for k, v in self.examples.items()},
                token_positions=self.token_positions[offset : offset + batch_size] if self.token_positions is not None else None
            )

    @torch.no_grad()
    def gather_examples(self, example_indices: torch.Tensor) -> 'Examples':
        # example_indices.shape = [n_selected], dtype=torch.int64
        assert len(example_indices.shape) == 1, 'example_indices must be a vector.'
        
        return self.__class__(
            examples={k: v[example_indices] for k, v in self.examples.items()},
            token_positions=self.token_positions[example_indices] if self.token_positions is not None else None
        )

    def as_torch_dataset(self) -> 'ExamplesTorchDataset':
        return ExamplesTorchDataset(self)


###############################################################################


class ExamplesTorchDataset(torch.utils.data.Dataset):

    def __init__(self, examples: 'Examples'):
        self._examples = examples

    def __len__(self) -> int:
        return self._examples.n_examples

    @torch.no_grad()
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        # TODO: See if want to include the token positions in the dataset.
        return {k: v[idx] for k, v in self._examples.examples.items()}


###############################################################################
###############################################################################


class EvaluationBatchInfo:
    def __init__(
        self, *,
        example_batch: ExampleBatch,
        # shape = [batch_size, n_classes]
        logits: torch.Tensor,
    ):
        self.example_batch = example_batch
        self.logits = logits

        self.offset = example_batch.offset
        self.batch_size = example_batch.batch_size


class EvaluatorAbc(abc.ABC):

    @abc.abstractmethod
    def _compute_raw_batch_logits(self, batch: ExampleBatch) -> torch.Tensor:
        """Returns the raw logits for the given batch."""
        raise NotImplementedError

    #######################################################

    @torch.no_grad()
    def compute_batch_info(self, batch: ExampleBatch) -> EvaluationBatchInfo:
        """Returns the logits for the given batch.
    
        Returns:
            A float32 tensor of shape [batch_size, n_classes].
        """
        logits = self._compute_raw_batch_logits(batch)

        # Select logits from the given token positions if logits.shape = [batch_size, sequence, n_classes].
        if batch.token_positions is not None and len(logits.shape) == 3:
            # TODO: See if there is a better way to do this. WTF PyTorch?
            logits = torch.stack([
                ex_logits[ex_pos]
                for ex_logits, ex_pos in zip(logits, batch.token_positions)
            ], dim=0)

        # assert logits.shape = [batch_size, n_classes]
        assert len(logits.shape) == 2

        return EvaluationBatchInfo(example_batch=batch, logits=logits)

    def compute_batch_infos(self, examples: Examples, batch_size: int) -> Iterator[EvaluationBatchInfo]:
        for batch in examples.get_batches(batch_size):
            yield self.compute_batch_info(batch)


###############################################################################


class StoredLogitsEvaluator(EvaluatorAbc):
    """Reads logits from a fixed tensor."""

    def __init__(
        self, *,
        # logits.shape = [n_examples, sequence?, n_classes]
        logits: torch.Tensor,
    ):
        self.logits = logits

    def _compute_raw_batch_logits(self, batch: ExampleBatch) -> torch.Tensor:
        return self.logits[batch.offset : batch.offset + batch.batch_size]


class ModelEvaluator(EvaluatorAbc):
    """Computes logits given a model."""

    def __init__(
        self, *,
        model: PreTrainedModel,
        device: Optional[torch.device] = None,
    ):
        self.model = model
        self.device = device

    def _compute_raw_batch_logits(self, batch: ExampleBatch) -> torch.Tensor:
        return model_utils.compute_logits(self.model, batch.examples, self.device)


###############################################################################
###############################################################################


# class PerturbationBatchInfo:

#     def __init__(
#         self, *,
#         # Assumes that the examples are consistent amongst these.
#         original_eval_info: EvaluationBatchInfo,
#         perturbed_eval_info: EvaluationBatchInfo,
#     ):
#         self.original_eval_info = original_eval_info
#         self.perturbed_eval_info = perturbed_eval_info

