import torch
import numpy as np

from torch.utils.data import Dataset as tDataset
from torch.utils.data import DataLoader as tDataLoader
from torch.utils.data import Subset as tSubset

from collections.abc import Iterable, Mapping

from typing import Callable, Union, List, Dict, Tuple, Set, Iterable

from copy import copy
from tqdm import tqdm


def any_to_device_nested(whatev, device=None, dtype=None, respect_floats=False):
    if isinstance(whatev, torch.Tensor): # or isinstance(whatev, nn.Module):
        return whatev.to(device=device, dtype=dtype) if whatev.dtype.is_floating_point or not respect_floats else whatev.to(device=device)
    elif isinstance(whatev, Dict):
        return {k: any_to_device_nested(v, device=device, dtype=dtype, respect_floats=respect_floats) for k,v in whatev.items()}
    elif isinstance(whatev, List) or isinstance(whatev, Tuple):
        return [any_to_device_nested(v, device=device, dtype=dtype, respect_floats=respect_floats) for v in whatev]
    else:
        raise TypeError(f'Cannot attempt to put stuff ona device, stuff in question: {type(whatev)}')


class tDatasetWrapper(tDataset):
    def __init__(self, ds: tDataset):
        self.data = ds

    @property
    def unwrapped(self) -> tDataset:
        return self.data.unwrapped if isinstance(self.data, tDatasetWrapper) else self.data

    @property
    def unwrap_self(self):
        return self.data
    
    @property
    def list_wrappers(self) -> List:
        return [type(self),] + self.unwrap_self.list_wrappers if isinstance(self.unwrap_self, tDatasetWrapper) else []
    
    def has_wrapper(self, which_type) -> bool:
        return which_type in self.list_wrappers
    
    def require_wrapper(self, which_type, check_original_ds = False) -> bool:
        return (which_type in self.list_wrappers) or (check_original_ds and isinstance(self.unwrapped, which_type))
    
    def __len__(self):
        return len(self.unwrap_self)  # default len function to spare the annoyance

    def __getattr__(self, name): # getattr is called when the attribute is not found, getattribute before looking
        return getattr(self.unwrap_self, name)

    def _str_my_parameters(self):
        return ''

    def __str__(self):
        return f'{self.type}({self._str_my_parameters()})->{str(self.unwrap_self)}'


class SubsetWrapper(tDatasetWrapper):
    def __init__(self, ds, subset:List=None):
        super(SubsetWrapper, self).__init__(ds)

        # sorted optinal, sets may be superslo to convert deterministically
        # self.subset = set(subset) if subset is not None else set(range(len(self.data)))
        self.subset = list(sorted(subset)) if subset is not None else range(len(self.data))

    def get_real_sample_id(self, item: int):
        # somewhat important to have it consistent-ish
        return self.get_real_sample_ids_all()[item]
    
    def get_real_sample_ids(self, items: Iterable[int]):
        return [self.get_real_sample_id(i) for i in items]
    
    def get_real_sample_ids_all(self):
        # not gonna sort the index on iteration anymore
        return self.subset #sorted(list(self.subset))

    def get_sample_id_by_parent_id(self, parent_idx):
        return self.get_real_sample_ids_all().index(parent_idx)

    def add(self, new_item_real_id: int, error_on_existing=True):
        if error_on_existing and new_item_real_id in self.subset:
            raise ValueError(f'{new_item_real_id} is already in the subset! ')
        self.subset = list(sorted(set(self.subset).add(new_item_real_id)))
        # self.subset.add(new_item_real_id)

    def remove(self, delete_item_real_id: int, error_on_missing=True):
        if error_on_missing and not delete_item_real_id in self.subset:
            raise ValueError(f'Cannot remove {delete_item_real_id} as it is not present in the subset')
        self.subset = list(sorted(set(self.subset) - {delete_item_real_id}))

    def removes(self, delete_items_real_ids: List[int], error_on_missing=True):
        [self.remove(i, error_on_missing=error_on_missing) for i in delete_items_real_ids]
    
    def adds(self, new_item_real_ids: Iterable[int], error_on_existing=True):
        [self.add(i, error_on_existing=error_on_existing) for i in new_item_real_ids]

    def inverse(self):
        inverse_set = list(set(range(len(self.unwrap_self))) - set(self.subset))
        return SubsetWrapper(self.unwrap_self, inverse_set)
    
    def subset_fraction(self):
        return  len(self) / len(self.unwrap_self)
    
    def peel(self):
        assert isinstance(self.unwrap_self, SubsetWrapper), f"Trying to Peel a wrapper that is not SubetWrapper but rather {type(self.unwrap_self)}"
        # peel the subset wrapper to reduce the layers of subsets
        # get the ids of the intermediate wrapper
        real_ids = self.get_real_sample_ids_all()
        # use them to index the unwrap_self
        unwrap_self_ids = self.unwrap_self.get_real_sample_ids_all()
        translated_ids = [unwrap_self_ids[rid] for rid in real_ids] # translate the ids using a list comprehension
        return SubsetWrapper(self.unwrap_self.unwrap_self, subset=set(translated_ids))

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, item):
        return self.data[self.get_real_sample_id(item)]
    
    @staticmethod
    def reallocate_sample_between_subsets(source, dest, source_index: int):
        assert source.unwrap_self == dest.unwrap_self, "Cannot move between subsets of different dataset"
        unwrapped_id = source.get_real_sample_id(source_index)
        # delete from the destination and add to the source
        source.remove(unwrapped_id)
        dest.add(unwrapped_id)
        return unwrapped_id
    
    @staticmethod
    def reallocate_samples_between_subsets(source, dest, source_index: Iterable[int]):
        '''
        Always use this for multiple realocations! The other method will bug out since between remove and adding the real_ids change
        '''
        assert isinstance(source, SubsetWrapper) and isinstance(dest, SubsetWrapper), "Can only apply to subset wrappers!"
        assert source.unwrap_self == dest.unwrap_self, "Cannot move between subsets of different dataset"
        # get all the sample ids first, or else it breaks
        real_ids = source.get_real_sample_ids(source_index)
        source.removes(real_ids)
        dest.adds(real_ids)

        return real_ids


