import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
import random
import os
from scipy.spatial.distance import pdist, squareform
from matplotlib import rcParams


NUM_CLIENTS = 60
NUM_BINARY = 50
NUM_FULL = 10
NOISY_CLIENT_NUM = 25

EPOCHS = 1
ROUNDS = 30
BATCH_SIZE = 64
LR = 0.01
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SAVE_ROUND = 30
SAVE_PATH = "./saved_models"
os.makedirs(SAVE_PATH, exist_ok=True)

K_KDE = 5

SEED = 60


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=1.0):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean


class AddSaltPepperNoise(object):
    def __init__(self, prob=0.3):
        self.prob = prob

    def __call__(self, tensor):
        noise = torch.rand(tensor.size())
        tensor = tensor.clone()
        tensor[noise < self.prob / 2] = 0.0
        tensor[noise > 1 - self.prob / 2] = 1.0
        return tensor


class RandomPixelFlip(object):
    def __init__(self, ratio=0.2):
        self.ratio = ratio

    def __call__(self, tensor):
        total = tensor.numel()
        idx = torch.randperm(total)[: int(total * self.ratio)]
        t = tensor.view(-1)
        t[idx] = 1 - t[idx]
        return t.view(1, 28, 28)


def flip_label(y, p=0.5):
    if random.random() < p:
        return random.randint(0, 9)
    return y


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = torch.relu(torch.max_pool2d(self.conv1(x), 2))
        x = torch.relu(torch.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)


set_seed(SEED)


clean_transform = transforms.Compose([transforms.ToTensor()])

train_dataset_clean = datasets.FashionMNIST(
    "./data", train=True, download=True, transform=clean_transform
)
test_dataset = datasets.FashionMNIST(
    "./data", train=False, download=True, transform=clean_transform
)

labels = np.array(train_dataset_clean.targets)
label_indices = {l: np.where(labels == l)[0] for l in range(10)}
for l in label_indices:
    np.random.shuffle(label_indices[l])


def noisy_clients_setup():
    noisy_clients = np.random.choice(NUM_BINARY, size=NOISY_CLIENT_NUM, replace=False)
    print("🎨 noisy 客户端:", noisy_clients)
    return noisy_clients


def label_skew_subset_indices():
    rand0 = rand1 = 600

    start0 = np.random.randint(0, len(label_indices[0]) - rand0)
    start1 = np.random.randint(0, len(label_indices[1]) - rand1)

    idx0 = label_indices[0][start0 : start0 + rand0]
    idx1 = label_indices[1][start1 : start1 + rand1]

    EXTRA_LABEL_NUM = np.random.randint(1, 6)
    extra_labels = np.random.choice(range(2, 10), size=EXTRA_LABEL_NUM, replace=False)

    extra_indices = []
    for lbl in extra_labels:
        total_lbl = len(label_indices[lbl])

        max_take = min(80, total_lbl)
        min_take = min(20, max_take)

        take = np.random.randint(min_take, max_take + 1)

        if total_lbl - take <= 1:
            start_lbl = 0
        else:
            start_lbl = np.random.randint(0, total_lbl - take)

        extra_indices.extend(label_indices[lbl][start_lbl : start_lbl + take])

    extra_indices = np.array(extra_indices)

    subset_indices = np.concatenate([idx0, idx1, extra_indices])
    np.random.shuffle(subset_indices)
    return subset_indices


def label_nonskew_subset_indices(i):
    indices_2_9 = []
    for label in range(2, 10):
        label_idx = np.where(train_dataset_clean.targets == label)[0]
        indices_2_9.extend(label_idx)

    indices_2_9 = np.array(indices_2_9)
    np.random.shuffle(indices_2_9)

    indices_0 = np.where(train_dataset_clean.targets == 0)[0]
    indices_1 = np.where(train_dataset_clean.targets == 1)[0]
    indices_0_1 = np.concatenate([indices_0, indices_1])
    np.random.shuffle(indices_0_1)

    print(f"标签 2~9 样本数量: {len(indices_2_9)}")
    print(f"标签 0,1 样本数量: {len(indices_0_1)}")

    samples_per_client = len(indices_2_9) // NUM_FULL

    subset_main = indices_2_9[i * samples_per_client : (i + 1) * samples_per_client]

    extra_num = np.random.randint(10, 31)
    if extra_num > len(indices_0_1):
        extra_num = len(indices_0_1)

    extra_indices = np.random.choice(indices_0_1, size=extra_num, replace=False)

    subset = np.concatenate([subset_main, extra_indices])
    np.random.shuffle(subset)
    return subset


def clients_datasets_setup(noisy_clients, is_feature_skew=True, is_label_skew=True):
    client_datasets = []

    for i in range(NUM_BINARY):
        if is_label_skew:
            subset_indices = label_skew_subset_indices()
        else:
            subset_indices = label_nonskew_subset_indices(i)

        if is_feature_skew and i in noisy_clients:
            noisy_transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    AddGaussianNoise(std=0.2),
                ]
            )
            noisy_dataset = datasets.FashionMNIST(
                "./data", train=True, download=True, transform=noisy_transform
            )
            client_datasets.append(Subset(noisy_dataset, subset_indices))
        else:
            client_datasets.append(Subset(train_dataset_clean, subset_indices))

    for i in range(NUM_FULL):
        subset = label_nonskew_subset_indices(i)
        client_datasets.append(Subset(train_dataset_clean, subset))

    return client_datasets


def get_loader(dataset):
    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


