from abc import abstractmethod
from argparse import Namespace
from torch import nn as nn
#from torchvision.transforms import transforms
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
#from typing import Tuple
from torchvision import datasets
import numpy as np
from torch.utils.data._utils.collate import default_collate


class ContinualDataset:
    """
    Continual learning evaluation setting.
    """
    NAME = None
    SETTING = None
    N_CLASSES_PER_TASK = None
    N_TASKS = None
    TRANSFORM = None

    def __init__(self, args: Namespace):
        """
        Initializes the train and test lists of dataloaders.
        :param args: the arguments which contains the hyperparameters
        """
        self.train_loader = None
        self.test_loaders = []
        self.memory_loaders = []
        self.train_loaders = []
        self.i = 0
        self.args = args

    @abstractmethod
    def get_data_loaders(self):
        """
        Creates and returns the training and test loaders for the current task.
        The current training loader and all test loaders are stored in self.
        :return: the current training and test loaders
        """
        pass

    @abstractmethod
    def not_aug_dataloader(self, batch_size: int):
        """
        Returns the dataloader of the current task,
        not applying data augmentation.
        :param batch_size: the batch size of the loader
        :return: the current training loader
        """
        pass

    @staticmethod
    @abstractmethod
    def get_backbone():
        """
        Returns the backbone to be used for to the current dataset.
        """
        pass

    @staticmethod
    @abstractmethod
    def get_transform():
        """
        Returns the transform to be used for to the current dataset.
        """
        pass

    @staticmethod
    @abstractmethod
    def get_loss():
        """
        Returns the loss to be used for to the current dataset.
        """
        pass

    @staticmethod
    @abstractmethod
    def get_normalization_transform():
        """
        Returns the transform used for normalizing the current dataset.
        """
        pass

    @staticmethod
    @abstractmethod
    def get_denormalization_transform():
        """
        Returns the transform used for denormalizing the current dataset.
        """
        pass

def collate_fn(batch):
    if not isinstance(batch[0][0], tuple):
        return default_collate(batch)
    else:
        batch_num = len(batch)
        ret = []
        for item_idx in range(len(batch[0][0])):
            if batch[0][0][item_idx] is None:
                ret.append(None)
            else:
                ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
        ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
        return ret

