from typing import List

import torch
from torch.utils.data import Dataset as TorchDataset

from src.utils import json_load


class JsonDataset(TorchDataset):
    """ Load dataset from json file. """
    def __init__(self, filename):
        self.datalist = json_load(filename)

    def __len__(self):
        return len(self.datalist)

    def __getitem__(self, i):
        return self.datalist[i]

    def shuffle(self) -> TorchDataset:
        indices = torch.randperm(len(self))
        dataset = torch.utils.data.Subset(self, indices)
        return dataset

    def __add__(self, other):
        self.datalist.extend(other.datalist)


class LogitsData(str):
    def __init__(self, data: List[dict]):
        super().__init__()
        self.data = data
        self.limit = len(self.data)
        self.pointer = 0

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return self.data[i]

    def __iter__(self):
        return self

    def __next__(self):
        if self.pointer < self.limit:
            val = self.data[self.pointer]
            self.pointer += 1
            return val
        else:
            self.pointer = 0
            raise StopIteration


class DistillingDataset(JsonDataset):
    """ Dataset for collecting logits data. """
    def __init__(self, filename):
        super().__init__(filename)
        assert "logits" in self.datalist[0].keys()
        for data in self.datalist:
            data["logits"] = [LogitsData(item) for item in data["logits"]]
