import sys, os
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_path)

import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import os

def prepare_cifar_data(args):
    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
   
    test_dataset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_val)
    test_loaders = [torch.utils.data.DataLoader(test_dataset, batch_size=512, shuffle=False,num_workers=args.workers, pin_memory=True)]
    trainval_dataset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_val)
    val_dataset = torch.utils.data.Subset(trainval_dataset, list(range(len(trainval_dataset)))[int(len(trainval_dataset)*0.7):])
    val_loaders = [torch.utils.data.DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=args.workers, pin_memory=True)]
    return val_loaders, test_loaders

def shuffle_cifar_data(args,imbalace_factor):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    train_loaders = []
    for i in range(len(imbalace_factor)):
        train_dataset = IMBALANCECIFAR10(root='../data', client_idx = i, percent=args.batch*args.local_iter/50000 ,imb_type= "exp", imb_factor=imbalace_factor[i], train=True, download=True, transform=transform_train)
        train_loaders.append(torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=True,num_workers=args.workers, pin_memory=True))
    return train_loaders

def balanced_cifar(args):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    train_dataset = IMBALANCECIFAR10(root='../data', client_idx = 0, percent=0.0248,imb_type="exp", imb_factor= 1.0, train=True, download=True, transform=transform_train)
    return torch.utils.data.DataLoader(train_dataset, batch_size=args.batch, shuffle=True,num_workers=args.workers, pin_memory=True)


class IMBALANCECIFAR10(torchvision.datasets.CIFAR10):
    cls_num = 10

    def __init__(self, root, client_idx, percent= 0.1, imb_type='"exp"', imb_factor=0.01, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download)
        img_num_list = np.array(self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor))
        img_num_list = ( img_num_list *  len(self.data) * percent/img_num_list.sum() ).astype(int)
        # print(img_num_list, img_num_list.sum())
        self.gen_imbalanced_data(img_num_list,client_idx)

    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        if imb_type == 'exp':
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == 'step':
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls, client_idx):
        new_data = []
        new_targets = []

        train_targets = self.targets[:int(len(self.targets)*0.7)]
        train_data = self.data[:int(len(self.data)*0.7)]
        targets_np = np.array(train_targets, dtype=np.int64)
        classes = np.unique(targets_np)
        np.random.shuffle(classes) # random shuffle class 
        # classes = (classes + client_idx)%len(img_num_per_cls)
        self.num_per_cls_dict = dict()
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0]
            np.random.shuffle(idx)
            selec_idx = idx[:the_img_num]
            new_data.append(train_data[selec_idx, ...])
            new_targets.extend([the_class, ] * the_img_num)
        new_data = np.vstack(new_data)
        self.data = new_data
        self.targets = new_targets
        
    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list