class RandomSubsetWrapper(SubsetWrapper):
    def __init__(self, ds, setlen, seed=42):
        if isinstance(setlen, float):
            setlen = int(len(ds)*setlen) # calculate the number of samples necessary
        
        rs = np.random.RandomState(seed=seed)
        randomsubset = set(rs.choice(len(ds), setlen, replace=False))
        super(RandomSubsetWrapper, self).__init__(ds, subset=randomsubset)


class ConsistentlyShuffledDataset(tDatasetWrapper):
    def __init__(self, original_data: tDataset, rng_seed: int = 42):
        super(ConsistentlyShuffledDataset, self).__init__(original_data)

        # generate the dataset index as per the seed
        rng_state = np.random.RandomState(seed=rng_seed)
        self.indices = rng_state.choice(np.arange(0, len(self.data), dtype=np.uint64), size=len(self.data), replace=False) # sample indices without replacement into an array

    def __getitem__(self, item):
        return self.unwrap_self()[self.indices[item].item()]

    def __len__(self):
        return len(self.unwrap_self())


class EncodedDatasetWrapper(tDatasetWrapper):
    def __init__(self, original_data: tDataset):
        super(EncodedDatasetWrapper, self).__init__(original_data)

        # self.unwrapped = original_data
        self.transformed_data = None
        # if the encoder is a tuple, the 'encoder' lambda should be aware of that
        self.encoder = None

    def encode_data(
        self,
        encoder: Callable
    ):
        self.transformed_data = [
            encoder(d) for d in self.unwrapped
        ]

    def encode_data_lazy(
        self,
        encoder: Callable
    ):
        self.encoder = encoder

    def __getattr__(self, item):
        if self.transformed_data:
            return self.transformed_data[item]
        elif self.encoder_net:
            return self.encoder(self.unwrapped[item])
        else:
            raise AttributeError('The data is neither statically transformed nor a transformation function is specified.')

    def __len__(self):
        if self.transformed_data:
            return len(self.transformed_data)
        elif self.encoder_net:
            return len(self.unwrapped)
        else:
            raise AttributeError(
                'The data is neither statically transformed nor a transformation function is specified.')


class CachingDatasetWrapper(tDatasetWrapper):
    def __init__(self, original_data: tDataset, device = None, dtype = None, max_cache_size = -1, verbose=False):
        super(CachingDatasetWrapper, self).__init__(original_data)

        # self.unwrapped = original_data
        self.cache = {}
        # self.overflow_list = []
        # self.max_size = max_cache_size
        self.device = device
        self.dtype = dtype

    def __len__(self):
        return len(self.unwrap_self())

    def __getitem__(self, item):
        if item not in self.cache.keys():
            # fetch the entry, convert to the required dtype and store on the required device
            dentry = self.unwrap_self()[item]
            if isinstance(dentry, torch.Tensor):
                self.cache[item] = dentry.to(device=self.device, dtype=self.dtype)
            elif isinstance(dentry, Mapping):
                raise NotImplementedError()
            elif isinstance(dentry, Iterable):
                self.cache[item] = [e.to(device=self.device, dtype=self.dtype) if isinstance(e, torch.Tensor) else e for e in dentry]
            else:
                raise NotImplementedError()
        return self.cache[item]


import pickle


