import logging
import random

import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms

from .datasets import (
    ImageNet100, ImageNet1000, MiniImageNet, iCIFAR10, iCIFAR100, iCUB200
)

logger = logging.getLogger(__name__)


class IncrementalDataset:
    """Incremental generator of datasets.

    :param dataset_name: Among a list of available dataset, that can easily
                         be defined (see at file's end).
    :param random_order: Shuffle the class ordering, else use a cherry-picked
                         ordering.
    :param shuffle: Shuffle batch order between epochs.
    :param workers: Number of workers loading the data.
    :param batch_size: The batch size.
    :param seed: Seed to force determinist class ordering.
    :param increment: Number of class to add at each task.
    :param validation_split: Percent of training data to allocate for validation.
    :param onehot: Returns targets encoded as onehot vectors instead of scalars.
                   Memory is expected to be already given in an onehot format.
    :param initial_increment: Initial increment may be defined if you want to train
                              on more classes than usual for the first task, like
                              UCIR does.
    """

    def __init__(
        self,
        dataset_name,
        run_id=0,
        random_order=False,
        shuffle=True,
        workers=10,
        batch_size=128,
        seed=1,
        validation_split=0.,
        block_split='100',
        onehot=False,
        sampler=None,
        sampler_config=None,
        data_path="data",
        dataset_transforms=None,
        all_test_domain=False,
        metadata_path=None
    ):
        dataset = _get_dataset(dataset_name)
        if metadata_path:
            print("Adding metadata path {}".format(metadata_path))
            dataset.metadata_path = metadata_path

        self._setup_data(dataset, data_path=data_path)

        dataset.set_custom_transforms(dataset_transforms)
        self.train_transforms = dataset.train_transforms  # FIXME handle multiple datasets
        self.test_transforms = dataset.test_transforms
        self.common_transforms = dataset.common_transforms
        self.two_transform = dataset.two_transform
        
        self.open_image = datasets[0].open_image
        
        self.n_tasks = dataset.n_tasks
        self.n_classes = dataset.n_classes

        self._current_task = 0

        self._seed = seed
        self._batch_size = batch_size
        self._workers = workers
        self._shuffle = shuffle
        self._onehot = onehot
        self._sampler = sampler
        self._sampler_config = sampler_config
        self._all_test_domain = all_test_domain
        self._block_split = list(map(lambda x: float(x), block_split.split('-')))
        self._block_split = list(map(lambda x: x/100, self._block_split))
        
    def new_task(self):
        if self._current_task >= self.n_tasks:
            raise Exception("No more tasks.")
        
        idx_train = []
        for i in range(0, len(self._block_split)):
            task_id = self._current_task - len(self._block_split) + i + 1
            if task_id < 0:
                continue

            idx_train_i = self._select(
                self.domain_train, 
                low_range=task_id, high_range=task_id+1, 
                block_id=self._current_task - task_id  
            )

            idx_train.append(idx_train_i)

        self.cur_idx_train = np.concatenate(idx_train)

        x_train = self.data_train[self.cur_idx_train] 
        y_train = self.targets_train[self.cur_idx_train] 
        z_train = self.domain_train[self.cur_idx_train]
        flags_train = self.flags_train[self.cur_idx_train]

        if self._all_test_domain is True:
            logger.info("Testing on all domains!")
            self.cur_idx_test = self._select(self.domain_test, high_range=self.n_tasks, block_id=None)
        else:
            self.cur_idx_test = self._select(self.domain_test, high_range=self._current_task+1, block_id=None)
        x_test = self.data_test[self.cur_idx_test] 
        y_test = self.targets_test[self.cur_idx_test]
        z_test = self.domain_test[self.cur_idx_train]
        flags_test = self.flags_test[self.cur_idx_test] 

        if self._onehot:
            def to_onehot(x):
                n = np.max(x) + 1
                return np.eye(n)[x]

            y_train = to_onehot(y_train)

        train_loader = self._get_loader(x_train, y_train, z_train, flags_train, mode="test", shuffle=False)
        test_loader = self._get_loader(x_test, y_test, z_test, flags_test, mode="test", shuffle=False)

        task_info = {
            "n_classes": self.n_classes,
            "task": self._current_task,
            "max_task": self.n_tasks,
            "n_train_data": x_train.shape[0],
            "n_test_data": x_test.shape[0]
        }

        self._current_task += 1

        return task_info, train_loader, test_loader

    def update_train_loader(self, label_indices, data_memory=None, targets_memory=None):
        logger.info("Number of labelled samples (Before): {}.".format((self.flags_train[self.cur_idx_train]==1).sum()))
        if label_indices is not None and len(label_indices) != 0:
            self.flags_train[self.cur_idx_train[label_indices]] = 1
        logger.info("Number of labelled samples (After): {}.".format((self.flags_train[self.cur_idx_train]==1).sum()))
        return self.get_custom_loader(data_memory=data_memory, targets_memory=targets_memory, shuffle=True, mode="train")[1]    
    
    def get_memory_loader(self, data, targets):
        return self._get_loader(
            data, targets, 1+np.ones((data.shape[0],)), shuffle=True, mode="train"
        )
    
    def get_custom_loader(self, class_indexes=None, data_memory=None, targets_memory=None, shuffle=True, mode="train", batch_size=None):
        x_train = self.data_train[self.cur_idx_train] 
        y_train = self.targets_train[self.cur_idx_train] 
        flags_train = self.flags_train[self.cur_idx_train] 
        
        labelled = flags_train != 0
        x_train = x_train[labelled]
        y_train = y_train[labelled]
        flags_train = flags_train[labelled]
        
        idx = []
        if class_indexes is not None:
            if not isinstance(class_indexes, list): 
                class_indexes = [class_indexes]
            
            for class_index in class_indexes:
                idx.append(np.where(y_train==class_index)[0])
            
            if len(idx) > 0:
                idx = np.concatenate(idx)
                x_train = x_train[idx]
                y_train = y_train[idx]
                flags_train = flags_train[idx]
            else:
                x_train = []
                y_train = []
                flags_train = []
            
            
        if data_memory is not None:
            logger.info("Set memory of size: {}.".format(data_memory.shape[0]))
            x_train, y_train, flags_train = self._add_memory(x_train, y_train, flags_train, data_memory, targets_memory)
        
        logger.info("X size: {}, Y size: {}".format(x_train.shape, y_train.shape))
        if x_train.shape[0] == 0:
            return (x_train, y_train, flags_train), None
        else:
            return (x_train, y_train, flags_train), self._get_loader(x_train, y_train, flags_train, mode=mode, shuffle=shuffle, batch_size=batch_size)       
        
    def _add_memory(self, x, y, z, data_memory, targets_memory):
        if self._onehot:  # Need to add dummy zeros to match the number of targets:
            targets_memory = np.concatenate(
                (
                    targets_memory,
                    np.zeros((targets_memory.shape[0], self.increments[self._current_task]))
                ),
                axis=1
            )

        x = np.concatenate((x, data_memory))
        y = np.concatenate((y, targets_memory))
        z = np.concatenate((z, 1+np.ones((data_memory.shape[0],))))

        return x, y, z


    def _select(self, z, low_range=0, high_range=0, block_id=0):
        idxes = []
        
        for task_id in range(low_range, high_range):
            cur_idx = np.where(z == task_id)[0]
            if block_id is not None:
                ttl_size = len(cur_idx)                
                start_idx = int(ttl_size * sum(self._block_split[:block_id]))
                end_idx = int(ttl_size * sum(self._block_split[:block_id+1]))
                idxes.append(cur_idx[start_idx:end_idx])                        
                
            else:
                idxes.append(cur_idx)
        
        idxes = np.concatenate(idxes)
        return idxes 
    
    
    def _get_loader(self, x, y, z, label_flags, shuffle=True, mode="train", sampler=None, batch_size=None):
        if "balanced" in mode:
            x, y, z, label_flags = construct_balanced_subset(x, y, z, label_flags)
            
        if "train" in mode:
            trsf = transforms.Compose([*self.train_transforms, *self.common_transforms])
        elif "test" in mode:
            trsf = transforms.Compose([*self.test_transforms, *self.common_transforms])
        elif mode == "flip":
            trsf = transforms.Compose(
                [
                    transforms.RandomHorizontalFlip(p=1.), *self.test_transforms,
                    *self.common_transforms
                ]
            )
        else:
            raise NotImplementedError("Unknown mode {}.".format(mode))

        two_transform = self.two_transform if "train" in mode else False    
        sampler = sampler or self._sampler
        if sampler is not None and "train" in mode:
            logger.info("Using sampler {}".format(sampler))
            sampler = sampler(y, label_flags, batch_size=self._batch_size, **self._sampler_config)
            batch_size = 1
        else:
            sampler = None
            if batch_size is None:
                batch_size = self._batch_size

        return DataLoader(
            DummyDataset(x, y, z, label_flags, trsf, open_image=self.open_image, two_transform=two_transform),
            batch_size=batch_size,
            shuffle=shuffle if sampler is None else False,
            num_workers=self._workers,
            batch_sampler=sampler
        )   

    def _setup_data(
        self,
        dataset,
        data_path="data",
    ):
        train_dataset = dataset().base_dataset(data_path, train=True, download=True)
        test_dataset = dataset().base_dataset(data_path, train=False, download=True)

        self.data_train, self.targets_train, self.domain_train = train_dataset.data, np.array(train_dataset.targets), np.array(train_dataset.domain)
        self.data_test, self.targets_test, self.domain_test = test_dataset.data, np.array(test_dataset.targets), np.array(test_dataset.domain)
        
        shuffled_indexes = np.random.permutation(self.data_train.shape[0])
        self.data_train = self.data_train[shuffled_indexes]
        self.targets_train = self.targets_train[shuffled_indexes]
        self.domain_train = self.domain_train[shuffled_indexes]

        self.flags_train = np.zeros(len(y_train))
        self.flags_test = np.ones(len(y_test))


