import os
import json
import glob
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform
from matplotlib import rcParams

FEMNIST_TRAIN_DIR = "./data/femnist/train"
FEMNIST_TEST_DIR = "./data/femnist/test"

NUM_CLIENTS = 60
NUM_BINARY = 50
NUM_FULL = 10
NOISY_CLIENT_NUM = 25

EPOCHS = 1
ROUNDS = 100
BATCH_SIZE = 64
LR = 0.005
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SEED = 60


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(SEED)


def load_femnist_json_dir(json_dir):
    all_x, all_y = [], []
    json_files = sorted(glob.glob(os.path.join(json_dir, "*.json")))
    if len(json_files) == 0:
        raise FileNotFoundError(f"No json files found in: {json_dir}")

    for fp in json_files:
        with open(fp, "r") as f:
            data = json.load(f)

        user_data = data.get("user_data", {})
        for _, ud in user_data.items():
            xs = ud.get("x", [])
            ys = ud.get("y", [])
            if len(xs) == 0:
                continue
            all_x.extend(xs)
            all_y.extend(ys)

    X = np.array(all_x)
    y = np.array(all_y)
    return X, y


class FEMNISTDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y.astype(np.int64)
        self.transform = transform

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        x = self.X[idx]
        if x.ndim == 1:
            x = x.reshape(28, 28)
        elif x.ndim == 2:
            pass
        else:
            x = x.squeeze()

        x = x.astype(np.float32)
        if x.max() > 1.0:
            x = x / 255.0

        x = torch.tensor(x).unsqueeze(0)
        y = int(self.y[idx])

        if self.transform is not None:
            x = self.transform(x)

        return x, y


class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=1.0):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        return tensor + torch.randn_like(tensor) * self.std + self.mean


class AddSaltPepperNoise(object):
    def __init__(self, prob=0.3):
        self.prob = prob

    def __call__(self, tensor):
        noise = torch.rand_like(tensor)
        out = tensor.clone()
        out[noise < self.prob / 2] = 0.0
        out[noise > 1 - self.prob / 2] = 1.0
        return out


class RandomPixelFlip(object):
    def __init__(self, ratio=0.2):
        self.ratio = ratio

    def __call__(self, tensor):
        out = tensor.clone()
        total = out.numel()
        k = int(total * self.ratio)
        if k <= 0:
            return out
        idx = torch.randperm(total)[:k]
        flat = out.view(-1)
        flat[idx] = 1.0 - flat[idx]
        return flat.view_as(out)


NUM_CLASSES = 62


class CNN_FEMNIST(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, NUM_CLASSES)

    def forward(self, x):
        x = torch.relu(torch.max_pool2d(self.conv1(x), 2))
        x = torch.relu(torch.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)


def flip_label(y, p=0.5):
    if random.random() < p:
        return random.randint(0, NUM_CLASSES - 1)
    return y


X_train, y_train = load_femnist_json_dir(FEMNIST_TRAIN_DIR)
X_test, y_test = load_femnist_json_dir(FEMNIST_TEST_DIR)

train_dataset_clean = FEMNISTDataset(X_train, y_train, transform=None)
test_dataset = FEMNISTDataset(X_test, y_test, transform=None)

print("Train samples:", len(train_dataset_clean), " Test samples:", len(test_dataset))

labels = np.array(y_train)
label_indices = {l: np.where(labels == l)[0] for l in range(NUM_CLASSES)}
for l in label_indices:
    np.random.shuffle(label_indices[l])

client_datasets = []

NOISY_CLIENTS = np.random.choice(NUM_BINARY, size=NOISY_CLIENT_NUM, replace=False)
print("🎨 noisy 客户端:", NOISY_CLIENTS)

