import copy
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        

def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
                
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
                            
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def replace_with_samples(arr1, arr2, alpha, capacity):
    arr1 = np.array(arr1)
    arr2 = np.array(arr2)
    if len(arr1) < capacity:
        if len(arr2) < capacity - len(arr1):
            sampled_elements = arr2
        else:
            sampled_elements = np.random.choice(arr2, capacity - len(arr1), replace=False)
        new_arr = np.concatenate((arr1, sampled_elements))
        return new_arr
    else:
        num_remove = int(len(arr1) * alpha)
        remove_indices = np.random.choice(len(arr1), num_remove, replace=False)
        
        sampled_elements = np.random.choice(arr2, num_remove, replace=False)

        new_arr = np.delete(arr1, remove_indices)
        new_arr = np.concatenate((new_arr, sampled_elements))

        return new_arr

def label_distribution(train_labels, data_id):
    n_classes = train_labels.max() + 1
    
    class_distribution = np.zeros(n_classes)
    
    for index in data_id:
        class_distribution[train_labels[index]]+=1
    class_probabilities = class_distribution / len(data_id)
    return class_probabilities

def jensen_shannon_divergence(p, q):
    p = p / p.sum()
    q = q / q.sum()
    epsilon = 1e-10
    p = np.maximum(p, epsilon)
    q = np.maximum(q, epsilon)
    m = 0.5 * (p + q)
    
    kl_p_m = np.sum(p * np.log2(p / m))  # KL(P || M)
    kl_q_m = np.sum(q * np.log2(q / m))  # KL(Q || M)
    
    # JSD = 0.5 * (KL(P || M) + KL(Q || M))
    jsd = 0.5 * (kl_p_m + kl_q_m)
    return jsd

def kl_divergence(p, q):
    p = p / p.sum()
    q = q / q.sum()
    epsilon = 1e-10
    p = np.maximum(p, epsilon)
    q = np.maximum(q, epsilon)
    kl_divergence = np.sum(p * np.log2(p / q))
    return kl_divergence

def l2_distance(p, q):
    p = p / p.sum()
    q = q / q.sum()
    l2 = np.linalg.norm(p - q, ord=2)
    return l2

def l1_distance(p, q):
    p = p / p.sum()
    q = q / q.sum()
    l1 = np.linalg.norm(p - q, ord=1)
    return l1

def distance(metric, *params):
    if metric == 'js':
        return jensen_shannon_divergence(*params)
    elif metric == 'kl':
        return kl_divergence(*params)
    elif metric == 'l2':
        return l2_distance(*params)
    elif metric == 'l1':
        return l1_distance(*params)
    else:
        raise ValueError(f"Unknown distance type: {metric}")
    
def relu(x):
    return max(x, 0)

def client_probabilities(N):
    return np.random.normal(loc=0.2, scale=0.01, size=N)

def initialize_state_probabilities(N, M, S, set):
    pi_list = []
    data_map_list = []
    nonzero_indices_list = []
    for i in range(N):
        if set == 'full':
            pi = np.random.rand(M)
            nonzero_indices_list.append(np.arange(M))
        else:
            if i < N / 2:  # partial-access: only access to the high-heterogeneity states
                pi = np.zeros(M)
                half_start = M * 2 // 3
                available_indices = np.arange(half_start, M)
                nonzero_indices = np.random.choice(available_indices, S, replace=False)
                pi[nonzero_indices] = np.random.rand(S)
                nonzero_indices_list.append(nonzero_indices)
            else:
                pi = np.zeros(M)
                nonzero_indices = np.random.choice(M, S, replace=False)
                pi[nonzero_indices] = np.random.rand(S)
                nonzero_indices_list.append(nonzero_indices)
            
        pi /= pi.sum()
        pi_list.append(pi)

        data_map_list.append(np.empty(0, dtype=int))
    return pi_list, data_map_list, nonzero_indices_list

def compute_alpha(N, M, pi_list, distance_list, nonzero_indices_list, a1, b1, alpha):
    alpha_matrix = np.zeros((N, M))
    for n in range(N):
        alpha_list = np.zeros(M)
        scores_list = np.zeros(M)
        for m in range(M):
            if pi_list[n][m] != 0:
                scores_list[m] = (1 / M - a1 * distance_list[m] + b1) / (1 + (1 - alpha) * pi_list[n][m] / alpha)
            else:
                scores_list[m] = 0
        converged = False
        remaining_indices = nonzero_indices_list[n].copy().tolist()
        remaining_alpha = alpha
        while not converged and len(remaining_indices) > 0:
            converged = True
            # Calculate sum of relu(scores) for remaining indices
            sum_alpha = sum(pi_list[n][m] * relu(scores_list[m]) for m in remaining_indices)

            if sum_alpha == 0:
                # If sum_alpha is 0, set all remaining indices to 0
                sum_pi = sum(pi_list[n][m] for m in remaining_indices)
                for m in remaining_indices:
                    alpha_list[m] = remaining_alpha / sum_pi
                break
            else:
                for m in remaining_indices:
                    # Calculate alpha value
                    alpha_val = min(remaining_alpha * relu(scores_list[m]) / sum_alpha, 1)
                    alpha_list[m] = alpha_val
                    
                    # Check if we need another iteration
                    if alpha_val >= 1:
                        converged = False
                        remaining_alpha -= pi_list[n][m]
                        remaining_indices.remove(m)
                        break
        alpha_matrix[n] = alpha_list
    return alpha_matrix

def compute_weights(N, M, pi_list, distance_list, alpha_matrix, gamma, client_probs, a2, b2, alpha, metric):
    client_weights = []
    for n in range(N):
        if metric == 'kl':
            G = 1.0
        elif metric == 'l1':
            G = 0.5
        elif metric == 'l2': 
            G = 0.2
        elif metric == 'js':
            G = 0.1
        list1 = [pi_list[n][m] * alpha_matrix[n][m] * distance_list[m] for m in range(M)]
        list2 = [pi_list[n][m] * alpha_matrix[n][m] ** 2 for m in range(M)]
        list3 = [pi_list[n][m] ** 2 * alpha_matrix[n][m] ** 2 for m in range(M)]
        list4 = [(pi_list[n][m] * alpha_matrix[n][m] / alpha - 1 / M) ** 2 for m in range(M)]
        S = (1 - gamma) / alpha * sum(list1) + 2 * gamma * G / (1 - (1 - alpha) ** 2) * (2 * sum(list2) - sum(list3) + alpha ** 2) + 2 * gamma * G * sum(list4)
        weight = 1 / client_probs[n] - a2 * S + b2
        client_weights.append(weight)
    return client_weights



if __name__ == "__main__":
    print(client_probabilities(30))