import logging.config
from typing import Callable, List, Optional

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision import transforms


logger = logging.getLogger()

class ImageDataset(Dataset):
    def __init__(self, data, transform=None, cls_list=None, data_dir=None,
                 preload=False, device=None, transform_on_gpu=False):
        inputs,gt = data
        self.images, self.labels = [], []
        for x,y in zip(inputs,gt):
            self.images.append(x)
            self.labels.append(y)

        self.transform = transform
        self.cls_list = cls_list
        self.data_dir = data_dir
        self.preload = preload
        self.device = device
        self.transform_on_gpu = transform_on_gpu

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

class StreamDataset(Dataset):
    def __init__(self, sample, transform :Optional[Callable]=None, cls_list=None):
        
        self.images     = []
        self.labels     = []
        self.cls_list   = cls_list
        self.transform  = transform

        for _, (image, label) in enumerate(sample):
            for img in image:
                self.images.append(img)
            for lbl in label:
                self.labels.append(self.cls_list.index(lbl.item()))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        sample  = dict()
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image   = self.images[idx]
        label   = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        sample["image"] = image
        sample["label"] = label
        return sample

    @torch.no_grad()
    def get_data(self):
        data = dict()
        images = []
        labels = []
        for i, image in enumerate(self.images):
            image = transforms.ToPILImage()(image)
            images.append(self.transform(image))
            labels.append(self.labels[i])
        data['image'] = torch.stack(images)
        data['label'] = torch.LongTensor(labels)
        return data

class MemoryDataset(Dataset):
    def __init__(self, transform=None, test_transform=None, cls_list=None, save_test=None, keep_history=False):
        
        self.datalist = []
        self.labels = []
        self.images = []
        
        self.transform = transform
        self.cls_list = cls_list
        self.cls_dict = {cls_list[i]:i for i in range(len(cls_list))}
        self.cls_count = []
        self.cls_idx = []
        self.cls_train_cnt = np.array([])
        self.score = []
        self.others_loss_decrease = np.array([])
        self.previous_idx = np.array([], dtype=int)
        self.test_transform = test_transform
        self.keep_history = keep_history

        self.save_test = save_test
        if self.save_test is not None:
            self.device_img = []

    def __len__(self):
        return len(self.images)

    def add_new_class(self, cls_list):
        self.cls_list = cls_list
        self.cls_count.append(0)
        self.cls_idx.append([])
        self.cls_dict = {self.cls_list[i]:i for i in range(len(self.cls_list))}
        self.cls_train_cnt = np.append(self.cls_train_cnt, 0)

    def __getitem__(self, idx):
        sample = dict()
        if torch.is_tensor(idx):
            idx = idx.value()
        label = self.labels[idx]
        image = self.images[idx]
        if self.transform:
            image = self.transform(image)
        sample["image"] = image
        sample["label"] = label
        return sample

    def update_gss_score(self, score, idx=None):
        if idx is None:
            self.score.append(score)
        else:
            self.score[idx] = score

    def replace_sample(self, sample, idx=None):
        x, y = sample
        y = y.item()
        self.cls_count[self.cls_dict[y]] += 1

        if idx is None:
            self.cls_idx[self.cls_dict[y]].append(len(self.images))
            self.datalist.append({'image':x,'label':self.cls_dict[y]})
            self.images.append(x)
            self.labels.append(self.cls_dict[y])
            if self.save_test:
                self.device_img.append(self.test_transform(transforms.ToPILImage()(x)).unsqueeze(0))
            if self.cls_count[self.cls_dict[y]] == 1:
                self.others_loss_decrease = np.append(self.others_loss_decrease, 0)
            else:
                self.others_loss_decrease = np.append(self.others_loss_decrease, np.mean(self.others_loss_decrease[self.cls_idx[self.cls_dict[y]][:-1]]))
        else:
            self.cls_count[self.labels[idx]] -= 1
            self.cls_idx[self.labels[idx]].remove(idx)
            self.datalist[idx] = {'image':x,'label':self.cls_dict[y]}
            self.cls_idx[self.cls_dict[y]].append(idx)
            self.images[idx] = x
            self.labels[idx] = self.cls_dict[y]
            if self.save_test:
                self.device_img[idx] = self.test_transform(transforms.ToPILImage()(x)).unsqueeze(0)
            if self.cls_count[self.cls_dict[y]] == 1:
                self.others_loss_decrease[idx] = np.mean(self.others_loss_decrease)
            else:
                self.others_loss_decrease[idx] = np.mean(self.others_loss_decrease[self.cls_idx[self.cls_dict[y]][:-1]])

    def get_weight(self):
        weight = np.zeros(len(self.images))
        for i, indices in enumerate(self.cls_idx):
            weight[indices] = 1/self.cls_count[i]
        return weight

    @torch.no_grad()
    def get_batch(self, batch_size, use_weight=False, transform=None):
        if use_weight:
            weight = self.get_weight()
            indices = np.random.choice(range(len(self.images)), size=batch_size, p=weight/np.sum(weight), replace=False)
        else:
            indices = np.random.choice(range(len(self.images)), size=batch_size, replace=False)
        data = dict()
        images = []
        labels = []
        for i in indices:
            if transform is None:
                images.append(self.transform(transforms.ToPILImage()(self.images[i])))
            else:
                images.append(transform(transforms.ToPILImage()(self.images[i])))
            labels.append(self.labels[i])
            self.cls_train_cnt[self.labels[i]] += 1
        data['image'] = torch.stack(images)
        data['label'] = torch.LongTensor(labels)
        if self.keep_history:
            self.previous_idx = np.append(self.previous_idx, indices)
        return data

    def update_loss_history(self, loss, prev_loss, ema_ratio=0.90, dropped_idx=None):
        if dropped_idx is None:
            loss_diff = np.mean(loss - prev_loss)
        elif len(prev_loss) > 0:
            mask = np.ones(len(loss), bool)
            mask[dropped_idx] = False
            loss_diff = np.mean((loss[:len(prev_loss)] - prev_loss)[mask[:len(prev_loss)]])
        else:
            loss_diff = 0
        difference = loss_diff - np.mean(self.others_loss_decrease[self.previous_idx]) / len(self.previous_idx)
        self.others_loss_decrease[self.previous_idx] -= (1 - ema_ratio) * difference
        self.previous_idx = np.array([], dtype=int)

    def get_two_batches(self, batch_size, test_transform):
        indices = np.random.choice(range(len(self.images)), size=batch_size, replace=False)
        data_1 = dict()
        data_2 = dict()
        images = []
        labels = []
        for i in indices:
            if self.transform_on_gpu:
                images.append(self.transform_gpu(self.images[i].to(self.device)))
            else:
                images.append(self.transform(self.images[i]))
            labels.append(self.labels[i])
        data_1['image'] = torch.stack(images)
        data_1['label'] = torch.LongTensor(labels)
        images = []
        labels = []
        for i in indices:
            images.append(self.test_transform(self.images[i]))
            labels.append(self.labels[i])
        data_2['image'] = torch.stack(images)
        data_2['label'] = torch.LongTensor(labels)
        return data_1, data_2

    def make_cls_dist_set(self, labels, transform=None):
        if transform is None:
            transform = self.transform
        indices = []
        for label in labels:
            indices.append(np.random.choice(self.cls_idx[label]))
        indices = np.array(indices)
        data = dict()
        images = []
        labels = []
        for i in indices:
            images.append(transform(self.images[i]))
            labels.append(self.labels[i])
        data['image'] = torch.stack(images)
        data['label'] = torch.LongTensor(labels)
        return data

    def make_val_set(self, size=None, transform=None):
        if size is None:
            size = int(0.1*len(self.images))
        if transform is None:
            transform = self.transform
        size_per_cls = size//len(self.cls_list)
        indices = []
        for cls_list in self.cls_idx:
            if len(cls_list) >= size_per_cls:
                indices.append(np.random.choice(cls_list, size=size_per_cls, replace=False))
            else:
                indices.append(np.random.choice(cls_list, size=size_per_cls, replace=True))
        indices = np.concatenate(indices)
        data = dict()
        images = []
        labels = []
        for i in indices:
            images.append(transform(self.images[i]))
            labels.append(self.labels[i])
        data['image'] = torch.stack(images)
        data['label'] = torch.LongTensor(labels)
        return data

    def is_balanced(self):
        mem_per_cls = len(self.images)//len(self.cls_list)
        for cls in self.cls_count:
            if cls < mem_per_cls or cls > mem_per_cls+1:
                return False
        return True


