import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
import random
import os
from scipy.spatial.distance import pdist, squareform
from matplotlib import rcParams


NUM_CLIENTS = 60
NUM_BINARY = 50
NUM_FULL = 10
NOISY_CLIENT_NUM = 25

EPOCHS = 1
ROUNDS = 100
BATCH_SIZE = 64
LR = 0.01
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

GAUSSION_NOISE_P = 0.2
SALT_PEPPER_NOISE_P = 0.1
RANDOM_PIXEL_FLIP_P = 0.1

SEED = 60


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=0.2):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        return tensor + torch.randn_like(tensor) * self.std + self.mean


class AddSaltPepperNoise(object):
    def __init__(self, prob=0.05):
        self.prob = prob

    def __call__(self, tensor):
        noise = torch.rand_like(tensor)
        out = tensor.clone()
        out[noise < self.prob / 2] = 0.0
        out[noise > 1 - self.prob / 2] = 1.0
        return out


class RandomPixelFlip(object):
    def __init__(self, ratio=0.02):
        self.ratio = ratio

    def __call__(self, tensor):
        c, h, w = tensor.shape
        total = tensor.numel()
        num = int(total * self.ratio)
        if num <= 0:
            return tensor
        idx = torch.randperm(total)[:num]
        flat = tensor.view(-1)
        flat[idx] = 1 - flat[idx]
        return flat.view(c, h, w)


class FeatureDropout(object):
    def __init__(self, dropout_rate=0.2):
        self.dropout_rate = dropout_rate

    def __call__(self, x):
        mask = (torch.rand_like(x) > self.dropout_rate).float()
        return x * mask


def flip_label(y, p=0.5):
    if random.random() < p:
        return random.randint(0, 9)
    return y


class CNN_CIFAR(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(128 * 4 * 4, 256), nn.ReLU(), nn.Linear(256, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)


set_seed(SEED)


clean_transform = transforms.Compose([transforms.ToTensor()])

train_dataset_clean = datasets.CIFAR10(
    "./data", train=True, download=True, transform=clean_transform
)
test_dataset = datasets.CIFAR10("./data", train=False, download=True, transform=clean_transform)

labels = np.array(train_dataset_clean.targets)
label_indices = {l: np.where(labels == l)[0] for l in range(10)}
for l in label_indices:
    np.random.shuffle(label_indices[l])


def noisy_clients_setup():
    noisy_clients = np.random.choice(NUM_BINARY, size=NOISY_CLIENT_NUM, replace=False)
    print("🎨 noisy 客户端:", noisy_clients)
    return noisy_clients


def label_skew_subset_indices(is_label_skew=True):
    rand0 = rand1 = 600
    if is_label_skew == False:
        rand0 = np.random.randint(350, 501)
        rand1 = np.random.randint(350, 501)

    start0 = np.random.randint(0, len(label_indices[0]) - rand0)
    start1 = np.random.randint(0, len(label_indices[1]) - rand1)

    idx0 = label_indices[0][start0 : start0 + rand0]
    idx1 = label_indices[1][start1 : start1 + rand1]

    EXTRA_LABEL_NUM = np.random.randint(1, 6)
    if is_label_skew == False:
        EXTRA_LABEL_NUM = np.random.randint(4, 9)

    extra_labels = np.random.choice(range(2, 10), size=EXTRA_LABEL_NUM, replace=False)

    extra_indices = []
    for lbl in extra_labels:
        total_lbl = len(label_indices[lbl])
        max_take = min(80, total_lbl)
        min_take = min(20, max_take)
        if is_label_skew == False:
            max_take = min(150, total_lbl)
            min_take = min(50, max_take)

        take = np.random.randint(min_take, max_take + 1)

        if total_lbl - take <= 1:
            start_lbl = 0
        else:
            start_lbl = np.random.randint(0, total_lbl - take)

        extra_indices.extend(label_indices[lbl][start_lbl : start_lbl + take])

    extra_indices = np.array(extra_indices)

    subset_indices = np.concatenate([idx0, idx1, extra_indices])
    np.random.shuffle(subset_indices)
    return subset_indices


def label_nonskew_subset_indices(i):
    indices_2_9 = []
    for label in range(2, 10):
        indices_2_9.extend(label_indices[label])
    indices_2_9 = np.array(indices_2_9)
    np.random.shuffle(indices_2_9)

    indices_0_1 = np.concatenate([label_indices[0], label_indices[1]])
    np.random.shuffle(indices_0_1)

    samples_per_client = len(indices_2_9) // NUM_FULL
    subset_main = indices_2_9[i * samples_per_client : (i + 1) * samples_per_client]

    extra_num = np.random.randint(10, 31)
    extra_num = min(extra_num, len(indices_0_1))
    extra_indices = np.random.choice(indices_0_1, size=extra_num, replace=False)

    subset = np.concatenate([subset_main, extra_indices])
    np.random.shuffle(subset)
    return subset


def clients_datasets_setup(noisy_clients, is_feature_skew=True, is_label_skew=True):
    client_datasets = []

    for i in range(NUM_BINARY):
        subset_indices = label_skew_subset_indices(is_label_skew)

        if is_feature_skew and (i in noisy_clients):
            noisy_transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    AddGaussianNoise(std=GAUSSION_NOISE_P),
                    AddSaltPepperNoise(prob=SALT_PEPPER_NOISE_P),
                    RandomPixelFlip(ratio=RANDOM_PIXEL_FLIP_P),
                    FeatureDropout(dropout_rate=0.2),
                ]
            )
            noisy_dataset = datasets.CIFAR10(
                "./data", train=True, download=False, transform=noisy_transform
            )
            client_datasets.append(Subset(noisy_dataset, subset_indices))
        else:
            client_datasets.append(Subset(train_dataset_clean, subset_indices))

    for i in range(NUM_FULL):
        subset = label_nonskew_subset_indices(i)
        client_datasets.append(Subset(train_dataset_clean, subset))

    return client_datasets


