import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
import random
import os
from scipy.spatial.distance import pdist, squareform
from matplotlib import rcParams

NUM_CLIENTS = 60
NUM_BINARY = 50
NUM_FULL = 10
NOISY_CLIENT_NUM = 25

EPOCHS = 1
ROUNDS = 30
BATCH_SIZE = 64
LR = 0.01
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SAVE_ROUND = 30
SAVE_PATH = "./saved_models"
os.makedirs(SAVE_PATH, exist_ok=True)

K_KDE = 5

SEED = 60

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=1.0):
        self.mean = mean
        self.std = std
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

class AddSaltPepperNoise(object):
    def __init__(self, prob=0.3):
        self.prob = prob
    def __call__(self, tensor):
        noise = torch.rand(tensor.size())
        tensor = tensor.clone()
        tensor[noise < self.prob/2] = 0.0
        tensor[noise > 1 - self.prob/2] = 1.0
        return tensor

class RandomPixelFlip(object):
    def __init__(self, ratio=0.2):
        self.ratio = ratio
    def __call__(self, tensor):
        total = tensor.numel()
        idx = torch.randperm(total)[:int(total * self.ratio)]
        tensor = tensor.view(-1)
        tensor[idx] = 1 - tensor[idx]
        return tensor.view(1, 28, 28)

