"""Method of finding examples from a paper.

https://openreview.net/pdf?id=djmjglxOZ7
"""
from typing import Any, Dict, Sequence

# import datasets
from transformers import PreTrainedModel
import torch

from npeff_torch.icl import icl_datasets_common

###############################################################################


# TODO: Maybe move to a common place.
class ExamplesCollection:
    
    def __init__(
        self,
        examples: Sequence[Dict[str, Any]]
    ):
        self.examples = tuple(examples)
        self.n_examples = len(self.examples)

    def __len__(self) -> int:
        return self.n_examples

    def __getitem__(self, key):
        if isinstance(key, slice):
            return self.__class__(self.examples[key])

        elif isinstance(key, int):
            return self.examples[key]

        elif isinstance(key, torch.Tensor):
            # Make sure is a 1-d integer tensor.
            assert key.dtype in (torch.int64, torch.int32)
            assert len(key.shape) == 1
            return self.__class__([self.examples[int(k)] for k in key.detach().cpu().numpy()])

        else:
            raise TypeError(f'Invalid argument type: {key}')


###############################################################################


class ProgressiveFilter:
    """
    Algorithm 1 Progressive Example Filtering from https://openreview.net/pdf?id=djmjglxOZ7
    """

    def __init__(
        self, *,
        model: PreTrainedModel,
        icl_example_helper: 'icl_datasets_common.IclExampleHelperAbc',
        examples_collection: 'ExamplesCollection',

        batch_size: int,

        # desired_candidate_size: int,
        # progressive_factor: float,
        # initial_score_dataset_size: int,
    ):
        self.model = model
        self.icl_example_helper = icl_example_helper
        self.examples_collection = examples_collection

        self.batch_size = batch_size


###############################################################################

# icl_datasets_common.IclExampleHelperAbc


R"""

- Need some representation/structure for the dataset.
    - Probably text.
    - Need a way to do the templating and tokenization stuff.

"""