class DummyDataset(torch.utils.data.Dataset):

    def __init__(self, x, y, z, label_flags, trsf, open_image=False, two_transform=False):
        self.x, self.y, self.z = x, y, z
        self.label_flags = label_flags
        self.trsf = trsf
        self.open_image = open_image
        self.two_transform = two_transform

        assert x.shape[0] == y.shape[0] == z.shape[0] == label_flags.shape[0]

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        x, y, z = self.x[idx], self.y[idx], self.z[idx]
        label_flags = self.label_flags[idx]

        if self.open_image:
            img = Image.open(x).convert("RGB")
        else:
            img = Image.fromarray(x.astype("uint8"))

        outputs = {"targets": y, "label_flags": label_flags.astype(int), "domain_id": z.astype(int)}
        outputs["inputs"] = self.trsf(img)
        
        if self.two_transform:
            outputs["inputs2"] = self.trsf(img)

        return outputs

def _get_dataset(dataset_name):
    dataset_name = dataset_name.lower().strip()

    if dataset_name == "cifar10":
        return iCIFAR10
    elif dataset_name == "cifar100":
        return iCIFAR100
    elif dataset_name == "imagenet100":
        return ImageNet100
    elif dataset_name == "imagenet1000":
        return ImageNet1000
    elif dataset_name == "miniimagenet":
        return MiniImageNet
    elif dataset_name == "cub200":
        return iCUB200
    else:
        raise NotImplementedError("Unknown dataset {}.".format(dataset_name))

def construct_balanced_subset(x, y, z, flag):
    xdata, ydata, zdata, flagdata = [], [], [], []
    minsize = np.inf
    for cls_ in np.unique(y):
        xdata.append(x[np.logical_and(y == cls_, flag != 0)])
        ydata.append(y[np.logical_and(y == cls_, flag != 0)])
        zdata.append(z[np.logical_and(y == cls_, flag != 0)])
        flagdata.append(flag[np.logical_and(y == cls_, flag != 0)])
        if ydata[-1].shape[0] < minsize:
            minsize = ydata[-1].shape[0]
    for i in range(len(xdata)):
        if xdata[i].shape[0] < minsize:
            import pdb
            pdb.set_trace()
        idx = np.arange(xdata[i].shape[0])
        np.random.shuffle(idx)
        xdata[i] = xdata[i][idx][:minsize]
        ydata[i] = ydata[i][idx][:minsize]
        zdata[i] = zdata[i][idx][:minsize]
        flagdata[i] = flagdata[i][idx][:minsize]

    # !list
    return np.concatenate(xdata, 0), np.concatenate(ydata, 0), np.concatenate(zdata, 0), , np.concatenate(flagdata, 0)