class CheckpointingWrapper(tDatasetWrapper):
    def __init__(self, ds, path):
        super(CheckpointingWrapper, self).__init__(ds)
        if self.unwrap_self is not None:
            # self.items = {i: self.unwrap_self[i] for i in range(len(self.unwrap_self))} # TODO
            self.items = [item for item in self.unwrap_self]
            self.dump_checkpoint(path)
        else:
            self.load_checkpoint(path)

    def dump_checkpoint(self, path):
        with open(path, 'wb') as f:
            pickle.dump(self.items, file=f)
        
    def load_checkpoint(self, path):
        with open(path, 'rb') as f:
            self.items = pickle.load(f)
        if isinstance(self.items, dict):
            self.items = list(self.items.values())
    
    def __getitem__(self, item):
        return copy(self.items[item])  # to prevent the subsequent wrappers from modifying the original data

    def __len__(self):
        return len(self.items)


class CondSubsetWrapper(SubsetWrapper):
    '''Select subset based on a callable filter function'''
    def __init__(self, ds: tDataset, filt_fn: Callable):
        # perform filtering 
        idxs = [i for i in range(len(ds)) if filt_fn(ds[i])]
        super(CondSubsetWrapper, self).__init__(ds, set(idxs))


from concurrent.futures import ThreadPoolExecutor

class CondSubsetWrapperParallel(SubsetWrapper):
    '''Select subset based on a callable filter function'''
    def __init__(self, ds: tDataset, filt_fn: Callable, n_threads = 8):
        with ThreadPoolExecutor(max_workers=n_threads) as tp:
            res = tp.map(filt_fn, chunksize=len(ds)//(n_threads*2))
            idxs = [r.results() for r in res]
        # perform filtering 
        super(CondSubsetWrapperParallel, self).__init__(ds, set(idxs))


#############################
####
#############################


class CachedDataLoader(tDataLoader):
    def __init__(self, dl: tDataLoader, device = None, dtype = None, verbose=False):
        self.unwrapped = dl
        self.device = device
        self.dtype = dtype
        self.cache = [
            any_to_device_nested(e, respect_floats=True, device=device, dtype=dtype) 
            for e in (tqdm(self.unwrapped) if verbose else self.unwrapped)
        ]

    def __iter__(self):
        return iter(self.cache)

    def __getattr__(self, item):
        return self.unwrapped.__getattribute__(item)


class InfiniteDataLoaderWrapper(tDataLoader):
    '''
    DO NOT STACK BEFORE CACHING WRAPPER!
    '''
    def __init__(self, source_dl: tDataLoader, max_n = -1):
        self.dl = source_dl
        self.max_n = max_n

    def __iter__(self):
        return self.gen()

    def __len__(self):
        return self.max_n

    def gen(self):
        iters_made = 0
        while True:
            for batch in self.dl:
                if iters_made < self.max_n:
                    yield [b for b in batch]
                    iters_made += 1
                else:
                    return


class IndexedDatasetWrapper(tDataset):
    def __init__(self, original_data: tDataset):
        self.unwrapped = original_data

    def __len__(self):
        return len(self.unwrapped)

    def __getitem__(self, item):
        # get the item from the parent dataset
        dpoint = self.unwrapped[item]

        if isinstance(dpoint, List) or isinstance(dpoint, Tuple):
            return list(dpoint)+[item,]
        elif isinstance(dpoint, Dict):
            dpoint['ds_index'] = item
            return dpoint
        else:
            raise NotImplementedError("only datasets providing dictionaries or lists/tuples are supported")



from collections import Counter


def count_label_balance(dataset):
    """
    Count the label balance in a Dataset class.

    Args:
        dataset (Dataset): The dataset to count the label balance for.

    Returns:
        Counter: A Counter object with the label values as keys and their corresponding counts as values.
    """
    labels = []
    for item in dataset:
        if isinstance(item, tuple):
            label = item[-1]
        elif isinstance(item, dict):
            label = item.get('target', item.get('label'))
            if label is None:
                raise ValueError("Dictionary item must have either 'target' or 'label' key")
        else:
            raise ValueError("Unsupported item type")
        if not isinstance(label, int):
            raise ValueError("Label must be an integer")
        labels.append(label)
    return Counter(labels)


def stratified_sample_subset(dataset, n, seed=None):
    """
    Pick n samples of each class from a PyTorch Dataset in a stratified fashion.

    Args:
        dataset (torch.utils.data.Dataset): The dataset to sample from.
        n (int): The number of samples to pick from each class.
        seed (int): seed for shuffling (optional, no shuffling if not specified)
    Returns:
        torch.utils.data.Subset: A subset of the original dataset with the sampled indices.
    """
    class_indices = {}
    for i, item in enumerate(dataset):
        if isinstance(item, tuple):
            label = item[-1]
        elif isinstance(item, dict):
            label = item.get('target', item.get('label'))
        if label not in class_indices:
            class_indices[label] = []
        class_indices[label].append(i)

    # Shuffle the lists of indices for each class, make it reproducible
    if seed is not None:
        for label in class_indices.keys():
            state = np.random.RandomState(seed=seed)
            state.shuffle(class_indices[label])

    # Pick top n samples from each class and merge them into one
    sampled_indices = []
    for label in class_indices:
        sampled_indices.extend(class_indices[label][:n])

    return SubsetWrapper(dataset, sampled_indices)