def store_masked_loaders(train_dataset: datasets, validation_dataset: datasets, test_dataset: datasets,
                    setting: ContinualDataset, labels_to_paths=False):
    """
    Divides the dataset into tasks.
    :param train_dataset: train dataset
    :param test_dataset: test dataset
    :param setting: continual learning setting
    :return: train and test loaders
    """
    
    if setting.args.dataset.name == 'seq-imagenet':            
        train_mask = np.logical_and(np.array(train_dataset.targets) >= setting.i,
            np.array(train_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK)
        test_mask = np.logical_and(np.array(test_dataset.targets) >= setting.i,
            np.array(test_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK)
        train_dataset.samples = np.array(train_dataset.samples)[train_mask]
        test_dataset.samples = np.array(test_dataset.samples)[test_mask]
    
        if labels_to_paths:
            train_dataset.samples[:,1] = train_dataset.samples[:,0]
            test_dataset.samples[:,1] = test_dataset.samples[:,0]

        elif setting.args.cl_model == 'DER':
            train_dataset.samples = train_dataset.samples.astype('<U80')
            combined_train_labels = list(map(lambda X: f'{X[0]},{X[1]}', list(zip(train_dataset.samples[:,0].tolist(), train_dataset.samples[:,1].tolist()))))
            train_dataset.samples[:,1] = combined_train_labels
            # combined_test_labels = list(map(lambda X: f'{X[0]},{X[1]}', list(zip(test_dataset.samples[:,0].tolist(), test_dataset.samples[:,1].tolist()))))
            # test_dataset.samples[:,1] = combined_test_labels
            # train_dataset.samples[:,1] = train_dataset.samples[:0]+train_dataset.samples[:1]
            # test_dataset.samples[:,1] = test_dataset.samples[:,0]
        else:
            pass
    else:
        train_mask = np.logical_and(np.array(train_dataset.targets) >= setting.i,
            np.array(train_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK)
        test_mask = np.logical_and(np.array(test_dataset.targets) >= setting.i,
            np.array(test_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK)

        train_dataset.data = train_dataset.data[train_mask]
        test_dataset.data = test_dataset.data[test_mask]
        
        train_dataset.targets = np.array(train_dataset.targets)[train_mask]
        test_dataset.targets = np.array(test_dataset.targets)[test_mask]


        if setting.args.few_shot:
            train_dataset.data = train_dataset.data[:setting.args.data_size_per_task]
            train_dataset.targets = train_dataset.targets[:setting.args.data_size_per_task]
 
 
    # import pdb; pdb.set_trace()
    train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
    train_loader = DataLoader(train_dataset, batch_size=setting.args.batch_size, sampler=train_sampler, 
                                            num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn)
    
    test_sampler = DistributedSampler(test_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=setting.args.batch_size, sampler=test_sampler, 
                                            num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn)
    
    """
    train_loader = DataLoader(train_dataset,
                              batch_size=setting.args.batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset,
                             batch_size=setting.args.batch_size, shuffle=False, num_workers=4)
    """
    
    if validation_dataset is None:
        validation_loader = None
    else:
        validation_dataset.data = validation_dataset.data[train_mask]
        validation_dataset.targets = np.array(validation_dataset.targets)[train_mask]
        """
        validation_loader = DataLoader(validation_dataset,
                              batch_size=setting.args.batch_size, shuffle=False, num_workers=4)
        """
        validation_sampler = DistributedSampler(validation_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
        validation_loader = DataLoader(validation_dataset, setting.args.batch_size, sampler=validation_sampler, num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn)
        
        setting.memory_loaders.append(validation_loader)

    setting.train_loaders.append(train_loader)
    setting.train_loader = train_loader
    setting.test_loaders.append(test_loader)

    setting.i += setting.N_CLASSES_PER_TASK
    return train_loader, validation_loader, test_loader


def store_masked_label_loaders(train_dataset: datasets, memory_dataset: datasets, test_dataset: datasets,
                    setting: ContinualDataset):
    """
    Divides the dataset into tasks.
    :param train_dataset: train dataset
    :param test_dataset: test dataset
    :param setting: continual learning setting
    :return: train and test loaders
    """
    train_mask = np.logical_and(np.array(train_dataset.labels) >= setting.i,
        np.array(train_dataset.labels) < setting.i + setting.N_CLASSES_PER_TASK)
    test_mask = np.logical_and(np.array(test_dataset.labels) >= setting.i,
        np.array(test_dataset.labels) < setting.i + setting.N_CLASSES_PER_TASK)

    train_dataset.data = train_dataset.data[train_mask]
    test_dataset.data = test_dataset.data[test_mask]

    train_dataset.targets = np.array(train_dataset.labels)[train_mask]
    test_dataset.targets = np.array(test_dataset.labels)[test_mask]

    memory_dataset.data = memory_dataset.data[train_mask]
    memory_dataset.targets = np.array(memory_dataset.labels)[train_mask]

    train_loader = DataLoader(train_dataset,
                              batch_size=setting.args.batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset,
                             batch_size=setting.args.batch_size, shuffle=False, num_workers=4)
    memory_loader = DataLoader(memory_dataset,
                              batch_size=setting.args.batch_size, shuffle=False, num_workers=4)

    setting.test_loaders.append(test_loader)
    setting.train_loaders.append(train_loader)
    setting.memory_loaders.append(memory_loader)
    setting.train_loader = train_loader

    setting.i += setting.N_CLASSES_PER_TASK
    return train_loader, memory_loader, test_loader


def get_previous_train_loader(train_dataset: datasets, batch_size: int,
                              setting: ContinualDataset):
    """
    Creates a dataloader for the previous task.
    :param train_dataset: the entire training set
    :param batch_size: the desired batch size
    :param setting: the continual dataset at hand
    :return: a dataloader
    """
    train_mask = np.logical_and(np.array(train_dataset.targets) >=
        setting.i - setting.N_CLASSES_PER_TASK, np.array(train_dataset.targets)
        < setting.i - setting.N_CLASSES_PER_TASK + setting.N_CLASSES_PER_TASK)

    train_dataset.data = train_dataset.data[train_mask]
    train_dataset.targets = np.array(train_dataset.targets)[train_mask]

    return DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
