import numpy as np
import torch
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import torchvision
import sys
import os
from PIL import Image
transform_train=transforms.Compose([  
    transforms.RandomCrop(32, padding=4),  
    transforms.RandomHorizontalFlip(),  
    transforms.ToTensor(),    
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  
])  
transform_test=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test_tiny_imagenet = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

class TinyImageNet(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.Train = train
        self.root_dir = root
        self.transform = transform
        self.train_dir = os.path.join(self.root_dir, "train")
        self.val_dir = os.path.join(self.root_dir, "val")
        if (self.Train):
            self._create_class_idx_dict_train()
        else:
            self._create_class_idx_dict_val()
        self._make_dataset(self.Train)
        words_file = os.path.join(self.root_dir, "words.txt")
        wnids_file = os.path.join(self.root_dir, "wnids.txt")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path, tgt = self.images[idx]
        with open(img_path, 'rb') as f:
            sample = Image.open(img_path)
            sample = sample.convert('RGB')
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, tgt

def cifar10_global(args,root):
    dataset_train=datasets.CIFAR10(root, train=True, transform= transform_train, download=True)
    dataset_test=datasets.CIFAR10(root, train=False, transform= transform_test, download=True)
    dataloader_train=data.DataLoader(dataset=dataset_train, batch_size=args.local_batch_size, shuffle=True)
    dataloader_test=data.DataLoader(dataset=dataset_test, batch_size=args.local_batch_size, shuffle=False)
    return dataloader_train, dataloader_test

def cifar10_noiid(args,root):

    dataset_train=datasets.CIFAR10(root, train=True, transform=transform_train, download=True)
    dataset_test=datasets.CIFAR10(root, train=False, transform=transform_test, download=True)

    x_train, y_train=dataset_train.data, np.array(dataset_train.targets)
    x_test, y_test=dataset_test.data, np.array(dataset_test.targets)

    data_train_dict=set_split(args=args,y=y_train,train=True)
    data_test_dict=set_split(args=args,y=y_test,train=False)
    train_len_dict=dict()
    test_len_dict=dict()
    
    dataloader_train_dict, dataloader_test_dict=dict(), dict()
    for idx in range(args.all_client):

        train_indices = data_train_dict[idx]  
        test_indices = data_test_dict[idx]    
        train_len_dict[idx]=len(train_indices)
        test_len_dict[idx]=len(test_indices)

        train_subset = data.Subset(dataset_train, train_indices)
        test_subset = data.Subset(dataset_test, test_indices)

        dataloader_train_local=data.DataLoader(train_subset, batch_size=args.local_batch_size, shuffle=True, num_workers=0)
        dataloader_test_local=data.DataLoader(test_subset, batch_size=args.local_batch_size, shuffle=True, num_workers=0)
        
        dataloader_train_dict[idx]=dataloader_train_local
        dataloader_test_dict[idx]=dataloader_test_local

    return dataloader_train_dict, dataloader_test_dict, train_len_dict, test_len_dict

def cifar100_noiid(args,root):
    dataset_train=datasets.CIFAR100(root, train=True, transform=transform_train, download=True)
    dataset_test=datasets.CIFAR100(root, train=False, transform=transform_test, download=True)

    x_train, y_train=dataset_train.data, np.array(dataset_train.targets)
    x_test, y_test=dataset_test.data, np.array(dataset_test.targets)

    data_train_dict=set_split(args=args,y=y_train,train=True)
    data_test_dict=set_split(args=args,y=y_test,train=False)
    train_len_dict=dict()
    test_len_dict=dict()
    
    dataloader_train_dict, dataloader_test_dict=dict(), dict()
    for idx in range(args.all_client):

        train_indices = data_train_dict[idx]  
        test_indices = data_test_dict[idx]    
        train_len_dict[idx]=len(train_indices)
        test_len_dict[idx]=len(test_indices)

        train_subset = data.Subset(dataset_train, train_indices)
        test_subset = data.Subset(dataset_test, test_indices)

        dataloader_train_local=data.DataLoader(train_subset, batch_size=args.local_batch_size, shuffle=True, num_workers=0)
        dataloader_test_local=data.DataLoader(test_subset, batch_size=args.local_batch_size, shuffle=True, num_workers=0)
        
        dataloader_train_dict[idx]=dataloader_train_local
        dataloader_test_dict[idx]=dataloader_test_local

    return dataloader_train_dict, dataloader_test_dict, train_len_dict, test_len_dict
def cifar100_distill(args, root):

    distill_dataset=datasets.CIFAR100(root=root, transform=transform_test, train=False, download=True)
    distill_dataloader=data.DataLoader(dataset=distill_dataset, batch_size=args.distill_batch_size, shuffle=True)

    return distill_dataloader

def cifar100_global(args, root):
    dataset_train=datasets.CIFAR100(root, train=True, transform=transform_train, download=True)
    dataset_test=datasets.CIFAR100(root, train=False, transform=transform_test, download=True)
    dataloader_train=data.DataLoader(dataset=dataset_train, batch_size=args.local_batch_size, shuffle=True)
    dataloader_test=data.DataLoader(dataset=dataset_test, batch_size=args.local_batch_size, shuffle=False)
    return dataloader_train, dataloader_test

def tiny_imagenet_distill(args, root):
    distill_dataset = TinyImageNet(root, train=False, transform=transform_test_tiny_imagenet)
    distill_dataloader = DataLoader(distill_dataset, batch_size=args.distill_batch_size, shuffle=True)

    return distill_dataloader


def set_split(args,y,train=True):

    min_size=0
    K=args.num_classes
    N=y.shape[0]
    data_train_dict=dict()
    
    cur_q=0
    while min_size<128:
        cur_q+=1
        idx_batch=[[] for _ in range(args.all_client)]

        for k in range(K):

            idx_k=np.where(y==k)[0]
            np.random.shuffle(idx_k)
            proportions=np.random.dirichlet(np.repeat(args.alpha, args.all_client))
            proportions=np.array(proportions)
            proportions=(np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
            idx_batch=[idx_i+idx.tolist() for idx_i, idx in zip(idx_batch, np.split(idx_k,proportions))]

            min_size=min([len(idx_i) for idx_i in idx_batch])

    sum_num=0
    for i in range(args.all_client):
        np.random.shuffle(idx_batch[i])
        data_train_dict[i]=idx_batch[i]

        sum_num+=len(data_train_dict[i])

    return data_train_dict