for i in range(NUM_BINARY):

    def safe_take(lbl, want):
        idxs = label_indices.get(lbl, np.array([], dtype=int))
        if len(idxs) == 0:
            return np.array([], dtype=int)
        take = min(want, len(idxs))
        start = 0 if len(idxs) - take <= 1 else np.random.randint(0, len(idxs) - take)
        return idxs[start:start + take]

    idx0 = safe_take(0, 600)
    idx1 = safe_take(1, 600)

    EXTRA_LABEL_NUM = np.random.randint(1, 6)
    extra_labels = np.random.choice(range(2, NUM_CLASSES), size=EXTRA_LABEL_NUM, replace=False)

    extra_indices = []
    for lbl in extra_labels:
        idxs = label_indices.get(lbl, np.array([], dtype=int))
        if len(idxs) == 0:
            continue
        max_take = min(80, len(idxs))
        min_take = min(20, max_take)
        take = np.random.randint(min_take, max_take + 1)
        start = 0 if len(idxs) - take <= 1 else np.random.randint(0, len(idxs) - take)
        extra_indices.extend(idxs[start:start + take])

    extra_indices = np.array(extra_indices, dtype=int)

    subset_indices = np.concatenate([idx0, idx1, extra_indices])
    np.random.shuffle(subset_indices)

    if i in NOISY_CLIENTS:
        noisy_transform = lambda x: RandomPixelFlip(0.1)(
            AddSaltPepperNoise(0.1)(
                AddGaussianNoise(std=0.4)(x)
            )
        )
        noisy_dataset = FEMNISTDataset(X_train, y_train, transform=noisy_transform)
        client_datasets.append(Subset(noisy_dataset, subset_indices))
    else:
        client_datasets.append(Subset(train_dataset_clean, subset_indices))

indices_2_61 = []
for lbl in range(2, NUM_CLASSES):
    indices_2_61.extend(label_indices.get(lbl, []))
indices_2_61 = np.array(indices_2_61, dtype=int)
np.random.shuffle(indices_2_61)

indices_0_1 = np.concatenate([label_indices.get(0, []), label_indices.get(1, [])]).astype(int)
np.random.shuffle(indices_0_1)

print(f"标签 2~61 样本数量: {len(indices_2_61)}")
print(f"标签 0,1 样本数量: {len(indices_0_1)}")

samples_per_client = len(indices_2_61) // NUM_FULL if NUM_FULL > 0 else 0

for i in range(NUM_FULL):
    subset_main = indices_2_61[i * samples_per_client: (i + 1) * samples_per_client]

    extra_num = np.random.randint(10, 31)
    extra_num = min(extra_num, len(indices_0_1))
    extra_indices = np.random.choice(indices_0_1, size=extra_num, replace=False) if extra_num > 0 else np.array([], dtype=int)

    subset = np.concatenate([subset_main, extra_indices])
    np.random.shuffle(subset)

    client_datasets.append(Subset(train_dataset_clean, subset))
    print(f"Full Client {i} → 主标签 2-61 样本 = {len(subset_main)}, 加入 0-1 样本 = {extra_num}")

assert len(client_datasets) == NUM_CLIENTS, f"client_datasets={len(client_datasets)} != {NUM_CLIENTS}"


def get_loader(dataset):
    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


def local_train(model, loader, noisy=False):
    model.train()
    opt = optim.Adam(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()

    for _ in range(EPOCHS):
        total_loss = 0
        for x, y in loader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)

            if noisy:
                y = torch.tensor([flip_label(int(t), p=0.5) for t in y], device=DEVICE)

            opt.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            opt.step()

            total_loss += loss.item()

    return {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}


