import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
import random
import os
from scipy.spatial.distance import pdist, squareform
from matplotlib import rcParams

NUM_CLIENTS = 60
NUM_BINARY = 50
NUM_FULL = 10
NOISY_CLIENT_NUM = 25

EPOCHS = 1
ROUNDS = 30
BATCH_SIZE = 64
LR = 0.01
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SAVE_ROUND = 30
SAVE_PATH = "./saved_models"
os.makedirs(SAVE_PATH, exist_ok=True)

K_KDE = 5

SEED = 60

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=1.0):
        self.mean = mean
        self.std = std
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

class AddSaltPepperNoise(object):
    def __init__(self, prob=0.3):
        self.prob = prob
    def __call__(self, tensor):
        noise = torch.rand(tensor.size())
        tensor = tensor.clone()
        tensor[noise < self.prob/2] = 0.0
        tensor[noise > 1 - self.prob/2] = 1.0
        return tensor

class RandomPixelFlip(object):
    def __init__(self, ratio=0.2):
        self.ratio = ratio
    def __call__(self, tensor):
        total = tensor.numel()
        idx = torch.randperm(total)[:int(total * self.ratio)]
        tensor = tensor.view(-1)
        tensor[idx] = 1 - tensor[idx]
        return tensor.view(1, 28, 28)