def local_train(model, loader, noisy=False):
    model.train()
    opt = optim.SGD(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()
    for _ in range(EPOCHS):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            opt.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            opt.step()
    return model.state_dict()


def test(model, loader):
    correct, total = 0, 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            pred = model(x).argmax(dim=1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
    return correct / total


def compute_kde_weights(local_weights, k):
    flat = np.array(
        [torch.cat([p.flatten() for p in w.values()]).cpu().numpy() for w in local_weights]
    )

    dist_matrix = squareform(pdist(flat))
    h = np.median(dist_matrix) + 1e-12

    N = len(local_weights)
    kde_vals = np.zeros(N)

    for i in range(N):
        idx = np.argsort(dist_matrix[i])[1 : k + 1]
        d = dist_matrix[i, idx]
        kde_vals[i] = np.mean(np.exp(-d * d / (2 * h * h)))

    kde_vals += 1e-12

    binary_kde_vals = kde_vals[:NUM_BINARY] / kde_vals[:NUM_BINARY].sum()
    binary_average_kde_val = np.mean(kde_vals[:NUM_BINARY])

    full_kde_vals = kde_vals[NUM_BINARY:] / kde_vals[NUM_BINARY:].sum()
    full_average_kde_val = np.mean(kde_vals[NUM_BINARY:])

    binary_weight = binary_average_kde_val / (
        binary_average_kde_val + full_average_kde_val
    )
    full_weight = full_average_kde_val / (binary_average_kde_val + full_average_kde_val)

    binary_kde_vals = [x * binary_weight for x in binary_kde_vals]
    full_kde_vals = [x * full_weight for x in full_kde_vals]
    kde_vals = binary_kde_vals + full_kde_vals

    if k == 0:
        return np.ones(len(kde_vals)) / len(kde_vals)
    else:
        return kde_vals


def fed_kde_avg(local_weights, kde_weights):
    keys = local_weights[0].keys()
    out = {}
    for k in keys:
        out[k] = sum(
            kde_weights[i] * local_weights[i][k] for i in range(len(local_weights))
        )
    return out


test_loader = get_loader(test_dataset)
noisy_clients = noisy_clients_setup()


def FL_training(k_value, is_feature_skew, is_label_skew):
    client_datasets = clients_datasets_setup(noisy_clients, is_feature_skew, is_label_skew)

    print("k_value = ", k_value)

    global_acc = []
    global_model = CNN().to(DEVICE)

    for r in range(ROUNDS):
        print(f"\n===== Round {r+1}/{ROUNDS} =====")

        local_weights = []
        for i in range(NUM_CLIENTS):
            local_model = CNN().to(DEVICE)
            local_model.load_state_dict(global_model.state_dict())
            loader = get_loader(client_datasets[i])
            w = local_train(local_model, loader, noisy=(i in noisy_clients and is_label_skew))
            local_weights.append(w)

        kde_weights = compute_kde_weights(local_weights, k_value)
        aggregated = fed_kde_avg(local_weights, kde_weights)
        global_model.load_state_dict(aggregated)

        acc = test(global_model, test_loader)
        global_acc.append(acc)
        print("🌍 Global Acc =", acc)

    return global_acc


k_value_FedAvg = 0
is_feature_skew_FedAvg = True
is_label_skew_FedAvg = True

k_value_LoMar = 60
is_feature_skew_LoMar = True
is_label_skew_LoMar = True

k_value_FedLC = 60
is_feature_skew_FedLC = True
is_label_skew_FedLC = False

k_value_FedRDN = 60
is_feature_skew_FedRDN = False
is_label_skew_FedRDN = True

k_value_FedKde = 13
is_feature_skew_FedKde = True
is_label_skew_FedKde = True

acc_FedAvg = FL_training(k_value_FedAvg, is_feature_skew_FedAvg, is_label_skew_FedAvg)
acc_FedLC = FL_training(k_value_FedLC, is_feature_skew_FedLC, is_label_skew_FedLC)
acc_FedRDN = FL_training(k_value_FedRDN, is_feature_skew_FedRDN, is_label_skew_FedRDN)


plt.figure(figsize=(6, 4), facecolor="white")
rcParams["font.family"] = "Times New Roman"

ax = plt.gca()
ax.set_facecolor((230 / 255, 230 / 255, 238 / 255))
ax.spines["bottom"].set_color("white")
ax.spines["top"].set_color("white")
ax.spines["right"].set_color("white")
ax.spines["left"].set_color("white")

color_1 = (128 / 255, 149 / 255, 192 / 255)
color_2 = (215 / 255, 164 / 255, 133 / 255)
color_3 = (141 / 255, 185 / 255, 149 / 255)
color_4 = (197 / 255, 128 / 255, 131 / 255)
color_5 = (158 / 255, 149 / 255, 192 / 255)
color_6 = (171 / 255, 155 / 255, 140 / 255)
colors = [color_1, color_2, color_3, color_4, color_5, color_6]

linestyles = ["-", "--", ":", "-.", (0, (3, 1, 1, 1)), (0, (8, 2))]

plt.plot(acc_FedAvg, label="FedAvg", color=colors[0], linestyle=linestyles[0])
plt.plot(acc_FedLC, label="FedLC", color=colors[2], linestyle=linestyles[2])
plt.plot(acc_FedRDN, label="FedRDN", color=colors[3], linestyle=linestyles[3])

plt.xlabel("Round")
plt.ylabel("Accuracy")
plt.title("(b) Fashion-MNIST")
plt.grid(True, color="white", zorder=0)
legend = plt.legend()
legend.get_frame().set_facecolor((230 / 255, 230 / 255, 238 / 255))
plt.show()
