import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import pandas as pd
import random
import json
import os
from PIL import Image
import torchvision.models as tor_models
from kmeans_pytorch import kmeans
from data.datasets_clean import input_dataset

class IMBALANCECIFAR10(torchvision.datasets.CIFAR10):
    cls_num = 10

    def __init__(self, root, imb_type='exp', imb_factor=0.01, g_idx=2, noise_mode='sym', type='knn', noise_ratio=0,
                 rand_number=0, train=True, transform=None, target_transform=None, download=True):
        super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download)
        np.random.seed(rand_number)
        random.seed(rand_number)
        
        self.g_idx = g_idx
        self.is_train = train
        self.transform = transform
        self.imb_factor = imb_factor
        self.type = type

        if train:
            img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor)
            self.gen_imbalanced_data(img_num_list, self.type)
            
            noise_file = os.path.join(root, 'cifar' + str(self.cls_num) + '_' + imb_type + '_' + str(imb_factor) + '_' + noise_mode + '_' + str(noise_ratio))
            self.get_noisy_data(self.cls_num, noise_file, noise_mode, noise_ratio)

        self.labels = self.targets

    def __getitem__(self, index):
        if self.is_train:
            img, target = self.data[index], self.targets[index]
            img = Image.fromarray(img)
            img = self.transform(img)
            g = self.group_label[index]
            return img, target, g, index
        else:
            img, target = super(IMBALANCECIFAR10, self).__getitem__(index)
            return img, target, index

    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        if imb_type == 'exp':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls, type='knn'):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)
        self.num_per_cls_dict = dict()
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0]
            np.random.shuffle(idx)
            selec_idx = idx[:the_img_num]
            new_data.append(self.data[selec_idx, ...])
            new_targets.extend([the_class, ] * the_img_num)
        new_data = np.vstack(new_data)
        self.data = new_data
        self.targets = new_targets
        if type == 'knn':
            group_file = 'data_longtail/' + f'cifar{len(classes)}_tr{self.imb_factor}_group_{self.g_idx}_knn'
        else:
            group_file = 'data_longtail/' + f'cifar{len(classes)}_tr{self.imb_factor}_group_{self.g_idx}'


        if not os.path.exists(group_file):
            # Previous generation of group information
            group_label = []
            if type != 'knn':
                pre_model = tor_models.resnet18(pretrained=True)
                for param in pre_model.parameters():
                    param.requires_grad = False
                num_ftrs = pre_model.fc.in_features
                pre_model.fc = nn.Linear(num_ftrs, self.g_idx)
                pre_model.cuda()
                pre_model.eval()
                for i in range(len(self.data)):
                    img = self.data[i]
                    img = Image.fromarray(img)
                    img = self.transform(img).cuda().unsqueeze(0) 
                    out = pre_model(img)
                    group_label.append(torch.max(out.data, 1)[1].item())
                json.dump(group_label,open(group_file,"w"))
            else:
                pre_model = tor_models.resnet18(pretrained=True)
                for param in pre_model.parameters():
                    param.requires_grad = False
                num_ftrs = pre_model.fc.in_features
                # pre_model.fc = nn.Linear(num_ftrs, len(classes))
                pre_model.cuda()
                for _ in range(10):
                    record = []
                    idx_rec = []
                    pre_model.eval()
                    for i in range(len(self.data)):
                        img = self.data[i]
                        img = Image.fromarray(img)
                        img = self.transform(img).cuda().unsqueeze(0)
                        with torch.no_grad():
                            extracted_feature = pre_model(img)
                        record += [extracted_feature]
                        idx_rec += [i]
                    record = torch.cat(record, dim=0)
                    idx_rec = torch.Tensor(idx_rec).int()
                    raw = torch.Tensor([i for i in range(len(self.data))]).int()
                    sel_loc = [i for i in range(len(self.data))]
                    sel_feature = record[sel_loc]
                    cluster_ids_x, cluster_centers = kmeans(X=sel_feature, num_clusters=len(classes), distance='euclidean', device='cuda')
                    norm_center = F.normalize(cluster_centers,dim=1)
                    similarity = torch.matmul(norm_center,norm_center.T)
                    print(f'The similarify matrix of cluster centers are\n{similarity}')
                    percentage = [np.round((torch.sum(cluster_ids_x==i)/cluster_ids_x.shape[0]).item()*100,3) for i in range(len(classes))]
                    print(f'Each cluster has {percentage}% data')
                    group_label = cluster_ids_x.tolist()
                json.dump(group_label,open(group_file,"w"))
        self.group_label = json.load(open(group_file,"r"))
        import collections
        print(collections.Counter(self.group_label))
    
    def get_noisy_data(self, cls_num, noise_file, noise_mode, noise_ratio):
        train_label = self.targets
        
        if os.path.exists(noise_file):
            noise_label = json.load(open(noise_file,"r"))
        else:    #inject noise
            noise_label = []
            num_train = len(self.targets)
            idx = list(range(num_train))
            random.shuffle(idx)
            cls_num_list = self.get_cls_num_list()
            
            if noise_mode == 'sym':
                num_noise = int(noise_ratio * num_train)
                noise_idx = idx[:num_noise]

                for i in range(num_train):
                    if i in noise_idx:
                        newlabel = (random.randint(1, cls_num - 1) + train_label[i]) % cls_num
                        assert newlabel != train_label[i]
                        noise_label.append(newlabel)
                    else:
                        noise_label.append(train_label[i])

            elif noise_mode == 'imb':
                num_noise = int(noise_ratio * num_train)
                noise_idx = idx[:num_noise]

                p = np.array([cls_num_list for _ in range(cls_num)])
                for i in range(cls_num):
                    p[i][i] = 0
                p = p / p.sum(axis=1, keepdims=True)
                for i in range(num_train):
                    if i in noise_idx:
                        newlabel = np.random.choice(cls_num, p=p[train_label[i]])
                        assert newlabel != train_label[i]
                        noise_label.append(newlabel)
                    else:    
                        noise_label.append(train_label[i])

            noise_label = np.array(noise_label, dtype=np.int8).tolist()
            #label_dict['noisy_labels'] = noise_label
            print("save noisy labels to %s ..." % noise_file)     
            json.dump(noise_label, open(noise_file,"w")) 

        self.clean_targets = self.targets[:]
        self.targets = noise_label

        for c1, c0 in zip(self.targets, self.clean_targets):
            if c1 != c0:
                self.num_per_cls_dict[c1] += 1
                self.num_per_cls_dict[c0] -= 1
        
    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list



class IMBALANCECIFAR100(IMBALANCECIFAR10):
    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
    This is a subclass of the `CIFAR10` Dataset.
    """
    base_folder = 'cifar-100-python'
    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
    ]

    test_list = [
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
    ]
    meta = {
        'filename': 'meta',
        'key': 'fine_label_names',
        'md5': '7973b15100ade9c7d40fb424638fde48',
    }
    cls_num = 100


if __name__ == '__main__':
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = IMBALANCECIFAR100(root='./data', train=True,
                    download=True, transform=transform)
    trainloader = iter(trainset)
    data, label = next(trainloader)
    import pdb; pdb.set_trace()