def flip_label(y, p=0.5):
    if random.random() < p:
        return random.randint(0, 9)
    return y

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
    def forward(self, x):
        x = torch.relu(torch.max_pool2d(self.conv1(x), 2))
        x = torch.relu(torch.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

set_seed(SEED)

clean_transform = transforms.Compose([transforms.ToTensor()])
train_dataset_clean = datasets.MNIST("./data", train=True, download=True, transform=clean_transform)
test_dataset = datasets.MNIST("./data", train=False, download=True, transform=clean_transform)

labels = np.array(train_dataset_clean.targets)
label_indices = {l: np.where(labels == l)[0] for l in range(10)}
for l in label_indices:
    np.random.shuffle(label_indices[l])

client_datasets = []

NOISY_CLIENTS = np.random.choice(NUM_BINARY, size=NOISY_CLIENT_NUM, replace=False)
print("🎨 noisy 客户端:", NOISY_CLIENTS)

for i in range(NUM_BINARY):

    rand0 = rand1 = 600

    start0 = np.random.randint(0, len(label_indices[0]) - rand0)
    start1 = np.random.randint(0, len(label_indices[1]) - rand1)

    idx0 = label_indices[0][start0:start0 + rand0]
    idx1 = label_indices[1][start1:start1 + rand1]

    EXTRA_LABEL_NUM = np.random.randint(1, 6)
    extra_labels = np.random.choice(range(2, 10), size=EXTRA_LABEL_NUM, replace=False)

    extra_indices = []
    for lbl in extra_labels:
        total_lbl = len(label_indices[lbl])

        max_take = min(80, total_lbl)
        min_take = min(20, max_take)

        take = np.random.randint(min_take, max_take + 1)

        if total_lbl - take <= 1:
            start_lbl = 0
        else:
            start_lbl = np.random.randint(0, total_lbl - take)

        extra_indices.extend(label_indices[lbl][start_lbl:start_lbl + take])

    extra_indices = np.array(extra_indices)

    subset_indices = np.concatenate([idx0, idx1, extra_indices])
    np.random.shuffle(subset_indices)

    if i in NOISY_CLIENTS:
        noisy_transform = transforms.Compose([
            transforms.ToTensor(),
            AddGaussianNoise(std=0.4),
            AddSaltPepperNoise(prob=0.1),
            RandomPixelFlip(ratio=0.1),
        ])
        noisy_dataset = datasets.MNIST("./data", train=True, transform=noisy_transform)
        client_datasets.append(Subset(noisy_dataset, subset_indices))
    else:
        client_datasets.append(Subset(train_dataset_clean, subset_indices))

indices_2_9 = []
for label in range(2, 10):
    label_idx = np.where(train_dataset_clean.targets == label)[0]
    indices_2_9.extend(label_idx)

indices_2_9 = np.array(indices_2_9)
np.random.shuffle(indices_2_9)

indices_0 = np.where(train_dataset_clean.targets == 0)[0]
indices_1 = np.where(train_dataset_clean.targets == 1)[0]

indices_0_1 = np.concatenate([indices_0, indices_1])
np.random.shuffle(indices_0_1)

print(f"标签 2~9 样本数量: {len(indices_2_9)}")
print(f"标签 0,1 样本数量: {len(indices_0_1)}")

samples_per_client = len(indices_2_9) // NUM_FULL

for i in range(NUM_FULL):

    subset_main = indices_2_9[i * samples_per_client : (i + 1) * samples_per_client]

    extra_num = np.random.randint(10, 31)

    if extra_num > len(indices_0_1):
        extra_num = len(indices_0_1)

    extra_indices = np.random.choice(indices_0_1, size=extra_num, replace=False)

    subset = np.concatenate([subset_main, extra_indices])
    np.random.shuffle(subset)

    client_datasets.append(Subset(train_dataset_clean, subset))

    print(f"Full Client {i} → 主标签 2-9 样本 = {len(subset_main)}, 加入 0-1 样本 = {extra_num}")

def get_loader(dataset):
    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

def local_train(model, loader, noisy=False):
    model.train()
    opt = optim.SGD(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()
    for _ in range(EPOCHS):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            if noisy:
                y = torch.tensor([flip_label(int(t), p=0.5) for t in y]).to(DEVICE)
            opt.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            opt.step()
    return model.state_dict()

def test(model, loader):
    correct, total = 0, 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            pred = model(x).argmax(dim=1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
    return correct / total

def compute_kde_weights(local_weights, k):

    flat = np.array([torch.cat([p.flatten() for p in w.values()]).numpy()
                     for w in local_weights])

    dist_matrix = squareform(pdist(flat))
    h = np.median(dist_matrix) + 1e-12

    N = len(local_weights)
    kde_vals = np.zeros(N)

    for i in range(N):
        idx = np.argsort(dist_matrix[i])[1:k+1]
        d = dist_matrix[i, idx]
        kde_vals[i] = np.mean(np.exp(-d*d/(2*h*h)))

    kde_vals += 1e-12

    binary_kde_vals = kde_vals[:NUM_BINARY] / kde_vals[:NUM_BINARY].sum()
    binary_average_kde_val = np.mean(kde_vals[:NUM_BINARY])
    full_kde_vals = kde_vals[NUM_BINARY:] / kde_vals[NUM_BINARY:].sum()
    full_average_kde_val = np.mean(kde_vals[NUM_BINARY:])
    binary_weight = binary_average_kde_val / (binary_average_kde_val + full_average_kde_val)
    full_weight = full_average_kde_val / (binary_average_kde_val + full_average_kde_val)
    binary_kde_vals = [x * binary_weight for x in binary_kde_vals]
    full_kde_vals = [x * full_weight for x in full_kde_vals]
    kde_vals = binary_kde_vals + full_kde_vals

    return kde_vals


def fed_kde_avg(local_weights, kde_weights):
    keys = local_weights[0].keys()
    out = {}
    for k in keys:
        out[k] = sum(kde_weights[i] * local_weights[i][k]
                     for i in range(len(local_weights)))
    return out

test_loader = get_loader(test_dataset)

binary_acc, full_acc = [], []
binary_monitor = random.randint(0, NUM_BINARY-1)
full_monitor = random.randint(NUM_BINARY, NUM_CLIENTS-1)

k_values = [5, 8, 10, 13, 15, 20]
global_accs = []
for k_value in k_values:

    print("k_value = ", k_value)
    
    global_acc = []

    global_model = CNN().to(DEVICE)

    for r in range(ROUNDS):
        print(f"\n===== Round {r+1}/{ROUNDS} =====")

        local_weights = []
        for i in range(NUM_CLIENTS):
            local_model = CNN().to(DEVICE)
            local_model.load_state_dict(global_model.state_dict())
            loader = get_loader(client_datasets[i])
            w = local_train(local_model, loader, noisy=(i in NOISY_CLIENTS))
            local_weights.append(w)

            if i == binary_monitor:
                binary_acc.append(test(local_model, test_loader))
            if i == full_monitor:
                full_acc.append(test(local_model, test_loader))

        kde_weights = compute_kde_weights(local_weights, k_value)
        aggregated = fed_kde_avg(local_weights, kde_weights)
        global_model.load_state_dict(aggregated)

        acc = test(global_model, test_loader)
        global_acc.append(acc)
        print("🌍 Global Acc =", acc)
    
    global_accs.append(global_acc)

plt.figure(figsize=(6, 4), facecolor='white')
rcParams['font.family'] = 'Times New Roman'

ax = plt.gca()
ax.set_facecolor((230/255, 230/255, 238/255))
ax.spines['bottom'].set_color('white')
ax.spines['top'].set_color('white')
ax.spines['right'].set_color('white')
ax.spines['left'].set_color('white')

color_1 = (128/255, 149/255, 192/255)
color_2 = (215/255, 164/255, 133/255)
color_3 = (141/255, 185/255, 149/255)
color_4 = (197/255, 128/255, 131/255)
color_5 = (158/255, 149/255, 192/255)
color_6 = (171/255, 155/255, 140/255)
colors = [color_1, color_2, color_3, color_4, color_5, color_6]

linestyles = ['-', '--', ':', '-.', (0, (3, 1, 1, 1)), (0, (8, 2))]

for idx, k_value in enumerate(k_values):
    plt.plot(global_accs[idx], label=f'k={k_values[idx]}', color=colors[idx], linestyle=linestyles[idx])

plt.xlabel("Round")
plt.ylabel("Accuracy")
plt.title("(a) MNIST")
plt.grid(True, color='white', zorder=0)
legend = plt.legend()
legend.get_frame().set_facecolor((230/255, 230/255, 238/255))
plt.show()
