import numpy as np
import torch

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_sequence

class FastDataLoader:
    
    def __init__(self, dataset: Dataset, batch_size: int, shuffle=True, sample_size=1.0, balance = None) -> None:
        assert sample_size <= 1, f"Sample size hast to be within (0, 1] but got {sample_size=}"
        assert balance in [None, "over", "under"], f"balance has to be None, 'over', or 'under' but got {balance=}"
        self.dataset = dataset

        self.dataset_len = int(len(self.dataset)*sample_size)
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.balance = balance

        self.indices = np.arange(len(self.dataset))

        # Calculate # batches
        n_batches, remainder = divmod(self.dataset_len, self.batch_size)
        if remainder > 0:
            n_batches += 1
        self.n_batches = n_batches

    def __iter__(self):
        # for over and under sampling
        if self.balance:
            self.__sampling()
 
        if self.shuffle:
            np.random.shuffle(self.indices)
            self.indices = self.indices[:self.dataset_len]
        self.i = 0
        #########
        # set datast specific epoch properties #
        # 1. #set features for epoch
        if hasattr(self.dataset, "set_features_for_epoch"):
            self.dataset.set_features_for_epoch()
        return self

    def __next__(self):
        if self.i >= self.dataset_len:
            raise StopIteration
        batch = self.dataset.get_batch(self.indices[self.i:self.i+self.batch_size])
        self.i += self.batch_size
        return batch

    def __len__(self):
        return self.n_batches
    
    def __sampling(self) -> None:
        # for over and undersampling
        dataset_classes = set(self.dataset.targets)
        num_classes = len(dataset_classes)
        indices_per_class, remainder = divmod(self.dataset_len, num_classes)
        choice_size = self.dataset_len
        tmp_idx_list = []

        for i, cls in enumerate(dataset_classes):
            class_indices = np.arange(len(self.dataset))[self.dataset.targets == cls]
            choice_size = min(len(class_indices), choice_size) if self.balance == "under" else indices_per_class + (1 if remainder > i else 0)
            tmp_idx_list.append(
                np.random.choice(class_indices, choice_size, replace=self.balance == "over")
            )

        self.indices = np.array(
            [idx for tmp_indices in tmp_idx_list for idx in tmp_indices[:(choice_size if self.balance == "under" else self.dataset_len)]]
        )

        # adjust batchsize for fewer indices 
        if len(self.indices) != self.dataset_len:
            n_batches, remainder = divmod(len(self.indices), self.batch_size)
            if remainder > 0:
                n_batches += 1
            self.n_batches = n_batches

class SimpleDataset(Dataset):
    def __init__(self, X, y) -> None:
        self.features = X
        self.targets = y

        assert len(self.features) == len(self.targets), "Lenght of X and y is not equal!"

        self.__len = len(self.features)
    
    def __len__(self):
        return self.__len
    
    def  __getitem__(self, index):
        return torch.from_numpy(self.features[index]), torch.from_numpy(self.targets[index])

    def get_batch(self, indices):
        return torch.from_numpy(self.features[indices]), torch.from_numpy(self.targets[indices])
    

class PackSequenceDataset(Dataset):

    def __init__(self, X, y) -> None:
        """_summary_

        :param X: Features for the packed sequence
        :type X: list[torch.Tensor]
        :param y: Targets
        :type y: torch.Tensor
        """
        self.features = X
        self.targets = y

        assert len(self.features) == len(self.targets), "Lenght of X and y is not equal!"

        self.__len = len(self.features)
    
    def __len__(self):
        return self.__len
    
    def __getitem__(self, index):
        return self.features[index], self.targets[index]
    
    def get_batch(self, indices):
        return pack_sequence([self.features[i] for i in indices], enforce_sorted=False), self.targets[indices]

class SequenceDataset(Dataset):

    def __init__(self, X, y, extended_y: dict = dict(), unknown_sample: float = 0.0, unknown_value: int = -1) -> None:
        """[summary]

        :param X: [description]
        :type X: [type]
        :param y: [description]
        :type y: [type]
        :param n: [description], defaults to 1
        :type n: int, optional
        :param unknown_sample: [description], defaults to 0.1
        :type unknown_sample: float, optional
        """

        self.features = X
        self.features_for_epoch = self.features.copy()
        self.targets = y
        self.extended_targets = extended_y
        self.unknown_sample = unknown_sample
        self.unknown_value = unknown_value

        assert len(self.features) == len(self.targets), "Lenght of X and y is not equal!"

        self.__len = len(self.features)
    
    def __len__(self):
        return self.__len
    
    def __getitem__(self, index):
        return torch.from_numpy(self.features_for_epoch[index]), torch.from_numpy(self.targets[index]), {key: torch.from_numpy(v[index]) for key, v in self.extended_targets.items()}
    
    def get_batch(self, indices):
        return torch.from_numpy(self.features_for_epoch[indices]), torch.from_numpy(self.targets[indices]), {key: torch.from_numpy(v[indices]) for key, v in self.extended_targets.items()}

    def set_features_for_epoch(self):
        unknown_idx = np.random.choice(np.arange(self.__len), int(self.__len*self.unknown_sample), replace=False)
        self.features_for_epoch = self.features.copy()
        self.features_for_epoch[unknown_idx] = self.unknown_value
