import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
import random
import os
from scipy.spatial.distance import pdist, squareform
from matplotlib import rcParams

NUM_CLIENTS = 60
NUM_BINARY = 50
NUM_FULL = 10
NOISY_CLIENT_NUM = 25

EPOCHS = 1
ROUNDS = 100
BATCH_SIZE = 64
LR = 0.01
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SAVE_ROUND = 30
SAVE_PATH = "./saved_models"
os.makedirs(SAVE_PATH, exist_ok=True)

SEED = 60


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=0.2):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        return tensor + torch.randn_like(tensor) * self.std + self.mean


class AddSaltPepperNoise(object):
    def __init__(self, prob=0.05):
        self.prob = prob

    def __call__(self, tensor):
        noise = torch.rand_like(tensor)
        out = tensor.clone()
        out[noise < self.prob / 2] = 0.0
        out[noise > 1 - self.prob / 2] = 1.0
        return out


class RandomPixelFlip(object):
    def __init__(self, ratio=0.05):
        self.ratio = ratio

    def __call__(self, tensor):
        out = tensor.clone()
        total = out.numel()
        k = int(total * self.ratio)
        if k <= 0:
            return out
        idx = torch.randperm(total)[:k]
        flat = out.view(-1)
        flat[idx] = 1.0 - flat[idx]
        return flat.view_as(out)


def flip_label(y, p=0.5):
    if random.random() < p:
        return random.randint(0, 9)
    return y