def test(model, loader):
    correct, total = 0, 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            pred = model(x).argmax(dim=1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
    return correct / total


def compute_kde_weights(local_weights, k):
    flat = np.array([
        torch.cat([p.flatten() for p in w.values()]).cpu().numpy()
        for w in local_weights
    ])

    dist_matrix = squareform(pdist(flat))
    h = np.median(dist_matrix) + 1e-12

    N = len(local_weights)
    kde_vals = np.zeros(N)

    for i in range(N):
        idx = np.argsort(dist_matrix[i])[1:k + 1]
        d = dist_matrix[i, idx]
        kde_vals[i] = np.mean(np.exp(-d * d / (2 * h * h)))

    kde_vals += 1e-12

    binary_kde_vals = kde_vals[:NUM_BINARY] / kde_vals[:NUM_BINARY].sum()
    binary_avg = np.mean(kde_vals[:NUM_BINARY])

    full_kde_vals = kde_vals[NUM_BINARY:] / kde_vals[NUM_BINARY:].sum()
    full_avg = np.mean(kde_vals[NUM_BINARY:])

    binary_weight = binary_avg / (binary_avg + full_avg)
    full_weight = full_avg / (binary_avg + full_avg)

    binary_kde_vals = [x * binary_weight for x in binary_kde_vals]
    full_kde_vals = [x * full_weight for x in full_kde_vals]
    kde_vals = binary_kde_vals + full_kde_vals

    return kde_vals


def fed_kde_avg(local_weights, kde_weights):
    keys = local_weights[0].keys()
    out = {}
    for k in keys:
        out[k] = sum(kde_weights[i] * local_weights[i][k] for i in range(len(local_weights)))
    return out


test_loader = get_loader(test_dataset)

binary_acc, full_acc = [], []
binary_monitor = random.randint(0, NUM_BINARY - 1)
full_monitor = random.randint(NUM_BINARY, NUM_CLIENTS - 1)

k_values = [5, 8, 10, 13, 15, 20]
global_accs = []

for k_value in k_values:
    print("k_value =", k_value)

    global_acc = []
    global_model = CNN_FEMNIST().to(DEVICE)

    for r in range(ROUNDS):
        print(f"\n===== Round {r + 1}/{ROUNDS} =====")

        local_weights = []
        for i in range(NUM_CLIENTS):
            local_model = CNN_FEMNIST().to(DEVICE)
            local_model.load_state_dict(global_model.state_dict())

            loader = get_loader(client_datasets[i])
            w = local_train(local_model, loader, noisy=(i in NOISY_CLIENTS))
            local_weights.append(w)

            if i == binary_monitor:
                binary_acc.append(test(local_model, test_loader))
            if i == full_monitor:
                full_acc.append(test(local_model, test_loader))

        kde_weights = compute_kde_weights(local_weights, k_value)
        aggregated = fed_kde_avg(local_weights, kde_weights)
        global_model.load_state_dict(aggregated)

        acc = test(global_model, test_loader)
        global_acc.append(acc)
        print("🌍 Global Acc =", acc)

    global_accs.append(global_acc)


plt.figure(figsize=(6, 4), facecolor='white')
rcParams['font.family'] = 'Times New Roman'

ax = plt.gca()
ax.set_facecolor((230 / 255, 230 / 255, 238 / 255))
for s in ['bottom', 'top', 'right', 'left']:
    ax.spines[s].set_color('white')

colors = [
    (128 / 255, 149 / 255, 192 / 255),
    (215 / 255, 164 / 255, 133 / 255),
    (141 / 255, 185 / 255, 149 / 255),
    (197 / 255, 128 / 255, 131 / 255),
    (158 / 255, 149 / 255, 192 / 255),
    (171 / 255, 155 / 255, 140 / 255),
]
linestyles = ['-', '--', ':', '-.', (0, (3, 1, 1, 1)), (0, (8, 2))]

for idx, k_value in enumerate(k_values):
    plt.plot(global_accs[idx], label=f'k={k_value}', color=colors[idx], linestyle=linestyles[idx])

plt.xlabel("Round")
plt.ylabel("Accuracy")
plt.title("(c) FEMNIST")
plt.grid(True, color='white', zorder=0)
legend = plt.legend()
legend.get_frame().set_facecolor((230 / 255, 230 / 255, 238 / 255))
plt.show()
