from operator import itemgetter
import torch
import random
from open_fl_net import net, CNN_SVHN, SimpleVGG, ImprovedCNN_SVHN, ResNet10
import math
import copy
import numpy as np
import json
import itertools
from torch import linalg as LA
import torchvision
from torch.utils.data import DataLoader, Subset
from collections import defaultdict
from numpy.random import dirichlet
from collections import defaultdict
from torchvision.models import resnet18, efficientnet_b0, densenet121, squeezenet1_1, resnet34
from torchvision import datasets

class client:
    def __init__(self, training_dataloader, grad_cal_dataloader, num_SGD_training, num_SGD_grad_cal, set_of_classes, cifar_model) -> None:
        self.training_dataloader = training_dataloader
        self.grad_cal_dataloader = grad_cal_dataloader
        self.num_SGD_training = num_SGD_training
        self.set_of_classes = set_of_classes
        self.num_SGD_grad_cal = num_SGD_grad_cal
        self.cifar_model = cifar_model


    def local_training_sync(self, gpu_cpu_device, model_parameter_dict, dataset_name, lr, momentum, regularized_or_not):

        criterion =  torch.nn.CrossEntropyLoss()
        
        if dataset_name == "SVHN":
            local_model = CNN_SVHN()
            local_model.to(gpu_cpu_device)
            local_model.load_state_dict(model_parameter_dict)
            optimizer = torch.optim.SGD(local_model.parameters(), lr = lr, momentum = momentum)
            # optimizer = torch.optim.Adam(local_model.parameters(), lr = 0.001)

        elif dataset_name == "cifar10":

            if self.cifar_model == "resnet18":

                local_model = resnet18(weights='DEFAULT') 
                local_model.fc = torch.nn.Linear(local_model.fc.in_features, 10) 
            
            elif self.cifar_model == "simpleVGG":
                local_model = SimpleVGG(num_classes=10)
            
            elif self.cifar_model == "squeezenet":

                local_model = squeezenet1_1(weights='DEFAULT')
                local_model.classifier[1] = torch.nn.Conv2d(512, 10, kernel_size=(1,1), stride=(1,1))  
                local_model.num_classes = 10

            else:
                raise ValueError("cifar10 only accepts resnet18 or simpleVGG or squeezenet")
            
            local_model.to(gpu_cpu_device)
            local_model.load_state_dict(model_parameter_dict)
            optimizer = torch.optim.SGD(local_model.parameters(), lr= lr, momentum= momentum, weight_decay=5e-4)

            
        elif dataset_name == "cifar100":

            if self.cifar_model == "resnet18":

                local_model = resnet18(weights='DEFAULT') 
                local_model.fc = torch.nn.Linear(local_model.fc.in_features, 100) 

            elif self.cifar_model == "resnet34":

                local_model = resnet34(weights='DEFAULT') 
                local_model.fc = torch.nn.Linear(local_model.fc.in_features, 100) 
            
            elif self.cifar_model == "densenet121":

                local_model = densenet121(weights='DEFAULT') 
                local_model.classifier = torch.nn.Linear(1024, 100)

            else:
                raise ValueError("cifar100 only accepts resnet18 or resnet34, densenet121")
            
            local_model.to(gpu_cpu_device)
            local_model.load_state_dict(model_parameter_dict)
            optimizer = torch.optim.SGD(local_model.parameters(), lr= lr, momentum= momentum, weight_decay=5e-4)
            
        else:
            local_model = net(dataset_name)
            local_model.to(gpu_cpu_device)
            local_model.load_state_dict(model_parameter_dict)
            optimizer = torch.optim.SGD(local_model.parameters(), lr = lr, momentum = momentum)

        local_model.train()

        for sgd_idx, (images, labels) in enumerate(self.training_dataloader):
            if sgd_idx >= self.num_SGD_training:
                break
            images, labels = images.to(gpu_cpu_device), labels.to(gpu_cpu_device)

            # local_model.zero_grad()
            optimizer.zero_grad()
            log_probs = local_model(images)
            loss = criterion(log_probs, labels)
            if regularized_or_not == 1:
                regularized_loss = sum([LA.norm(0.5 * (local_model.state_dict()[keys] - values)**2) for keys, values in model_parameter_dict.items()])
                loss = loss + regularized_loss
            loss.backward()
            optimizer.step()
            print(f' Local SGD index : {sgd_idx} ,\tLoss: {loss.item():.6f}')

        return local_model.state_dict()
    
    def local_training_sync_model_diff(self, gpu_cpu_device, model_parameter_dict, dataset_name, lr, momentum, regularized_or_not):

        criterion =  torch.nn.CrossEntropyLoss()

        if dataset_name == "SVHN":
            local_model = CNN_SVHN()
            local_model.to(gpu_cpu_device)
            local_model.load_state_dict(model_parameter_dict)
            optimizer = torch.optim.SGD(local_model.parameters(), lr = lr, momentum = momentum)

        elif dataset_name == "cifar10":

            if self.cifar_model == "resnet18":

                local_model = resnet18(weights='DEFAULT') 
                local_model.fc = torch.nn.Linear(local_model.fc.in_features, 10) 
            
            elif self.cifar_model == "simpleVGG":
                local_model = SimpleVGG(num_classes=10)
            
            elif self.cifar_model == "squeezenet":

                local_model = squeezenet1_1(weights='DEFAULT')
                local_model.classifier[1] = torch.nn.Conv2d(512, 10, kernel_size=(1,1), stride=(1,1))  
                local_model.num_classes = 10

            else:
                raise ValueError("cifar_model only accepts resnet18 or simpleVGG or squeezenet")
            
            local_model.to(gpu_cpu_device)
            local_model.load_state_dict(model_parameter_dict)
            optimizer = torch.optim.SGD(local_model.parameters(), lr= lr, momentum= momentum, weight_decay=5e-4)
            
        elif dataset_name == "cifar100":

            if self.cifar_model == "resnet18":

                local_model = resnet18(weights='DEFAULT') 
                local_model.fc = torch.nn.Linear(local_model.fc.in_features, 100) 

            elif self.cifar_model == "resnet34":

                local_model = resnet34(weights='DEFAULT') 
                local_model.fc = torch.nn.Linear(local_model.fc.in_features, 100) 
            
            elif self.cifar_model == "densenet121":

                local_model = densenet121(weights='DEFAULT') 
                local_model.classifier = torch.nn.Linear(1024, 100)

            else:
                raise ValueError("cifar100 only accepts resnet18 or resnet34 or densenet121")
            
            local_model.to(gpu_cpu_device)
            local_model.load_state_dict(model_parameter_dict)
            optimizer = torch.optim.SGD(local_model.parameters(), lr= lr, momentum= momentum, weight_decay=5e-4)
            
        else:
            local_model = net(dataset_name)
            local_model.to(gpu_cpu_device)
            local_model.load_state_dict(model_parameter_dict)
            optimizer = torch.optim.SGD(local_model.parameters(), lr = lr, momentum = momentum)
        
        # optimizer = torch.optim.Adam(local_model.parameters(), lr=0.001)
        local_model.train()
        
        for sgd_idx, (images, labels) in enumerate(self.grad_cal_dataloader):
            if sgd_idx > self.num_SGD_grad_cal:
                break
            images, labels = images.to(gpu_cpu_device), labels.to(gpu_cpu_device)

            # local_model.zero_grad()
            optimizer.zero_grad()
            log_probs = local_model(images)
            loss = criterion(log_probs, labels)
            if regularized_or_not == 1:
                regularized_loss = sum([LA.norm(0.5 * (local_model.state_dict()[keys] - values)**2) for keys, values in model_parameter_dict.items()])
                loss = loss + regularized_loss
            loss.backward()
            optimizer.step()
            print(f' Local SGD index : {sgd_idx} ,\tLoss: {loss.item():.6f}')

        return local_model.state_dict()
    
    def local_training_sync_train_pilot_grad(self, local_training_idx, client_idx, gpu_cpu_device, pilot_parameter_dict, pilot_parameter_name_to_grad_dict, dataset_name, lr, momentum, regularized_or_not, local_weight_dict_type):
        criterion =  torch.nn.CrossEntropyLoss()
        if dataset_name == "SVHN":
            pilot_model = CNN_SVHN()
        else:
            pilot_model = net(dataset_name)
        pilot_model.to(gpu_cpu_device)
        pilot_model.load_state_dict(pilot_parameter_dict)
        optimizer = torch.optim.SGD(pilot_model.parameters(), lr = lr, momentum = momentum)
        pilot_model.train()
        local_SGD_batch_size_list = self.local_SGD_batch_size_ll[local_training_idx]

        for sgd_idx, batch_size_each_SGD in enumerate(local_SGD_batch_size_list):

            sampled_data_idx_list = random.sample(range(len(self.dataset)), batch_size_each_SGD)
            data_list, lable_list = [0] * batch_size_each_SGD, [0] * batch_size_each_SGD
            for ix, i in enumerate(sampled_data_idx_list):
                data_list[ix] = self.dataset[i][0]
                lable_list[ix] = self.dataset[i][1]
            images, labels = torch.stack(data_list, dim=0).to(gpu_cpu_device), torch.tensor(lable_list).to(gpu_cpu_device)

            # pilot_model.zero_grad()
            optimizer.zero_grad()
            log_probs = pilot_model(images)
            loss = criterion(log_probs, labels)
            if regularized_or_not == 1:
                regularized_loss = sum([LA.norm(0.5 * (pilot_model.state_dict()[keys] - values)**2) for keys, values in pilot_parameter_dict.items()])
                loss = loss + regularized_loss
            loss.backward()
            if sgd_idx == len(local_SGD_batch_size_list) - 1:
                for name, param in pilot_model.named_parameters():
                    if param.requires_grad:
                        pilot_parameter_name_to_grad_dict[name] += param.grad * local_weight_dict_type[client_idx]
            optimizer.step()
            # print(f' Local SGD index : {sgd_idx} ,\tLoss: {loss.item():.6f}')

        return pilot_parameter_name_to_grad_dict
    
    def local_training_sync_fixed_pilot_grad(self, args, client_idx, gpu_cpu_device, pilot_parameter_dict, pilot_parameter_name_to_grad_dict, dataset_name, lr, momentum, regularized_or_not, local_weight_dict_type):

        criterion =  torch.nn.CrossEntropyLoss()
        if dataset_name == "SVHN":
            pilot_model = CNN_SVHN()
        else:
            pilot_model = net(dataset_name)
        pilot_model.to(gpu_cpu_device)
        pilot_model.load_state_dict(pilot_parameter_dict)
        optimizer = torch.optim.SGD(pilot_model.parameters(), lr = lr, momentum = momentum)
        pilot_model.train()

        batch_size_each_SGD = args.batch_size
        sampled_data_idx_list = random.sample(range(len(self.dataset)), batch_size_each_SGD)
        data_list, lable_list = [0] * batch_size_each_SGD, [0] * batch_size_each_SGD
        for ix, i in enumerate(sampled_data_idx_list):
            data_list[ix] = self.dataset[i][0]
            lable_list[ix] = self.dataset[i][1]
        images, labels = torch.stack(data_list, dim=0).to(gpu_cpu_device), torch.tensor(lable_list).to(gpu_cpu_device)

        # pilot_model.zero_grad()
        optimizer.zero_grad()
        log_probs = pilot_model(images)
        loss = criterion(log_probs, labels)
        if regularized_or_not == 1:
            regularized_loss = sum([LA.norm(0.5 * (pilot_model.state_dict()[keys] - values)**2) for keys, values in pilot_parameter_dict.items()])
            loss = loss + regularized_loss
        loss.backward()
        for name, param in pilot_model.named_parameters():
            if param.requires_grad:
                pilot_parameter_name_to_grad_dict[name] += param.grad * local_weight_dict_type[client_idx]

        return pilot_parameter_name_to_grad_dict