def flip_label(y, p=0.5):
    if random.random() < p:
        return random.randint(0, 9)
    return y

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
    def forward(self, x):
        x = torch.relu(torch.max_pool2d(self.conv1(x), 2))
        x = torch.relu(torch.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

set_seed(SEED)

clean_transform = transforms.Compose([transforms.ToTensor()])
train_dataset_clean = datasets.MNIST("./data", train=True, download=True, transform=clean_transform)
test_dataset = datasets.MNIST("./data", train=False, download=True, transform=clean_transform)

labels = np.array(train_dataset_clean.targets)
label_indices = {l: np.where(labels == l)[0] for l in range(10)}
for l in label_indices:
    np.random.shuffle(label_indices[l])

def noisy_clients_setup():
    noisy_clients = np.random.choice(NUM_BINARY, size=NOISY_CLIENT_NUM, replace=False)
    print("🎨 noisy 客户端:", noisy_clients)
    return noisy_clients

def label_skew_subset_indices():
    rand0 = rand1 = 600
    start0 = np.random.randint(0, len(label_indices[0]) - rand0)
    start1 = np.random.randint(0, len(label_indices[1]) - rand1)
    idx0 = label_indices[0][start0:start0 + rand0]
    idx1 = label_indices[1][start1:start1 + rand1]
    EXTRA_LABEL_NUM = np.random.randint(1, 6)
    extra_labels = np.random.choice(range(2, 10), size=EXTRA_LABEL_NUM, replace=False)
    extra_indices = []
    for lbl in extra_labels:
        total_lbl = len(label_indices[lbl])
        max_take = min(80, total_lbl)
        min_take = min(20, max_take)
        take = np.random.randint(min_take, max_take + 1)
        start_lbl = 0 if total_lbl - take <= 1 else np.random.randint(0, total_lbl - take)
        extra_indices.extend(label_indices[lbl][start_lbl:start_lbl + take])
    extra_indices = np.array(extra_indices)
    subset_indices = np.concatenate([idx0, idx1, extra_indices])
    np.random.shuffle(subset_indices)
    return subset_indices

def label_nonskew_subset_indices(i):
    indices_2_9 = []
    for label in range(2, 10):
        label_idx = np.where(train_dataset_clean.targets == label)[0]
        indices_2_9.extend(label_idx)
    indices_2_9 = np.array(indices_2_9)
    np.random.shuffle(indices_2_9)
    indices_0 = np.where(train_dataset_clean.targets == 0)[0]
    indices_1 = np.where(train_dataset_clean.targets == 1)[0]
    indices_0_1 = np.concatenate([indices_0, indices_1])
    np.random.shuffle(indices_0_1)
    samples_per_client = len(indices_2_9) // NUM_FULL
    subset_main = indices_2_9[i * samples_per_client : (i + 1) * samples_per_client]
    extra_num = np.random.randint(10, 31)
    extra_num = min(extra_num, len(indices_0_1))
    extra_indices = np.random.choice(indices_0_1, size=extra_num, replace=False)
    subset = np.concatenate([subset_main, extra_indices])
    np.random.shuffle(subset)
    return subset

def clients_datasets_setup(noisy_clients, is_feature_skew=True, is_label_skew=True):
    client_datasets = []
    for i in range(NUM_BINARY):
        subset_indices = label_skew_subset_indices() if is_label_skew else label_nonskew_subset_indices(i)
        if is_feature_skew and i in noisy_clients:
            noisy_transform = transforms.Compose([
                transforms.ToTensor(),
                AddGaussianNoise(std=0.4),
                AddSaltPepperNoise(prob=0.1),
                RandomPixelFlip(ratio=0.1),
            ])
            noisy_dataset = datasets.MNIST("./data", train=True, transform=noisy_transform)
            client_datasets.append(Subset(noisy_dataset, subset_indices))
        else:
            client_datasets.append(Subset(train_dataset_clean, subset_indices))
    for i in range(NUM_FULL):
        subset = label_nonskew_subset_indices(i)
        client_datasets.append(Subset(train_dataset_clean, subset))
    return client_datasets

def get_loader(dataset):
    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

def local_train(model, loader):
    model.train()
    opt = optim.SGD(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()
    for _ in range(EPOCHS):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            opt.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            opt.step()
    return model.state_dict()

def test(model, loader):
    correct, total = 0, 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            pred = model(x).argmax(dim=1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
    return correct / total

def compute_kde_weights(local_weights, k):
    flat = np.array([torch.cat([p.flatten() for p in w.values()]).numpy() for w in local_weights])
    dist_matrix = squareform(pdist(flat))
    h = np.median(dist_matrix) + 1e-12
    N = len(local_weights)
    kde_vals = np.zeros(N)
    for i in range(N):
        idx = np.argsort(dist_matrix[i])[1:k+1]
        d = dist_matrix[i, idx]
        kde_vals[i] = np.mean(np.exp(-d*d/(2*h*h)))
    kde_vals += 1e-12
    binary_kde_vals = kde_vals[:NUM_BINARY] / kde_vals[:NUM_BINARY].sum()
    full_kde_vals = kde_vals[NUM_BINARY:] / kde_vals[NUM_BINARY:].sum()
    bw = np.mean(kde_vals[:NUM_BINARY])
    fw = np.mean(kde_vals[NUM_BINARY:])
    binary_kde_vals *= bw / (bw + fw)
    full_kde_vals *= fw / (bw + fw)
    kde_vals = np.concatenate([binary_kde_vals, full_kde_vals])
    return np.ones(len(kde_vals)) / len(kde_vals) if k == 0 else kde_vals

def fed_kde_avg(local_weights, kde_weights):
    keys = local_weights[0].keys()
    return {k: sum(kde_weights[i] * local_weights[i][k] for i in range(len(local_weights))) for k in keys}

test_loader = get_loader(test_dataset)
noisy_clients = noisy_clients_setup()

def FL_training(k_value, is_feature_skew, is_label_skew):
    client_datasets = clients_datasets_setup(noisy_clients, is_feature_skew, is_label_skew)
    global_model = CNN().to(DEVICE)
    global_acc = []
    for _ in range(ROUNDS):
        local_weights = []
        for i in range(NUM_CLIENTS):
            local_model = CNN().to(DEVICE)
            local_model.load_state_dict(global_model.state_dict())
            loader = get_loader(client_datasets[i])
            local_weights.append(local_train(local_model, loader))
        kde_weights = compute_kde_weights(local_weights, k_value)
        global_model.load_state_dict(fed_kde_avg(local_weights, kde_weights))
        global_acc.append(test(global_model, test_loader))
    return global_acc

acc_FedAvg = FL_training(0, True, True)
acc_LoMar = FL_training(60, True, True)
acc_FedLC = FL_training(60, True, False)
acc_FedRDN = FL_training(60, False, True)
acc_FedKde = FL_training(8, True, True)

plt.figure(figsize=(6, 4), facecolor='white')
rcParams['font.family'] = 'Times New Roman'
ax = plt.gca()
ax.set_facecolor((230/255, 230/255, 238/255))
for spine in ax.spines.values():
    spine.set_color('white')

colors = [
    (128/255,149/255,192/255),
    (215/255,164/255,133/255),
    (141/255,185/255,149/255),
    (197/255,128/255,131/255),
    (158/255,149/255,192/255)
]

plt.plot(acc_FedAvg, label='FedAvg', color=colors[0])
plt.plot(acc_LoMar, label='LoMar', color=colors[1])
plt.plot(acc_FedLC, label='FedLC', color=colors[2])
plt.plot(acc_FedRDN, label='FedRDN', color=colors[3])
plt.plot(acc_FedKde, label='FedKde', color=colors[4])

plt.xlabel("Round")
plt.ylabel("Accuracy")
plt.title("(a) MNIST")
plt.grid(True, color='white')
plt.legend()
plt.show()
