import os
import torch
import random
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader
from utils import set_seed
from torch.utils.data import Subset
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder,FashionMNIST

def dirichlet_distribution(data_name, dataroot, num_clients, seed, least_nums, alpha):
    set_seed(seed)
    if data_name =='cifar10':
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
        normalization = transforms.Normalize(mean, std)
        data_obj = CIFAR10
    elif data_name == 'cifar100':
        mean = [0.5071, 0.4865, 0.4409]
        std = [0.2673, 0.2564, 0.2762]
        normalization = transforms.Normalize(mean, std)
        data_obj = CIFAR100
    elif data_name == 'tiny_imageNet':
        mean, std = [0.4802, 0.4481, 0.3975], [0.2770, 0.2691, 0.2821]  
        normalization = transforms.Normalize(mean, std)
        data_obj = ImageFolder
    elif data_name == 'fashionmnist':
        mean, std = [0.2861,], [0.3530,]
        normalization = transforms.Normalize(mean, std)
        data_obj = FashionMNIST
    else:
        raise ValueError("choose data_name from ['fashionmnist', 'cifar10' , 'cifar100', 'tiny_imageNet']")

    if data_name=='tiny_imageNet':
        transform = transforms.Compose([  
            transforms.Resize((32, 32)),  
            transforms.ToTensor(), 
            normalization
        ])  
        train_data = ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/train'), transform=transform)  
        test_data = ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/test'), transform=transform)  
        targets = torch.tensor(train_data.targets)

    elif(data_name == 'fashionmnist'):
        transform = transforms.Compose([  
            transforms.Resize((32, 32)),  
            transforms.ToTensor(),  
            normalization
        ])
        train_data = data_obj(root=dataroot, train=True, download=True, transform=transform)  
        targets = train_data.targets  

    else:
        transform =  transforms.Compose([transforms.ToTensor(), normalization])
        train_data = data_obj(dataroot, train=True, download=True, transform=transform)
        test_data = data_obj(dataroot, train=False, download=True, transform=transform)
        targets = torch.tensor(train_data.targets)

    min_value = 0.0
    max_value = 0.0

    while least_nums >= min_value:
        n_classes = len(train_data.classes)
        label_distribution = np.random.dirichlet([alpha]*num_clients, n_classes)
    
        data_id = [i for i in range(len(train_data))]
        class_idcs = [np.argwhere(targets[data_id]==y).flatten() for y in range(n_classes)]
        clients_idcs = [[] for _ in range(num_clients)] 

        for c, fracs in zip(class_idcs, label_distribution):
            for i, idcs in enumerate(np.split(c, (np.cumsum(fracs)[:-1]*len(c)).astype(int))):
                clients_idcs[i] += [idcs]

        clients_idcs = [np.concatenate(idcs) for idcs in clients_idcs]
        min_value = min([len(client_size) for client_size in clients_idcs])

        if min_value>max_value:
            max_value = min_value
    
    clients_data = [Subset(train_data, idcs) for idcs in clients_idcs] 
    test_loader = DataLoader(test_data, batch_size=256, shuffle=False, num_workers=4)
    return clients_data, mean, std, test_loader


def get_classes(data_name, dataroot, num_clients, num_classes_per_client, seed):
    set_seed(seed)
    if data_name =='cifar10':
        data_obj = CIFAR10
        total_classes =10
        class_nums = [5000] * total_classes
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
        normalization = transforms.Normalize(mean, std)
    elif data_name == 'cifar100':
        data_obj = CIFAR100
        total_classes =100
        class_nums = [500] * total_classes
        mean = [0.5071, 0.4865, 0.4409]
        std = [0.2673, 0.2564, 0.2762]
        normalization = transforms.Normalize(mean, std)
    elif data_name == 'tiny_imageNet':
        data_obj = ImageFolder
        total_classes =200
        class_nums = [500] * total_classes
        mean, std = [0.4802, 0.4481, 0.3975], [0.2770, 0.2691, 0.2821]  
        normalization = transforms.Normalize(mean, std)
    elif data_name == 'fashionmnist':
        data_obj = FashionMNIST
        total_classes =10
        class_nums = [6000] * total_classes
        mean, std = [0.2861,], [0.3530,]
        normalization = transforms.Normalize(mean, std)
    else:
        raise ValueError("choose data_name from ['fashionmnistt', 'cifar10', 'tiny_imageNet', 'cifar100']")

    if data_name=='tiny_imageNet':
        transform = transforms.Compose([  
            transforms.Resize((32, 32)), 
            transforms.ToTensor(),  
            normalization
        ])  
        train_data = ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/train'), transform=transform)
        test_data = ImageFolder(os.path.join(dataroot, 'tiny-imagenet-200/test'), transform=transform)  
        targets = torch.tensor(train_data.targets)  

    elif(data_name == 'fashionmnist'):
        transform = transforms.Compose([  
            transforms.Resize((32, 32)),  
            transforms.ToTensor(),  
            normalization
        ])
        train_data = data_obj(root=dataroot, train=True, download=True, transform=transform)  
        targets = train_data.targets
    
    else:
        transform =  transforms.Compose([transforms.ToTensor(), normalization])
        train_data = data_obj(dataroot, train=True, download=True, transform=transform)
        test_data = data_obj(dataroot, train=False, download=True, transform=transform)
        targets = torch.tensor(train_data.targets) 

    client_classes = {}  
    for client_id in range(num_clients):  
        classes = np.random.choice(total_classes, size=num_classes_per_client, replace=False)  
        client_classes[client_id] = classes  

    client_weights = {}  
    for client_id,_ in client_classes.items():
        client_weights[client_id] = {}  
        for class_label in client_classes[client_id]:   
            weight = np.random.uniform(0.4, 0.6)  
            client_weights[client_id][class_label] = weight  

    weights_sum = {id: 0 for id in range(total_classes)}  
    for _, weights in client_weights.items():  
        for key,value in weights.items():
            weights_sum[key] += value

    for _, weights in client_weights.items():  
        for key,value in weights.items():
            weights[key] = (value / weights_sum[key])

    client_label_sum = {id: 0 for id in range(total_classes)}
    for client_id, weights in client_weights.items():  
        for key,value in weights.items():
            weights[key] = int(value * int(class_nums[key]))
            client_label_sum[key] += weights[key]

    flag = [0 for _ in range(total_classes)]
    for key,value in client_label_sum.items():
        if(value==0 or value==int(class_nums[key])):
            pass
        else:
            for _, weights in client_weights.items():  
                for label,value in weights.items():
                    if(key==label and flag[key]==0):
                        weights[label] += (int(class_nums[key]) - client_label_sum[key])
                        flag[key]=1

    label_idcs = {label_id: [] for label_id in range(total_classes)}  
    for label in range(total_classes):
        for j in range(len(targets)):
            if(targets[j]==label):
                label_idcs[label].append(j)
        random.shuffle(label_idcs[label])  

    clients_idcs = []
    idx = {label:0 for label in range(total_classes)}
    for client_id, weights in client_weights.items():
        client_idcs = []
        for key,value in weights.items():
            client_idcs += label_idcs[key][idx[key]:idx[key]+value]
            idx[key]+=value
        print(client_id,'\t',weights)
        clients_idcs.append(client_idcs)
    clients_data = [Subset(train_data, idcs) for idcs in clients_idcs] 
    test_loader = DataLoader(test_data, batch_size=256, shuffle=False, num_workers=4)
    return clients_data, mean, std, test_loader









    
    