class CNN_CIFAR10(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


set_seed(SEED)

clean_transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset_clean = datasets.CIFAR10("./data", train=True, download=True, transform=clean_transform)
test_dataset = datasets.CIFAR10("./data", train=False, download=True, transform=clean_transform)

labels = np.array(train_dataset_clean.targets)
label_indices = {l: np.where(labels == l)[0] for l in range(10)}
for l in label_indices:
    np.random.shuffle(label_indices[l])

client_datasets = []

NOISY_CLIENTS = np.random.choice(NUM_BINARY, size=NOISY_CLIENT_NUM, replace=False)
print("🎨 noisy 客户端:", NOISY_CLIENTS)

for i in range(NUM_BINARY):

    rand0 = rand1 = 400

    start0 = np.random.randint(0, len(label_indices[0]) - rand0)
    start1 = np.random.randint(0, len(label_indices[1]) - rand1)

    idx0 = label_indices[0][start0:start0 + rand0]
    idx1 = label_indices[1][start1:start1 + rand1]

    EXTRA_LABEL_NUM = np.random.randint(1, 6)
    extra_labels = np.random.choice(range(2, 10), size=EXTRA_LABEL_NUM, replace=False)

    extra_indices = []
    for lbl in extra_labels:
        total_lbl = len(label_indices[lbl])

        max_take = min(60, total_lbl)
        min_take = min(20, max_take)
        take = np.random.randint(min_take, max_take + 1)

        if total_lbl - take <= 1:
            start_lbl = 0
        else:
            start_lbl = np.random.randint(0, total_lbl - take)

        extra_indices.extend(label_indices[lbl][start_lbl:start_lbl + take])

    extra_indices = np.array(extra_indices)

    subset_indices = np.concatenate([idx0, idx1, extra_indices])
    np.random.shuffle(subset_indices)

    if i in NOISY_CLIENTS:
        noisy_transform = transforms.Compose([
            transforms.ToTensor(),
            AddGaussianNoise(std=0.25),
            AddSaltPepperNoise(prob=0.05),
            RandomPixelFlip(ratio=0.05),
        ])
        noisy_dataset = datasets.CIFAR10("./data", train=True, download=True, transform=noisy_transform)
        client_datasets.append(Subset(noisy_dataset, subset_indices))
    else:
        client_datasets.append(Subset(train_dataset_clean, subset_indices))

indices_2_9 = []
for label in range(2, 10):
    indices_2_9.extend(np.where(labels == label)[0])
indices_2_9 = np.array(indices_2_9)
np.random.shuffle(indices_2_9)

indices_0 = np.where(labels == 0)[0]
indices_1 = np.where(labels == 1)[0]
indices_0_1 = np.concatenate([indices_0, indices_1])
np.random.shuffle(indices_0_1)

print(f"标签 2~9 样本数量: {len(indices_2_9)}")
print(f"标签 0,1 样本数量: {len(indices_0_1)}")

samples_per_client = len(indices_2_9) // NUM_FULL

for i in range(NUM_FULL):
    subset_main = indices_2_9[i * samples_per_client: (i + 1) * samples_per_client]

    extra_num = np.random.randint(30, 81)
    extra_num = min(extra_num, len(indices_0_1))
    extra_indices = np.random.choice(indices_0_1, size=extra_num, replace=False)

    subset = np.concatenate([subset_main, extra_indices])
    np.random.shuffle(subset)

    client_datasets.append(Subset(train_dataset_clean, subset))
    print(f"Full Client {i} → 主标签 2-9 样本 = {len(subset_main)}, 加入 0-1 样本 = {extra_num}")

assert len(client_datasets) == NUM_CLIENTS, f"client_datasets={len(client_datasets)} != {NUM_CLIENTS}"


def get_loader(dataset):
    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


def local_train(model, loader, noisy=False):
    model.train()
    opt = optim.SGD(model.parameters(), lr=LR, momentum=0.9)
    loss_fn = nn.CrossEntropyLoss()
    for _ in range(EPOCHS):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            if noisy:
                y = torch.tensor([flip_label(int(t), p=0.5) for t in y], device=DEVICE)
            opt.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            opt.step()
    return {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}


def test(model, loader):
    correct, total = 0, 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            pred = model(x).argmax(dim=1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
    return correct / total


def compute_kde_weights(local_weights, k):
    flat = np.array([
        torch.cat([p.flatten() for p in w.values()]).numpy()
        for w in local_weights
    ])

    dist_matrix = squareform(pdist(flat))
    h = np.median(dist_matrix) + 1e-12

    N = len(local_weights)
    kde_vals = np.zeros(N)

    for i in range(N):
        idx = np.argsort(dist_matrix[i])[1:k + 1]
        d = dist_matrix[i, idx]
        kde_vals[i] = np.mean(np.exp(-d * d / (2 * h * h)))

    kde_vals += 1e-12

    binary_kde_vals = kde_vals[:NUM_BINARY] / kde_vals[:NUM_BINARY].sum()
    binary_avg = np.mean(kde_vals[:NUM_BINARY])

    full_kde_vals = kde_vals[NUM_BINARY:] / kde_vals[NUM_BINARY:].sum()
    full_avg = np.mean(kde_vals[NUM_BINARY:])

    binary_weight = binary_avg / (binary_avg + full_avg)
    full_weight = full_avg / (binary_avg + full_avg)

    binary_kde_vals = [x * binary_weight for x in binary_kde_vals]
    full_kde_vals = [x * full_weight for x in full_kde_vals]
    kde_vals = binary_kde_vals + full_kde_vals

    return kde_vals


def fed_kde_avg(local_weights, kde_weights):
    keys = local_weights[0].keys()
    out = {}
    for k in keys:
        out[k] = sum(kde_weights[i] * local_weights[i][k] for i in range(len(local_weights)))
    return out


test_loader = get_loader(test_dataset)

binary_acc, full_acc = [], []
binary_monitor = random.randint(0, NUM_BINARY - 1)
full_monitor = random.randint(NUM_BINARY, NUM_CLIENTS - 1)

k_values = [5, 8, 10, 13, 15, 20]
global_accs = []

for k_value in k_values:
    print("k_value =", k_value)

    global_acc = []
    global_model = CNN_CIFAR10().to(DEVICE)

    for r in range(ROUNDS):
        print(f"\n===== Round {r + 1}/{ROUNDS} =====")

        local_weights = []
        for i in range(NUM_CLIENTS):
            local_model = CNN_CIFAR10().to(DEVICE)
            local_model.load_state_dict(global_model.state_dict())

            loader = get_loader(client_datasets[i])
            w = local_train(local_model, loader, noisy=(i in NOISY_CLIENTS))
            local_weights.append(w)

            if i == binary_monitor:
                binary_acc.append(test(local_model, test_loader))
            if i == full_monitor:
                full_acc.append(test(local_model, test_loader))

        kde_weights = compute_kde_weights(local_weights, k_value)
        aggregated = fed_kde_avg(local_weights, kde_weights)

        global_model.load_state_dict(aggregated)
        acc = test(global_model, test_loader)

        global_acc.append(acc)
        print("🌍 Global Acc =", acc)

    global_accs.append(global_acc)


plt.figure(figsize=(6, 4), facecolor='white')
rcParams['font.family'] = 'Times New Roman'

ax = plt.gca()
ax.set_facecolor((230 / 255, 230 / 255, 238 / 255))
ax.spines['bottom'].set_color('white')
ax.spines['top'].set_color('white')
ax.spines['right'].set_color('white')
ax.spines['left'].set_color('white')

color_1 = (128 / 255, 149 / 255, 192 / 255)
color_2 = (215 / 255, 164 / 255, 133 / 255)
color_3 = (141 / 255, 185 / 255, 149 / 255)
color_4 = (197 / 255, 128 / 255, 131 / 255)
color_5 = (158 / 255, 149 / 255, 192 / 255)
color_6 = (171 / 255, 155 / 255, 140 / 255)
colors = [color_1, color_2, color_3, color_4, color_5, color_6]

linestyles = ['-', '--', ':', '-.', (0, (3, 1, 1, 1)), (0, (8, 2))]

for idx, k_value in enumerate(k_values):
    plt.plot(global_accs[idx], label=f'k={k_value}', color=colors[idx], linestyle=linestyles[idx])

plt.xlabel("Round")
plt.ylabel("Accuracy")
plt.title("CIFAR-10")
plt.grid(True, color='white', zorder=0)

legend = plt.legend()
legend.get_frame().set_facecolor((230 / 255, 230 / 255, 238 / 255))
plt.show()
