from abc import abstractmethod
from argparse import Namespace
from torch import nn as nn

import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler

from torchvision import datasets
import numpy as np
from torch.utils.data._utils.collate import default_collate


class ContinualDataset:
    """
    Continual learning evaluation setting.
    """
    NAME = None
    SETTING = None
    N_CLASSES_PER_TASK = None
    N_TASKS = None
    TRANSFORM = None

    def __init__(self, args: Namespace):
        """
        Initializes the train and test lists of dataloaders.
        :param args: the arguments which contains the hyperparameters
        """
        self.train_loader = None
        self.validation_loaders = []
        self.train_loaders = []
        self.i = 0
        self.args = args

    @abstractmethod
    def get_data_loaders(self):
        """
        Creates and returns the training and test loaders for the current task.
        The current training loader and all test loaders are stored in self.
        :return: the current training and test loaders
        """
        pass

    @abstractmethod
    def not_aug_dataloader(self, batch_size: int):
        """
        Returns the dataloader of the current task,
        not applying data augmentation.
        :param batch_size: the batch size of the loader
        :return: the current training loader
        """
        pass

    @staticmethod
    @abstractmethod
    def get_backbone():
        """
        Returns the backbone to be used for to the current dataset.
        """
        pass

    @staticmethod
    @abstractmethod
    def get_transform():
        """
        Returns the transform to be used for to the current dataset.
        """
        pass

    @staticmethod
    @abstractmethod
    def get_loss():
        """
        Returns the loss to be used for to the current dataset.
        """
        pass

    @staticmethod
    @abstractmethod
    def get_normalization_transform():
        """
        Returns the transform used for normalizing the current dataset.
        """
        pass

    @staticmethod
    @abstractmethod
    def get_denormalization_transform():
        """
        Returns the transform used for denormalizing the current dataset.
        """
        pass

def collate_fn(batch):
    if not isinstance(batch[0][0], tuple):
        return default_collate(batch)
    else:
        batch_num = len(batch)
        ret = []
        for item_idx in range(len(batch[0][0])):
            if batch[0][0][item_idx] is None:
                ret.append(None)
            else:
                ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
        ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
        return ret

def store_masked_loaders(train_dataset: datasets, validation_dataset: datasets, 
                    setting: ContinualDataset, labels_to_paths=False):
    
    train_mask = np.logical_and(np.array(train_dataset.targets) >= setting.i,
        np.array(train_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK)
    validation_mask = np.logical_and(np.array(validation_dataset.targets) >= setting.i,
        np.array(validation_dataset.targets) < setting.i + setting.N_CLASSES_PER_TASK)
    train_dataset.samples = np.array(train_dataset.samples)[train_mask]
    validation_dataset.samples = np.array(validation_dataset.samples)[validation_mask]
    
    if labels_to_paths:
        train_dataset.samples[:,1] = train_dataset.samples[:,0]
        validation_dataset.samples[:,1] = validation_dataset.samples[:,0]
 
    train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
    train_loader = DataLoader(train_dataset, batch_size=setting.args.batch_size, sampler=train_sampler, 
                                            num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn)
    
    validation_sampler = DistributedSampler(validation_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
    validation_loader = DataLoader(validation_dataset, batch_size=setting.args.batch_size, sampler=validation_sampler, 
                                            num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate_fn)
        
    setting.train_loaders.append(train_loader)
    setting.train_loader = train_loader
    setting.validation_loaders.append(validation_loader)

    setting.i += setting.N_CLASSES_PER_TASK
    return train_loader, validation_loader