def moving_average(original_list, odd_window):
    if odd_window%2 == 0: raise ValueError
    left = right = (odd_window - 1)//2
    if 2 * left >= len(original_list): raise ValueError
    new_list = [0] * (len(original_list) - 2 * left)
    for j in range(left, len(original_list) - right):
        new_list[j -left] = sum(original_list[j - left: j-left + odd_window])/odd_window
    
    copied_list = copy.deepcopy(original_list)
    copied_list[left: left + len(new_list)] = new_list

    return copied_list

def moving_average_reduce(original_list, window):
    if window > len(original_list): raise ValueError("The window is larger than the size of the original list")
    new_list = [0] * (len(original_list) - (window -1))
    for i in range(len(original_list) - (window -1)):
        new_list[i] = sum(original_list[i: i+window])/window
    return new_list


def test_inference(args, device, model, testloader):
    """ Returns the test accuracy and loss.
    """

    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0
    batch_size = 10
    criterion = torch.nn.CrossEntropyLoss() 

    for batch_idx, (images, labels) in enumerate(testloader):
        images, labels = images.to(device).float(), labels.to(device)

        # Inference
        outputs = model(images)
        batch_loss = criterion(outputs, labels)
        loss += batch_loss.item()

        # Prediction
        _, pred_labels = torch.max(outputs, 1)
        pred_labels = pred_labels.view(-1)
        correct += torch.sum(torch.eq(pred_labels, labels)).item()
        total += len(labels)

    accuracy = correct/total
    return accuracy, loss/len(testloader.dataset)