def get_loader(dataset):
    return DataLoader(
        dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=False
    )


def local_train(model, loader, noisy=False):
    model.train()
    opt = optim.SGD(model.parameters(), lr=LR, momentum=0.9)
    loss_fn = nn.CrossEntropyLoss()
    for _ in range(EPOCHS):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)

            opt.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            opt.step()
    return model.state_dict()


def test(model, loader):
    correct, total = 0, 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            pred = model(x).argmax(dim=1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
    return correct / total


def compute_kde_weights(local_weights, k):
    flat = np.array(
        [
            torch.cat([p.flatten().detach().cpu() for p in w.values()]).numpy()
            for w in local_weights
        ]
    )

    dist_matrix = squareform(pdist(flat))
    h = np.median(dist_matrix) + 1e-12

    N = len(local_weights)
    kde_vals = np.zeros(N)

    for i in range(N):
        idx = np.argsort(dist_matrix[i])[1 : k + 1]
        d = dist_matrix[i, idx]
        kde_vals[i] = np.mean(np.exp(-d * d / (2 * h * h)))

    kde_vals += 1e-12

    binary_kde_vals = kde_vals[:NUM_BINARY] / kde_vals[:NUM_BINARY].sum()
    binary_average_kde_val = np.mean(kde_vals[:NUM_BINARY])

    full_kde_vals = kde_vals[NUM_BINARY:] / kde_vals[NUM_BINARY:].sum()
    full_average_kde_val = np.mean(kde_vals[NUM_BINARY:])

    binary_weight = binary_average_kde_val / (
        binary_average_kde_val + full_average_kde_val
    )
    full_weight = full_average_kde_val / (binary_average_kde_val + full_average_kde_val)

    binary_kde_vals = [x * binary_weight for x in binary_kde_vals]
    full_kde_vals = [x * full_weight for x in full_kde_vals]
    kde_vals = binary_kde_vals + full_kde_vals

    if k == 0:
        return np.ones(len(kde_vals)) / len(kde_vals)
    else:
        return kde_vals


def fed_kde_avg(local_weights, kde_weights):
    keys = local_weights[0].keys()
    out = {}
    for k in keys:
        out[k] = sum(
            kde_weights[i] * local_weights[i][k] for i in range(len(local_weights))
        )
    return out


test_loader = get_loader(test_dataset)
noisy_clients = noisy_clients_setup()


def FL_training(k_value, is_feature_skew, is_label_skew):
    client_datasets = clients_datasets_setup(noisy_clients, is_feature_skew, is_label_skew)

    global_acc = []
    global_model = CNN_CIFAR().to(DEVICE)

    for r in range(ROUNDS):
        print(f"\n===== Round {r+1}/{ROUNDS} =====")

        local_weights = []
        for i in range(NUM_CLIENTS):
            local_model = CNN_CIFAR().to(DEVICE)
            local_model.load_state_dict(global_model.state_dict())

            loader = get_loader(client_datasets[i])
            w = local_train(local_model, loader, noisy=(i in noisy_clients and is_label_skew))
            local_weights.append(w)

        kde_weights = compute_kde_weights(local_weights, k_value)
        aggregated = fed_kde_avg(local_weights, kde_weights)
        global_model.load_state_dict(aggregated)

        acc = test(global_model, test_loader)
        global_acc.append(acc)
        print("🌍 Global Acc =", acc)

    return global_acc


k_value_FedAvg = 0
is_feature_skew_FedAvg = True
is_label_skew_FedAvg = True

k_value_LoMar = 60
is_feature_skew_LoMar = True
is_label_skew_LoMar = True

k_value_FedLC = 0
is_feature_skew_FedLC = True
is_label_skew_FedLC = False

k_value_FedRDN = 0
is_feature_skew_FedRDN = False
is_label_skew_FedRDN = True

k_value_FedKde = 8
is_feature_skew_FedKde = True
is_label_skew_FedKde = True

print("FedAvg:")
acc_FedAvg = FL_training(k_value_FedAvg, is_feature_skew_FedAvg, is_label_skew_FedAvg)
print("LoMar:")
acc_LoMar = FL_training(k_value_LoMar, is_feature_skew_LoMar, is_label_skew_LoMar)
print("FedLC:")
acc_FedLC = FL_training(k_value_FedLC, is_feature_skew_FedLC, is_label_skew_FedLC)
print("FedRDN:")
acc_FedRDN = FL_training(k_value_FedRDN, is_feature_skew_FedRDN, is_label_skew_FedRDN)
print("FedKde:")
acc_FedKde = FL_training(k_value_FedKde, is_feature_skew_FedKde, is_label_skew_FedKde)


plt.figure(figsize=(6, 4), facecolor="white")
rcParams["font.family"] = "Times New Roman"

ax = plt.gca()
ax.set_facecolor((230 / 255, 230 / 255, 238 / 255))
ax.spines["bottom"].set_color("white")
ax.spines["top"].set_color("white")
ax.spines["right"].set_color("white")
ax.spines["left"].set_color("white")

color_1 = (128 / 255, 149 / 255, 192 / 255)
color_2 = (215 / 255, 164 / 255, 133 / 255)
color_3 = (141 / 255, 185 / 255, 149 / 255)
color_4 = (197 / 255, 128 / 255, 131 / 255)
color_5 = (158 / 255, 149 / 255, 192 / 255)
colors = [color_1, color_2, color_3, color_4, color_5]

linestyles = ["-", "--", ":", "-.", (0, (3, 1, 1, 1))]

plt.plot(acc_FedAvg, label="FedAvg", color=colors[0], linestyle=linestyles[0])
plt.plot(acc_LoMar, label="LoMar", color=colors[1], linestyle=linestyles[1])
plt.plot(acc_FedLC, label="FedLC", color=colors[2], linestyle=linestyles[2])
plt.plot(acc_FedRDN, label="FedRDN", color=colors[3], linestyle=linestyles[3])
plt.plot(acc_FedKde, label="FedKde", color=colors[4], linestyle=linestyles[4])

plt.xlabel("Round")
plt.ylabel("Accuracy")
plt.title("(d) CIFAR-10")
plt.grid(True, color="white", zorder=0)

legend = plt.legend()
legend.get_frame().set_facecolor((230 / 255, 230 / 255, 238 / 255))
plt.show()