def get_train_datalist(dataset, n_tasks, m, n, rnd_seed, cur_iter: int) -> List:
    if n == 100 or m == 0:
        n = 100
        m = 0
    return pd.read_json(
        f"collections/{dataset}/{dataset}_split{n_tasks}_n{n}_m{m}_rand{rnd_seed}_task{cur_iter}.json"
    ).to_dict(orient="records")

def get_test_datalist(dataset) -> List:
    return pd.read_json(f"collections/{dataset}/{dataset}_val.json").to_dict(orient="records")


def get_statistics(dataset: str):
    """
    Returns statistics of the dataset given a string of dataset name. To add new dataset, please add required statistics here
    """
    if dataset == 'imagenet':
        dataset = 'imagenet1000'
    assert dataset in [
        "mnist",
        "KMNIST",
        "EMNIST",
        "FashionMNIST",
        "SVHN",
        "cifar10",
        "cifar100",
        "CINIC10",
        "imagenet100",
        "imagenet900",
        "imagenetsub",
        "imagenet1000",
        "tinyimagenet",
        "imagenet-r",
        "nch",
        "cub200",
        "cars196",
        "cub175",
        "cubrandom",
        "places365",
        'gtsrb',
        'wikiart'
    ]

    mean = {
        "mnist": (0.1307,),
        "KMNIST": (0.1307,),
        "EMNIST": (0.1307,),
        "FashionMNIST": (0.1307,),
        "SVHN": (0.4377, 0.4438, 0.4728),
        "cifar10": (0.4914, 0.4822, 0.4465),
        "cifar100": (0.5071, 0.4867, 0.4408),
        "CINIC10": (0.47889522, 0.47227842, 0.43047404),
        "tinyimagenet": (0.4802, 0.4481, 0.3975),
        "imagenet100": (0.485, 0.456, 0.406),
        "imagenet900":(0.485, 0.456, 0.406),
        "imagenet1000": (0.485, 0.456, 0.406),
        "imagenetsub": (0.485, 0.456, 0.406),
        "imagenet-r": (0.485, 0.456, 0.406),
        "nch": (0.485, 0.456, 0.406),
        "cub200": (0.5071, 0.4867, 0.4408),
        "cub175": (0.5071, 0.4867, 0.4408),
        "cars196": (0.46951303, 0.45906875, 0.45407656),
        "cubrandom": (0.5071, 0.4867, 0.4408),
        "places365": (0.485, 0.456, 0.406),
        "gtsrb": (0.485, 0.456, 0.406),
        "wikiart": (0.485, 0.456, 0.406),
    }

    std = {
        "mnist": (0.3081,),
        "KMNIST": (0.3081,),
        "EMNIST": (0.3081,),
        "FashionMNIST": (0.3081,),
        "SVHN": (0.1969, 0.1999, 0.1958),
        "cifar10": (0.2023, 0.1994, 0.2010),
        "cifar100": (0.2675, 0.2565, 0.2761),
        "CINIC10": (0.24205776, 0.23828046, 0.25874835),
        "tinyimagenet": (0.2302, 0.2265, 0.2262),
        "imagenet100": (0.229, 0.224, 0.225),
        "imagenet900": (0.229, 0.224, 0.225),
        "imagenet1000": (0.229, 0.224, 0.225),
        "imagenetsub": (0.229, 0.224, 0.225),
        "imagenet-r": (0.229, 0.224, 0.225),
        "nch": (0.229, 0.224, 0.225),
        "cub200": (0.2675, 0.2565, 0.2761),
        "cub175": (0.2675, 0.2565, 0.2761),
        "cars196": (0.29279903, 0.2917511, 0.2999349),
        "cubrandom": (0.2675, 0.2565, 0.2761),
        "places365": (0.229, 0.224, 0.225),
        "gtsrb": (0.229, 0.224, 0.225),
        "wikiart": (0.229, 0.224, 0.225),
    }

    classes = {
        "mnist": 10,
        "KMNIST": 10,
        "EMNIST": 49,
        "FashionMNIST": 10,
        "SVHN": 10,
        "cifar10": 10,
        "cifar100": 100,
        "CINIC10": 10,
        "tinyimagenet": 200,
        "imagenet100": 100,
        "imagenet900": 900,
        "imagenet1000": 1000,
        "imagenetsub": 611,
        "imagenet-r": 200,
        "nch": 9,
        "cub200": 200,
        "cub175": 175,
        "cars196": 196,
        "cubrandom": 175,
        "places365": 365,
        'gtsrb': 43,
        'wikiart': 27
    }

    in_channels = {
        "mnist": 1,
        "KMNIST": 1,
        "EMNIST": 1,
        "FashionMNIST": 1,
        "SVHN": 3,
        "cifar10": 3,
        "cifar100": 3,
        "CINIC10": 3,
        "tinyimagenet": 3,
        "imagenet100": 3,
        "imagenet900": 3,
        "imagenet1000": 3,
        "imagenetsub": 3,
        "imagenet-r": 3,
        "nch": 3,
        "cub200": 3,
        "cub175": 3,
        "cars196": 3,
        "cubrandom": 3,
        "places365": 3,
        'gtsrb': 3,
        'wikiart': 3
    }

    inp_size = {
        "mnist": 28,
        "KMNIST": 28,
        "EMNIST": 28,
        "FashionMNIST": 28,
        "SVHN": 32,
        "cifar10": 32,
        "cifar100": 32,
        "CINIC10": 32,
        "tinyimagenet": 64,
        "imagenet100": 224,
        "imagenet900": 224,
        "imagenet1000": 224,
        "imagenetsub": 224,
        "imagenet-r": 224,
        "nch": 224,
        "cub200": 224,
        "cub175": 224,
        "cars196": 224,
        "cubrandom": 224,
        "places365": 224,
        'gtsrb': 224,
        'wikiart': 224
    }

    return (
        mean[dataset],
        std[dataset],
        classes[dataset],
        inp_size[dataset],
        in_channels[dataset],
    )


# from https://github.com/drimpossible/GDumb/blob/74a5e814afd89b19476cd0ea4287d09a7df3c7a8/src/utils.py#L102:5
def cutmix_data(x, y, alpha=1.0, cutmix_prob=0.5):
    assert alpha > 0
    # generate mixed sample
    lam = np.random.beta(alpha, alpha)

    batch_size = x.size()[0]
    index = torch.randperm(batch_size)

    if torch.cuda.is_available():
        index = index.cuda()

    y_a, y_b = y, y[index]
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]

    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    return x, y_a, y_b, lam


def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1.0 - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2
