import os
import random

import numpy as np

try:
    from . import utils
    from . import generator
except:
    import sys
    sys.path.append(os.path.abspath(__file__))
    import utils
    import generator

def default_collate(data_list):
    input_list = []
    label_list = []

    for inp, label in data_list:
        input_list.append(inp)
        label_list.append(label)

    inputs = np.stack(input_list, axis=0)
    labels = np.stack(label_list, axis=0)
    return inputs, labels

class TaskContinuumDataset:
    def __init__(self, raw_data, transform, num_tasks, seed, root_path=utils.root_path, 
                 shuffle_tasks=True, train_samples_per_task=None, data_type=None):
        self.root_path = root_path
        self.transform = transform       
        if not utils.exists_continuum_data(self.root_path,
                                       raw_data,
                                       transform,
                                       num_tasks,
                                       seed):
            arguments = generator.conver_to_argument(root_path=self.root_path,
                                                     raw_data=raw_data,
                                                     transform=transform,
                                                     num_tasks=num_tasks,
                                                     seed=seed)
            generator.build_dataset(arguments)
        
        # Load a continuum dataset
        data = utils.get_continuum_data(self.root_path,
                                        raw_data,
                                        transform,
                                        num_tasks,
                                        seed)
        
        if data_type is None:
            train_task_continuum = utils.decode_data(data, 'train')
            test_task_continuum = utils.decode_data(data, 'test')
        else:
            raise ValueError("Not implemented yet.")

        if len(train_task_continuum) != len(test_task_continuum):
            raise ValueError("Wrong dataset, the numbers of train and test tasks are not matched.")
       
        # Set the length of data for every task and Shuffle the order of tasks and the data of each task    
        if shuffle_tasks:
            permutation = list(range(len(train_task_continuum)))
            random.shuffle(permutation)

            train_task_continuum = [train_task_continuum[i] for i in permutation]
            test_task_continuum = [test_task_continuum[i] for i in permutation]

        if train_samples_per_task is not None:
            if train_samples_per_task > len(train_task_continuum[0][1]):
                raise ValueError("a longer value of samples")
            else:
                for task in train_task_continuum:
                    task[1] = task[1][:train_samples_per_task]
                    task[2] = task[2][:train_samples_per_task]
        """
        for tr, te in zip(train_task_continuum, test_task_continuum):
            # sanity check. Not implemented yet.
            # Check transform, train data, and test data
            pass
        """
        self.train_tasks = train_task_continuum
        self.test_tasks = test_task_continuum
        
    def __getitem__(self, idx):
        train_task = self.train_tasks[idx]
        test_task = self.test_tasks[idx]

        transform = train_task[0]

        train_img = train_task[1]
        train_target = train_task[2]
        test_img = test_task[1]
        test_target = test_task[2]

        train_data = [(img, tgt) for img, tgt in zip(train_img, train_target)]
        test_data = [(img, tgt) for img, tgt in zip(test_img, test_target)]

        return transform, train_data, test_data


    def __len__(self):
        return len(self.train_tasks)

# Dataloader for each task
class RandomSampler:
    def __init__(self, data_source, num_samples=None):
        self.data_source = data_source
        self._num_samples = num_samples

    def num_samples(self):
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self):
        n = len(self.data_source)
        rand_seq = list(range(n))
        random.shuffle(rand_seq)
        yield from rand_seq

    def __len__(self):
        return self.num_samples()


class BatchSampler:
    def __init__(self, sampler, batch_size, drop_last):
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size


class DataLoader:
    def __init__(self, dataset, batch_size=1, sampler='random', drop_last=False, collate_fn=default_collate):
        self.dataset = dataset
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.collate_fn = collate_fn

        if sampler == 'random':
            sampler = RandomSampler(dataset)
        else:
            raise ValueError("No other implementation except for random sampler")

        batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler
 
    def __iter__(self):
        return _DataLoaderIterator(self)

    def __len__(self):
        return len(self.batch_sampler)


class _DataLoaderIterator:

    def __init__(self, loader):
        self.dataset = loader.dataset
        self.batch_sampler = loader.batch_sampler
        self.collate_fn = loader.collate_fn

        self.sample_iter = iter(self.batch_sampler)

        # base seed

    def __len__(self):
        return len(self.batch_sampler)

    def __next__(self):
        indices = next(self.sample_iter) # may raise stopiteration
        batch = self.collate_fn([self.dataset[i] for i in indices])
        return batch

    def __iter__(self):
        return self

     
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('--i', default='raw', help='raw datasets for continual learning')
    parser.add_argument('--num_tasks', default=3, type=int, help='number of different tasks')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--raw_data', default='mnist', help='raw dataset for using continual learning')
    parser.add_argument('--transform', default='rotation', help='transform transformation')
    parser.add_argument('--min_angle', default=0, type=float, help='minimum rotation angle')
    parser.add_argument('--max_angle', default=90, type=float, help='maximum rotation angle')

    parser.add_argument('--batch_size', default=128, type=int, help="batch size")
    parser.add_argument('--shuffle_tasks', default=True, type=bool, help="shuffle tasks")

    args = parser.parse_args()
    """
    datareader = DataLoader(args, 'train')
    print(datareader)
    print(type(datareader.data))
    print(type(datareader.data[0]), len(datareader.data))
    print(datareader.data[0][1].shape, datareader.data[0][2].shape)
    """
    dataset = TaskContinuumDataset(args.raw_data, args.transform, 
                                   args.num_tasks, args.seed,
                                   shuffle_tasks=True, train_samples_per_task=20000)

    for tf, tr_data, te_data in dataset:
        print(tf)
        train_loader = DataLoader(tr_data, batch_size=128)
        for epoch in range(5):
            for i, data in enumerate(train_loader, 0):
                inputs, labels = data
                if i == 20:
                    print(np.argmax(labels, axis=-1))

    