from argparse import Namespace
from typing import Tuple
import logging
from copy import deepcopy

import numpy as np
import torch.nn as nn
import torch.optim
from torch.utils.data import DataLoader, Dataset, ConcatDataset, Subset
from .datasets import (
    KITCHEN, AVE, DKD, UESTC_MMEA
)
from .samplers import MemoryMultiplerSampler

logger = logging.getLogger(__name__)


# +
class IncrementalDataset:
    """
    Continual learning evaluation setting.
    """
    def __init__(self, args, run_id):
        """
        Initializes the train and test lists of dataloaders.
        :param args: the arguments which contains the hyperparameters
        """
        self.args = args
        self.task = 0
        self.seen = 0
        self.dataset = _get_dataset(self.args["dataset"])
        self.run_id = run_id

        self.type = self.dataset.TYPE
        self.n_classes = self.dataset.N_CLASSES
        
        self.n_tasks = self.dataset.N_TASKS
        self.n_classes_per_task = self.dataset.N_CLASSES_PER_TASK
        
        if self.type == 'CIL':
            logger.info(f'{self.type} learning order: {self.dataset.CLASS_ORDER[run_id]}')
        else:
            logger.info(f'{self.type} learning order: {self.dataset.DOMAIN_ORDER[run_id]}')
        self.memory_size = args.get("memory_size", 0)
        self.exemplar_dataset = None
        self.exemplar_sampler = MemoryMultiplerSampler
                        
    def new_task(self):
        train_dataset = self.dataset(self.args, task_id=(self.run_id, self.task), train=True, augment=True)
        test_train_dataset = self.dataset(self.args, task_id=(self.run_id, self.task), train=True, augment=False)
        self.cur_train_dataset = deepcopy(train_dataset)
        
        if self.args['all_test_data']:
            if self.task == 0:
                test_dataset = []
                for t in range(self.n_tasks):
                    test_dataset.append(self.dataset(self.args, task_id=(self.run_id, t), train=False, augment=False))
                test_dataset = ConcatDataset(test_dataset)  
                self.cur_test_dataset = test_dataset
            else:
                test_dataset = self.cur_test_dataset
        else:
            test_dataset = self.dataset(self.args, task_id=(self.run_id, self.task), train=False, augment=False) 
            if self.task > 0:
                test_dataset = ConcatDataset([self.cur_test_dataset, test_dataset])
            self.cur_test_dataset = test_dataset

        sampler = None
        if isinstance(self.exemplar_dataset, list): 
            for ed in self.exemplar_dataset:
                train_dataset.join(ed)
            
            if self.exemplar_sampler is not None:
                sampler = self.exemplar_sampler(
                    train_dataset.task_id, 
                    batch_size=self.args["batch_size"], 
                    cur_task=self.task,
                    mem_multiplier=self.args.get("memory_sample_batch", 1),
                    drop_last=True
                )
        
        train_loader = DataLoader(
            train_dataset, 
            batch_size=self.args["batch_size"] if sampler is None else 1, 
            shuffle=True if sampler is None else False, 
            num_workers=self.args["workers"],
            batch_sampler=sampler,
            drop_last=True if sampler is None else False
        )
        test_loader = DataLoader(test_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["workers"])
        test_train_loader = DataLoader(test_train_dataset, batch_size=self.args["batch_size"], shuffle=False, num_workers=self.args["workers"])
            
        task_info = {'task': self.task, 'max_class': train_dataset.targets.max()+1}        
        self.task += 1
        
        return task_info, train_loader, test_loader, test_train_loader
            
    def get_cur_train_loader(self, shuffle=False, num_workers=0, drop_last=False):
        return DataLoader(self.cur_train_dataset, batch_size=self.args["batch_size"], shuffle=shuffle, num_workers=num_workers, drop_last=drop_last)
        
    def get_cur_exemplar_loader(self, batch_size=None, drop_last=True, shuffle=False):
        if isinstance(self.exemplar_dataset, list):
            dataset = ConcatDataset(self.exemplar_dataset)
        else:
            dataset = self.exemplar_dataset
        
        if batch_size is None:
            batch_size = self.args["batch_size"]
        batch_size = min(batch_size, len(dataset))
        return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=self.args["workers"])
        
    def update_cur_train_predictions(self, cur_train_predictions):
        assert len(self.cur_train_dataset) == len(cur_train_predictions)
        self.cur_train_dataset.past_logits = cur_train_predictions
        
    def update_exemplar(self, cur_dataset_save_index=None):
        ttl_exemplars = 0
                
        if self.type == 'DIL':
            memory_per_task = self.memory_size // self.task
            memory_per_class = memory_per_task // self.n_classes
        elif self.type == 'CIL':
            memory_per_class = self.memory_size // (self.task * self.n_classes_per_task)
        
        if self.exemplar_dataset is not None:
            for i, ex in enumerate(self.exemplar_dataset):
                labels = ex.targets            
                new_indices = []
                for cls in range(self.n_classes):
                    cls_inds = np.where(labels==cls)[0]
                    if len(cls_inds) > 0:
                        new_indices.append(cls_inds[:memory_per_class])

                self.exemplar_dataset[i].resample(np.concatenate(new_indices))
                ttl_exemplars += len(np.concatenate(new_indices))
                
        labels = self.cur_train_dataset.targets         
        new_indices = []
        for cls in range(self.n_classes):
            cls_inds = np.where(labels==cls)[0]
            if len(cls_inds) == 0:
                continue
            
            if cur_dataset_save_index is not None:
                overlap = np.intersect1d(cur_dataset_save_index, cls_inds)
                np.random.shuffle(overlap)
                
                cls_inds = np.setdiff1d(cls_inds, overlap)
                np.random.shuffle(cls_inds)

                cls_inds = np.concatenate([overlap, cls_inds])
            else:
                np.random.shuffle(cls_inds)
                
            new_indices.append(cls_inds[:memory_per_class])
            
        self.cur_train_dataset.resample(np.concatenate(new_indices))
        ttl_exemplars += len(np.concatenate(new_indices))
        
        if self.exemplar_dataset is None:
            self.exemplar_dataset = [deepcopy(self.cur_train_dataset)]
        else:
            self.exemplar_dataset.append(self.cur_train_dataset) 
        
        logger.info(f"No. of exemplars: {ttl_exemplars}")
        
    def update_exemplar_by_herding(self, cur_features, cur_targets):
        from inclearn.lib import herding
        
        ttl_exemplars = 0
                
        if self.type == 'DIL':
            memory_per_task = self.memory_size // self.task
            memory_per_class = memory_per_task // self.n_classes
        elif self.type == 'CIL':
            memory_per_class = self.memory_size // (self.task * self.n_classes_per_task)
        
        if self.exemplar_dataset is not None:
            for i, ex in enumerate(self.exemplar_dataset):
                labels = ex.targets            
                new_indices = []
                for cls in range(self.n_classes):
                    cls_inds = np.where(labels==cls)[0]
                    if len(cls_inds) > 0:
                        new_indices.append(cls_inds[:memory_per_class])

                self.exemplar_dataset[i].resample(np.concatenate(new_indices))
                ttl_exemplars += len(np.concatenate(new_indices))
                
        new_indices = []
        for cls in range(self.task * self.n_classes_per_task):
            cls_inds = np.where(cur_targets==cls)[0]
            if len(cls_inds) == 0:
                continue
            
            cls_features = cur_features[cls_inds]
            cls_selected_inds = herding.icarl_selection(cls_features, memory_per_class)
            new_indices.append(cls_inds[cls_selected_inds])
            
        self.cur_train_dataset.resample(np.concatenate(new_indices))
        ttl_exemplars += len(np.concatenate(new_indices))
        
        if self.exemplar_dataset is None:
            self.exemplar_dataset = [deepcopy(self.cur_train_dataset)]
        else:
            self.exemplar_dataset.append(self.cur_train_dataset) 
        
        logger.info(f"No. of exemplars: {ttl_exemplars}")
        


# -

def _get_dataset(dataset_name):
    dataset_name = dataset_name.lower().strip()

    if dataset_name == "kitchen":
        return KITCHEN
    elif dataset_name == "dkd":
        return DKD
    elif dataset_name == "ave":
        return AVE
    elif dataset_name == "uestc_mmea":
        return UESTC_MMEA
    else:
        raise NotImplementedError("Unknown dataset {}.".format(dataset_name))