def Fedavg_local_weight(client_list, current_device_idx_list):
    total_num_data = 0
    for i in current_device_idx_list:
      total_num_data += len(client_list[i].training_dataloader.dataset)
    return {i: len(client_list[i].training_dataloader.dataset)/total_num_data for i in current_device_idx_list}

def create_distinct_half_SVHN(dataset, num_users = 20):
    if num_users%2 != 0:
        raise ValueError("num_users should be an even number.")
    device_data_idx = {i:[] for i in range(num_users)}
    device_label_idx_dict = {i:[] for i in range(num_users)}
    label_to_index = {i:[] for i in range(10)}
    label_to_chunk_index = {i:[] for i in range(10)}
    labels = dataset.labels

    random.seed(0)
    minimum_num_data_first_half = 0
    minimum_num_data_second_half = 0
    for i in range(10):
        label_to_index[i] = np.where(labels == i)[0].tolist()
        if i <= 4:
            if i == 0:
                minimum_num_data_first_half = len(label_to_index[i])
            else:
                minimum_num_data_first_half = min(minimum_num_data_first_half, len(label_to_index[i]))
        else:
            if i == 5:
                minimum_num_data_second_half = len(label_to_index[i])
            else:
                minimum_num_data_second_half = min(minimum_num_data_second_half, len(label_to_index[i]))
        random.shuffle(label_to_index[i])
    
    
    def chunks(lst, n):
        """Yield successive n-sized chunks from lst."""
        for i in range(0, len(lst), n):
            yield lst[i:i + n]
    
    data_needed_first_half_per_client = minimum_num_data_first_half  // (num_users//2)
    data_needed_second_half_per_client = minimum_num_data_second_half // (num_users//2)

    for label_index in range(10):
        if label_index <= 4:
            label_to_chunk_index[label_index] = list(chunks(label_to_index[label_index][:minimum_num_data_first_half], data_needed_first_half_per_client))
        else:
            label_to_chunk_index[label_index] = list(chunks(label_to_index[label_index][:minimum_num_data_second_half],data_needed_second_half_per_client))

    for label_index in range(10):
        if label_index < 5:
            for d_index in range(num_users//2):
                device_label_idx_dict[d_index] += label_to_chunk_index[label_index][d_index]
        else:
            for d_index in range(num_users//2, num_users):
                device_label_idx_dict[d_index] += label_to_chunk_index[label_index][d_index - num_users//2]

    dict_user_data = {i: [0]* len(device_label_idx_dict[i]) for i in range(num_users)}
    for k, v in device_label_idx_dict.items():
        for lx, l in enumerate(v):
            dict_user_data[k][lx] = dataset[l]
        
    return dict_user_data

def create_distinct_labels_for_10_clients_SVHN(dataset, num_users = 10):
    if num_users != 10:
        raise ValueError("create_distinct_labels_for_10_clients_SVHN method is only called for 10 clients scenario.")
    
    device_label_idx_dict = {i:[] for i in range(num_users)}
    label_to_index = {i:[] for i in range(10)}
    labels = dataset.labels

    random.seed(0)

    minimum_num_datapoint = 0

    for i in range(10):
        label_to_index[i] = np.where(labels == i)[0].tolist()
        if i == 0:
            minimum_num_datapoint = len(label_to_index[i])
        else:
            minimum_num_datapoint = min(minimum_num_datapoint, len(label_to_index[i]))

        random.shuffle(label_to_index[i])
    
    for label_index in range(10):
        device_label_idx_dict[label_index] = label_to_index[label_index][:minimum_num_datapoint]

    dict_user_data = {i: [0]* len(device_label_idx_dict[i]) for i in range(num_users)}
    for k, v in device_label_idx_dict.items():
        for lx, l in enumerate(v):
            dict_user_data[k][lx] = dataset[l]
        
    return dict_user_data

def create_distinct_labels_for_10_clients_mnist_fmnist(dataset, num_users = 10):
    if num_users != 10:
        raise ValueError("create_distinct_labels_for_10_clients_mnist_fmnist method is only called for 10 clients scenario.")
    
    device_label_idx_dict = {i:[] for i in range(num_users)}
    label_to_index = {i:[] for i in range(10)}
    labels = dataset.targets.numpy()

    random.seed(0)

    minimum_num_datapoint = 0

    for i in range(10):
        label_to_index[i] = np.where(labels == i)[0].tolist()
        if i == 0:
            minimum_num_datapoint = len(label_to_index[i])
        else:
            minimum_num_datapoint = min(minimum_num_datapoint, len(label_to_index[i]))

        random.shuffle(label_to_index[i])
    
    for label_index in range(10):
        device_label_idx_dict[label_index] = label_to_index[label_index][:minimum_num_datapoint]

    dict_user_data = {i: [0]* len(device_label_idx_dict[i]) for i in range(num_users)}
    for k, v in device_label_idx_dict.items():
        for lx, l in enumerate(v):
            dict_user_data[k][lx] = dataset[l]
        
    return dict_user_data

def distinct_half(dataset):

    class_indices = defaultdict(list)

    # Collect indices for each class
    for idx, data in enumerate(dataset):
        if isinstance(dataset, (torchvision.datasets.MNIST, torchvision.datasets.FashionMNIST)):
            label = data[1]  # MNIST and Fashion-MNIST
        elif isinstance(dataset, torchvision.datasets.SVHN):
            label = dataset.labels[idx]  # SVHN
        elif isinstance(dataset, (torchvision.datasets.CIFAR10, torchvision.datasets.CIFAR100)):
            label = data[1]  # CIFAR-10 and CIFAR-100
        else:
            raise ValueError("Unsupported dataset type")

        class_indices[label].append(idx)

    # Determine the number of clients and divide them into two groups
    num_clients = 10
    num_classes = len(class_indices)
    half_clients = num_clients // 2
    half_classes = num_classes // 2

    clients_data_ids = {}
    client_classes = {}

    # Assign the first half of the classes to the first half of the clients
    for client_id in range(half_clients):
        client_classes_ids = range(0, half_classes)
        client_indices = []
        for class_id in client_classes_ids:
            num_samples_per_class = len(class_indices[class_id]) // half_clients
            client_indices.extend(class_indices[class_id][:num_samples_per_class])
            # Update class indices after assignment
            class_indices[class_id] = class_indices[class_id][num_samples_per_class:]
        clients_data_ids[client_id] = client_indices
        client_classes[client_id] = list(client_classes_ids)

    # Assign the second half of the classes to the second half of the clients
    for client_id in range(half_clients, num_clients):
        client_classes_ids = range(half_classes, num_classes)
        client_indices = []
        for class_id in client_classes_ids:
            num_samples_per_class = len(class_indices[class_id]) // half_clients
            client_indices.extend(class_indices[class_id][:num_samples_per_class])
            # Update class indices after assignment
            class_indices[class_id] = class_indices[class_id][num_samples_per_class:]
        clients_data_ids[client_id] = client_indices
        client_classes[client_id] = list(client_classes_ids)

    return clients_data_ids, client_classes


def distinct_class_each_device(dataset):

    class_indices = defaultdict(list)
    # Collect indices for each class
    for idx, data in enumerate(dataset):
        if isinstance(dataset, (torchvision.datasets.MNIST, torchvision.datasets.FashionMNIST)):
            label = data[1]  # MNIST and Fashion-MNIST
        elif isinstance(dataset, torchvision.datasets.SVHN):
            label = dataset.labels[idx]  # SVHN
        elif isinstance(dataset, (torchvision.datasets.CIFAR10, torchvision.datasets.CIFAR100)):
            label = data[1]  # CIFAR-10 and CIFAR-100
        else:
            raise ValueError("Unsupported dataset type")

        class_indices[label].append(idx)

    # Determine the number of clients and classes per client
    if isinstance(dataset, torchvision.datasets.CIFAR10):
        num_clients = 10
        num_classes_per_client = 1
    elif isinstance(dataset, torchvision.datasets.CIFAR100):
        num_clients = 10
        num_classes_per_client = 10
    elif isinstance(dataset, (torchvision.datasets.MNIST, torchvision.datasets.FashionMNIST)):
        num_clients = 10
        num_classes_per_client = 1
    elif isinstance(dataset, torchvision.datasets.SVHN):
        num_clients = 10
        num_classes_per_client = 1

    clients_data_ids = {}
    client_classes = {}

    for client_id in range(num_clients):
        # Define the classes for each client
        client_classes_ids = range(client_id * num_classes_per_client, (client_id + 1) * num_classes_per_client)
        client_indices = []
        for class_id in client_classes_ids:
            client_indices.extend(class_indices[class_id])
        # Assign the data loader
        clients_data_ids[client_id] = client_indices
        # Store the classes for the current client
        client_classes[client_id] = client_classes_ids

    return clients_data_ids, client_classes

def Dirichlet_disbuted_classes(dataset, num_clients, alpha):

    if isinstance(dataset, (torchvision.datasets.MNIST, torchvision.datasets.FashionMNIST)):
        targets = np.array(dataset.targets)
    elif isinstance(dataset, torchvision.datasets.SVHN):
        targets = np.array(dataset.labels)
    elif isinstance(dataset, (torchvision.datasets.CIFAR10, torchvision.datasets.CIFAR100)):
        targets = np.array(dataset.targets)
    else:
        raise ValueError("Unsupported dataset type")
    
    num_classes = len(np.unique(targets))
    dirichlet_dist = np.array([dirichlet([alpha] * num_clients) for _ in range(num_classes)])
    
    class_indices = defaultdict(list)
    for idx, label in enumerate(targets):
        class_indices[label].append(idx)
    
    clients_data_ids = {i: [] for i in range(num_clients)}
    client_classes = {i: set() for i in range(num_clients)}  # To keep track of classes for each client
    
    for class_id, dist in enumerate(dirichlet_dist):
        class_idxs = class_indices[class_id]
        np.random.shuffle(class_idxs)
        
        split_indices = np.cumsum(dist) * len(class_idxs)
        split_indices = np.round(split_indices).astype(int)
        split_indices = np.split(class_idxs, split_indices[:-1])
        
        for client_id, idxs in enumerate(split_indices):
            clients_data_ids[client_id].extend(idxs)
            client_classes[client_id].add(class_id)  # Add the class_id to the client's set of classes
    
    return clients_data_ids, client_classes

def create_distinct_half_mnist_fmnist(dataset, num_users = 20):
    if num_users%2 != 0:
        raise ValueError("num_users should be an even number.")
    device_data_idx = {i:[] for i in range(num_users)}
    device_label_idx_dict = {i:[] for i in range(num_users)}
    label_to_index = {i:[] for i in range(10)}
    label_to_chunk_index = {i:[] for i in range(10)}
    labels = dataset.targets.numpy()

    random.seed(0)
    minimum_num_data_first_half = 0
    minimum_num_data_second_half = 0
    for i in range(10):
        label_to_index[i] = np.where(labels == i)[0].tolist()
        if i <= 4:
            if i == 0:
                minimum_num_data_first_half = len(label_to_index[i])
            else:
                minimum_num_data_first_half = min(minimum_num_data_first_half, len(label_to_index[i]))
        else:
            if i == 5:
                minimum_num_data_second_half = len(label_to_index[i])
            else:
                minimum_num_data_second_half = min(minimum_num_data_second_half, len(label_to_index[i]))
        random.shuffle(label_to_index[i])
    
    
    def chunks(lst, n):
        """Yield successive n-sized chunks from lst."""
        for i in range(0, len(lst), n):
            yield lst[i:i + n]
    
    data_needed_first_half_per_client = minimum_num_data_first_half  // (num_users//2)
    data_needed_second_half_per_client = minimum_num_data_second_half // (num_users//2)

    for label_index in range(10):
        if label_index <= 4:
            label_to_chunk_index[label_index] = list(chunks(label_to_index[label_index][:minimum_num_data_first_half], data_needed_first_half_per_client))
        else:
            label_to_chunk_index[label_index] = list(chunks(label_to_index[label_index][:minimum_num_data_second_half],data_needed_second_half_per_client))

    for label_index in range(10):
        if label_index < 5:
            for d_index in range(num_users//2):
                device_label_idx_dict[d_index] += label_to_chunk_index[label_index][d_index]
        else:
            for d_index in range(num_users//2, num_users):
                device_label_idx_dict[d_index] += label_to_chunk_index[label_index][d_index - num_users//2]

    dict_user_data = {i: [0]* len(device_label_idx_dict[i]) for i in range(num_users)}
    for k, v in device_label_idx_dict.items():
        for lx, l in enumerate(v):
            dict_user_data[k][lx] = dataset[l]
        
    return dict_user_data

def create_labels_dirichlet(dataset, dataset_name, n_parties, beta, seed):
    np.random.seed(seed)
    random.seed(seed)

    n_train = len(dataset)
    if dataset_name == "SVHN":
        y_train = dataset.labels
    else:
        y_train = dataset.targets.numpy()

    min_size = 0
    min_require_size = 10
    n_classes = 10
    if dataset in ('celeba', 'covtype', 'a9a', 'rcv1', 'SUSY'):
        n_classes = 2
        # min_require_size = 100
    if dataset == 'cifar100':
        n_classes = 100
    elif dataset == 'tinyimagenet':
        n_classes = 200

    device_data_idx = {}
    dict_user_data = {i: [] for i in range(n_parties)}
    device_label_idx_dict = {i: [] for i in range(n_parties)}
    while min_size < min_require_size:
        idx_batch = [[] for _ in range(n_parties)]
        label_remaining_idx_dict = {i: set() for i in range(n_classes)}
        for k in range(n_classes):
            idx_k = np.where(y_train == k)[0]
            deducted_idx = np.random.choice(idx_k, 0, replace = False)
            label_remaining_idx_dict[k] = label_remaining_idx_dict[k].union(set(deducted_idx))
            idx_k  = np.setdiff1d(idx_k, deducted_idx)
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(beta, n_parties))
            proportions = np.array([p * (len(idx_j) < n_train / n_parties) for p, idx_j in zip(proportions, idx_batch)])
            proportions = proportions / proportions.sum()
            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])
        

    for j in range(n_parties):
        dict_user_data[j] = [0] * len(idx_batch[j])
        np.random.shuffle(idx_batch[j])
        device_data_idx[j] = idx_batch[j]
        for mx, m in enumerate(idx_batch[j]):
            dict_user_data[j][mx] = dataset[m]
            device_label_idx_dict[j] = list(set(device_label_idx_dict[j]) | {y_train[m]})

    return dict_user_data, device_label_idx_dict, label_remaining_idx_dict, device_data_idx

def create_labels_dirichelt_another(dataset, dataset_name, n_client, unbalanced_sgm, beta):
    np.random.seed(0)
    if dataset_name == "SVHN":
        y_train = dataset.labels
    else:
        y_train = dataset.targets.numpy()
    
    n_class = len(set(y_train))
    n_data_per_clnt = int((len(y_train)) / n_client)
    if unbalanced_sgm != 0:
        # Draw from lognormal distribution
        clnt_data_list = (np.random.lognormal(mean=np.log(n_data_per_clnt), sigma=unbalanced_sgm, size=n_client))
        clnt_data_list = (clnt_data_list/np.sum(clnt_data_list)*len(y_train)).astype(int)
        diff = np.sum(clnt_data_list) - len(y_train)

        # Add/Subtract the excess number starting from first client
        if diff!= 0:
            for clnt_i in range(n_client):
                if clnt_data_list[clnt_i] > diff:
                    clnt_data_list[clnt_i] -= diff
                    break
    else:
        clnt_data_list = (np.ones(n_client) * n_data_per_clnt).astype(int) 

    dict_user_data_dict = {}
    device_data_idx_dict = {}
    device_data_to_be_added_dict = {}
    
    for i in range(n_client):
        dict_user_data_dict[i] = [0] * n_data_per_clnt
        device_data_idx_dict[i] = [0] * n_data_per_clnt

    cls_priors   = np.random.dirichlet(alpha=[beta]*n_class, size=n_client)
    prior_cumsum = np.cumsum(cls_priors, axis=1)

    class_to_idx_list =[ np.where(y_train==i)[0] for i in range(n_class)]
    cls_amount = [len(class_to_idx_list[i]) for i in range(n_class)]
    
    while(np.sum(clnt_data_list)!=0):
        curr_clnt = np.random.randint(n_client)
        # If current node is full resample a client
        # print('Remaining Data: %d' %np.sum(clnt_data_list))
        if clnt_data_list[curr_clnt] <= 0:
            continue
        clnt_data_list[curr_clnt] -= 1
        curr_prior = prior_cumsum[curr_clnt]
        while True:
            cls_label = np.argmax(np.random.uniform() <= curr_prior)
            # Redraw class label if trn_y is out of that class
            if cls_amount[cls_label] <= 0:
                continue
            cls_amount[cls_label] -= 1
            dict_user_data_dict[curr_clnt][clnt_data_list[curr_clnt]] = dataset[class_to_idx_list[cls_label][cls_amount[cls_label]]]
            device_data_idx_dict[curr_clnt][clnt_data_list[curr_clnt]] = class_to_idx_list[cls_label][cls_amount[cls_label]]

            break
    
    for i in range(n_client):
        selected_data_idx = random.sample(range(len(dict_user_data_dict[i])), 600)
        remaining_data_idx = set(range(len(dict_user_data_dict[i]))) - set(selected_data_idx)
        device_data_to_be_added_dict[i] = list(itemgetter(*selected_data_idx)(dict_user_data_dict[i]))
        dict_user_data_dict[i] =  list(itemgetter(*remaining_data_idx)(dict_user_data_dict[i]))
        device_data_idx_dict[i] = remaining_data_idx
        
    return dict_user_data_dict, device_data_idx_dict, device_data_to_be_added_dict


def filter_dataset_by_classes(dataset, class_set, batch_size):
    """ Filter dataset to only include samples with labels in the given class set """
    
    filtered_indices = []

    if hasattr(dataset, 'targets'):
        # For MNIST, Fashion-MNIST, CIFAR-10, CIFAR-100
        targets = dataset.targets
    elif hasattr(dataset, 'labels'):
        # For SVHN
        targets = dataset.labels
    else:
        raise ValueError("Dataset does not have 'targets' or 'labels' attribute")
    
    if isinstance(targets, list):
        # It's already a list, do nothing
        pass
    elif isinstance(targets, torch.Tensor):
        # Convert PyTorch tensor to list
        targets =  targets.tolist()
    elif isinstance(targets, np.ndarray):
        # Convert NumPy array to list
        targets = targets.tolist()
    else:
        raise TypeError("The variable is neither a list, tensor, nor a NumPy array.")
    
    for idx, label in enumerate(targets):
        if label in class_set:
            filtered_indices.append(idx)
    
    return DataLoader(Subset(dataset, filtered_indices), batch_size=batch_size, shuffle=True)

def generate_list(length):
    if length == 1:
        return [1.0]
    elif length % 2 == 0:  # Even length
        value = 1.0 / (length / 2)
        return [value if i % 2 == 0 else 0.0 for i in range(length)]
    else:  # Odd length
        value = 1.0 / ((length - 1) / 2)
        return [0.0 if i % 2 == 0 else value for i in range(length)]

def distribute_labels_in_batches(dataset, num_clients):
    """
    Distribute labels in batches for datasets like MNIST, Fashion-MNIST, SVHN, CIFAR-10, and CIFAR-100.
    For 10-label datasets (MNIST, Fashion-MNIST, SVHN, CIFAR-10), each client gets 2 labels.
    For 100-label datasets (CIFAR-100), each client gets 20 labels.
    
    Args:
    - dataset: The dataset to be partitioned.
    - num_clients: The number of clients (e.g., 10 clients).

    Returns:
    - clients_data_ids: A dictionary where keys are device IDs, and values are lists of data indices.
    - client_classes: A dictionary where keys are device IDs, and values are sets of assigned labels.
    """
    # Step 1: Extract labels based on dataset type
    if hasattr(dataset, 'targets'):
        # For MNIST, Fashion-MNIST, CIFAR-10, CIFAR-100, we convert to numpy for indexing
        targets = np.array(dataset.targets)
    elif hasattr(dataset, 'labels'):
        # For SVHN
        targets = np.array(dataset.labels)
    else:
        raise ValueError("Dataset does not have 'targets' or 'labels' attribute")

    # Determine the number of labels in the dataset
    num_classes = len(np.unique(targets))
    
    # Step 2: Group data by label
    class_indices = defaultdict(list)
    for idx in range(len(dataset)):
        label = targets[idx]
        class_indices[label].append(idx)

    # Shuffle the labels
    available_labels = list(class_indices.keys())
    np.random.shuffle(available_labels)

    # Prepare data structures to store client data and assigned labels
    clients_data_ids = {i: [] for i in range(num_clients)}
    client_classes = {i: set() for i in range(num_clients)}
    client_batch_counts = {i: 0 for i in range(num_clients)}  # To track how many batches each client has been assigned

    # Helper function to assign data and ensure non-empty batches
    def assign_data_to_clients(half_1, half_2, chosen_clients, batch):
        client_1, client_2 = chosen_clients
        if len(half_1) > 0 and len(half_2) > 0:
            clients_data_ids[client_1].extend(half_1)
            clients_data_ids[client_2].extend(half_2)
            client_classes[client_1].update(batch)
            client_classes[client_2].update(batch)
            return True
        return False

    # If dataset has 10 labels (MNIST, Fashion-MNIST, SVHN, CIFAR-10)
    if num_classes == 10:
        # Create 10 batches, each with 1 label
        label_batches = [[label] for label in available_labels]
        
        # Split each batch into halves and assign to two random clients
        for batch in label_batches:
            half_1 = []
            half_2 = []

            for label in batch:
                indices = class_indices[label]
                mid_point = len(indices) // 2
                half_1.extend(indices[:mid_point])
                half_2.extend(indices[mid_point:])

            # Select two random clients that have not already received 2 batches
            available_clients = [client for client, count in client_batch_counts.items() if count < 2]
            chosen_clients = np.random.choice(available_clients, 2, replace=False)

            # Ensure clients get non-empty batches
            if assign_data_to_clients(half_1, half_2, chosen_clients, batch):
                client_batch_counts[chosen_clients[0]] += 1
                client_batch_counts[chosen_clients[1]] += 1
            else:
                raise ValueError("One of the clients received no data!")

    # If dataset has 100 labels (CIFAR-100)
    elif num_classes == 100:
        num_labels_per_batch = 10
        num_batches = 10

        # Step 3: Create batches of 10 non-overlapping labels
        label_batches = [available_labels[i * num_labels_per_batch: (i + 1) * num_labels_per_batch] for i in range(num_batches)]

        # Step 4: Assign each batch to two randomly selected clients
        for batch in label_batches:
            # Split the data into two halves
            half_1 = []
            half_2 = []

            for label in batch:
                indices = class_indices[label]
                mid_point = len(indices) // 2
                half_1.extend(indices[:mid_point])
                half_2.extend(indices[mid_point:])

            # Select two random clients that have not already received 2 batches
            available_clients = [client for client, count in client_batch_counts.items() if count < 2]
            chosen_clients = np.random.choice(available_clients, 2, replace=False)

            # Ensure clients get non-empty batches
            if assign_data_to_clients(half_1, half_2, chosen_clients, batch):
                client_batch_counts[chosen_clients[0]] += 1
                client_batch_counts[chosen_clients[1]] += 1
            else:
                raise ValueError("One of the clients received no data!")

    else:
        raise ValueError(f"Unsupported dataset with {num_classes} labels.")

    return clients_data_ids, client_classes

def calculate_dynamic_similarity_threshold(client_classes):
    """
    Automatically calculate a reasonable similarity threshold based on the distribution of labels
    according to the way labels are distributed in batches.
    
    Args:
    - client_classes: A dictionary where each key is a device ID and the value is a set of labels for that device.
    
    Returns:
    - similarity_threshold: A dynamic threshold for acceptable similarity between device groups.
    """
    total_labels = set()
    labels_per_device = []

    # Collect all unique labels and the number of labels per device
    for labels in client_classes.values():
        total_labels.update(labels)
        labels_per_device.append(len(labels))

    # Average number of labels per device
    avg_labels_per_device = sum(labels_per_device) / len(labels_per_device)
    
    # Total number of unique labels
    num_total_labels = len(total_labels)
    
    # Expected overlap: each client receives data from a batch of labels shared with one other client
    # Overlap should be the number of shared labels between two clients divided by total unique labels
    num_shared_batches = 2  # Each batch is shared between two clients
    expected_overlap = avg_labels_per_device / num_total_labels * (1 / num_shared_batches)
    
    # Adjust the threshold to allow some flexibility based on batch distribution
    similarity_threshold = expected_overlap * 1.5  # Factor can be adjusted based on experimentation
    
    # Ensure the threshold is within [0, 1]
    return min(max(similarity_threshold, 0.5), 1)  # Ensure it doesn't go below 0.1 or above 1

def generate_device_lists(client_classes, num_groups, min_group_size=1, max_group_size=10, max_attempts=50, initial_similarity_threshold=0.5):
    groups = []

    # Function to calculate similarity between two sets of labels
    def calculate_similarity(set1, set2):
        return len(set1.intersection(set2)) / max(len(set1), 1)

    # Shuffle the list of all devices to start with randomness
    all_devices = list(client_classes.keys())
    random.shuffle(all_devices)

    # Dynamically calculate the similarity threshold based on the client_classes label distribution
    dynamic_similarity_threshold = calculate_dynamic_similarity_threshold(client_classes)

    for _ in range(num_groups):
        # Create a new group with a random size
        group_size = random.randint(min_group_size, max_group_size)
        group = random.sample(all_devices, group_size)

        # Calculate the union of labels for the current group
        current_union = set()
        for device in group:
            current_union.update(client_classes[device])

        # Check uniqueness against the previous group
        attempt = 0
        similarity_threshold = initial_similarity_threshold
        max_similarity_increase = 0.05  # Start with small increases

        if groups:
            previous_union = set()
            for device in groups[-1]:
                previous_union.update(client_classes[device])
            
            # Adjust group until the union is sufficiently different from the last one
            while calculate_similarity(current_union, previous_union) > similarity_threshold:
                if attempt >= max_attempts:
                    # Relax the similarity threshold if max attempts are reached
                    similarity_threshold += max_similarity_increase
                    attempt = 0  # Reset attempts and try again with a relaxed threshold
                    if similarity_threshold >= 1.0:  # Cap the threshold to avoid excessive relaxation
                        break

                group = random.sample(all_devices, group_size)
                current_union = set()
                for device in group:
                    current_union.update(client_classes[device])
                attempt += 1

        # Check similarity with all past groups (excluding the previous one)
        if len(groups) > 1 and attempt < max_attempts:
            found_similar_past_group = False
            for past_group in groups[:-1]:
                past_union = set()
                for device in past_group:
                    past_union.update(client_classes[device])

                # Use dynamic similarity threshold here for the check
                if calculate_similarity(current_union, past_union) >= dynamic_similarity_threshold:
                    found_similar_past_group = True
                    break  # Stop checking once we find a similar past group

            # If no similar group is found, skip this iteration and try generating a new group
            if not found_similar_past_group:
                continue  # Skip this iteration and try another group

        # Add the group to the list of groups
        groups.append(group)

    return groups

def calculate_dynamic_similarity_threshold_distinct(client_classes):
    """
    Automatically calculate a reasonable similarity threshold based on the distribution of labels.
    Each device has only one unique label.
    
    Args:
    - client_classes: A dictionary where each key is a device ID and the value is a set of labels for that device.
    
    Returns:
    - similarity_threshold: A dynamic threshold for acceptable similarity between device groups.
    """
    total_labels = set()
    
    # Collect all unique labels
    for labels in client_classes.values():
        total_labels.update(labels)

    # Total number of unique labels
    num_total_labels = len(total_labels)
    
    # Since each device only has one label, the average number of labels per device is 1
    avg_labels_per_device = 1
    
    # Expected overlap: Minimal since each device only has one unique label
    expected_overlap = avg_labels_per_device / num_total_labels
    
    # Adjust the threshold to allow flexibility based on minimal overlap
    similarity_threshold = expected_overlap * 1.5  # Adjust based on experimentation
    
    # Ensure the threshold is within [0.5, 1] since overlap is expected to be very low
    return min(max(similarity_threshold, 0.5), 1)

def calculate_jaccard_similarity(set1, set2):
    """
    Calculate the Jaccard similarity between two sets.
    """
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union if union != 0 else 0

def generate_device_lists_distinct(client_classes, num_groups, min_group_size=1, max_group_size=10, low_similarity_threshold=0.3, high_similarity_threshold=0.7, max_attempts=100):
    """
    Generate device lists such that each group has low similarity with the previous group and high similarity with at least one past group.
    
    Args:
    - client_classes: A dictionary where each key is a device ID and the value is a set of labels for that device.
    - num_groups: The number of groups to generate.
    - min_group_size: Minimum size of each group.
    - max_group_size: Maximum size of each group.
    - low_similarity_threshold: Maximum allowed similarity with the previous group.
    - high_similarity_threshold: Minimum required similarity with a past group (not the previous one).
    - max_attempts: Maximum number of attempts to generate a valid group.
    
    Returns:
    - A list of groups, where each group is a list of device indices.
    """
    groups = []
    all_devices = list(client_classes.keys())

    for group_index in range(num_groups):
        for attempt in range(max_attempts):
            # Randomly select a group size and devices
            group_size = random.randint(min_group_size, max_group_size)
            group = random.sample(all_devices, group_size)
            
            # Get the union of labels for the current group
            current_union = set()
            for device in group:
                current_union.update(client_classes[device])
            
            # Check similarity with the previous group (if exists) for low similarity
            if group_index > 0:
                previous_group_union = set()
                for device in groups[-1]:
                    previous_group_union.update(client_classes[device])
                
                similarity_with_previous = calculate_jaccard_similarity(current_union, previous_group_union)
                if similarity_with_previous > low_similarity_threshold:
                    # If too similar to the previous group, retry
                    continue
            
            # Check similarity with any past group (except the previous one) for high similarity
            if group_index > 1:
                found_similar_past_group = False
                for past_group in groups[:-1]:
                    past_group_union = set()
                    for device in past_group:
                        past_group_union.update(client_classes[device])

                    similarity_with_past = calculate_jaccard_similarity(current_union, past_group_union)
                    if similarity_with_past >= high_similarity_threshold:
                        found_similar_past_group = True
                        break  # No need to check further if one similar group is found

                if not found_similar_past_group:
                    # If no past group is similar enough, retry
                    continue
            
            # If both conditions are satisfied, add the group
            groups.append(group)
            break  # Group created successfully, go to the next group
        else:
            # If no valid group was created after max_attempts, raise an error
            raise RuntimeError(f"Failed to create a valid group after {max_attempts} attempts")

    return groups

def distribute_labels_slight_overlap_10_clients(dataset):
    """
    Distributes labels across clients ensuring two sets of labels with 20% overlap, with specified distribution rules.
    
    Args:
    - dataset: The dataset object (e.g., MNIST, CIFAR-100, etc.).
    - total_clients: The number of clients to distribute data across (should be 10 in this case).
    
    Returns:
    - clients_data_ids: A dictionary where keys are device IDs, and values are lists of data indices.
    - client_classes: A dictionary where keys are device IDs, and values are sets of assigned labels.
    """
    
    # Step 1: Extract labels based on dataset type
    if hasattr(dataset, 'targets'):
        # For MNIST, Fashion-MNIST, CIFAR-10, CIFAR-100, we convert to numpy for indexing
        targets = np.array(dataset.targets)
    elif hasattr(dataset, 'labels'):
        # For SVHN
        targets = np.array(dataset.labels)
    else:
        raise ValueError("Dataset does not have 'targets' or 'labels' attribute")

    # Determine the number of labels in the dataset
    num_labels = len(np.unique(targets))
    
    # Step 2: Select two sets of labels with 20% overlap
    num_labels_per_set = int(0.6 * num_labels)  # 60% of the total labels
    overlap_size = int(0.2 * num_labels)  # Overlap size (20% of total labels)
    
    # First set of labels: Select first 60% labels
    labels_set1 = list(range(num_labels_per_set))
    
    # Second set of labels: Overlap is the last 'overlap_size' labels from labels_set1
    overlap_labels = labels_set1[-overlap_size:]
    
    # Fill the rest of labels_set2 with new labels (not in labels_set1 but still from the label range)
    remaining_labels = [label for label in range(num_labels) if label not in labels_set1]
    labels_set2 = overlap_labels + remaining_labels[:num_labels_per_set - overlap_size]
    
    # Group data by label
    class_indices = defaultdict(list)
    for idx, label in enumerate(targets):
        class_indices[label].append(idx)

    # Prepare data structures to store client data
    clients_data_ids = defaultdict(list)
    client_classes = defaultdict(set)
    
    # Helper function to distribute data for non-overlapping labels
    def distribute_non_overlap(labels, start_client, end_client):
        for label in labels:
            total_data = class_indices[label]
            clients_data = np.array_split(total_data, end_client - start_client + 1)
            for i in range(start_client, end_client + 1):
                clients_data_ids[i].extend(clients_data[i - start_client])
                client_classes[i].add(label)

    # Helper function to distribute data for overlapping labels
    def distribute_overlap(labels, start_client_first_half, end_client_first_half, start_client_second_half, end_client_second_half):
        for label in labels:
            total_data = class_indices[label]
            half_data = len(total_data) // 2
            clients_data_1 = np.array_split(total_data[:half_data], end_client_first_half - start_client_first_half + 1)
            clients_data_2 = np.array_split(total_data[half_data:], end_client_second_half - start_client_second_half + 1)

            # Distribute the first half to the first set of clients
            for i in range(start_client_first_half, end_client_first_half + 1):
                clients_data_ids[i].extend(clients_data_1[i - start_client_first_half])
                client_classes[i].add(label)
            
            # Distribute the second half to the second set of clients
            for i in range(start_client_second_half, end_client_second_half + 1):
                clients_data_ids[i].extend(clients_data_2[i - start_client_second_half])
                client_classes[i].add(label)

    # Step 3: Distribute data for the first set of labels (non-overlapping and overlapping)
    distribute_non_overlap([label for label in labels_set1 if label not in overlap_labels], 0, 4)
    distribute_overlap(overlap_labels, 0, 4, 5, 9)

    # Step 4: Distribute data for the second set of labels (non-overlapping and overlapping)
    distribute_non_overlap([label for label in labels_set2 if label not in overlap_labels], 5, 9)
    distribute_overlap(overlap_labels, 0, 4, 5, 9)

    return clients_data_ids, client_classes

def print_client_data_distribution(clients_data_ids, client_classes, dataset):
    """
    Print the labels each client has and how many datapoints for each label.

    Args:
    - clients_data_ids: A dictionary where keys are client IDs, and values are lists of data indices.
    - client_classes: A dictionary where keys are client IDs, and values are sets of assigned labels.
    - dataset: The dataset from which data is being distributed (used to map data indices back to labels).
    """
    
    # Step 1: Extract labels based on dataset type
    if hasattr(dataset, 'targets'):
        # For MNIST, Fashion-MNIST, CIFAR-10, CIFAR-100, we convert to numpy for indexing
        targets = np.array(dataset.targets)
    elif hasattr(dataset, 'labels'):
        # For SVHN
        targets = np.array(dataset.labels)
    else:
        raise ValueError("Dataset does not have 'targets' or 'labels' attribute")

    # Step 2: Count the number of datapoints for each label per client
    for client_id, data_indices in clients_data_ids.items():
        print(f"\nClient {client_id}:")
        
        # Count the number of datapoints for each label
        label_counts = defaultdict(int)
        for index in data_indices:
            label = targets[index]
            label_counts[label] += 1
        
        # Print the labels and the corresponding number of datapoints
        for label in sorted(client_classes[client_id]):
            print(f"  Label {label}: {label_counts[label]} datapoints")

