import os
import json
import glob
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform
from matplotlib import rcParams

FEMNIST_TRAIN_DIR = "./data/femnist/train"
FEMNIST_TEST_DIR = "./data/femnist/test"

NUM_CLIENTS = 60
NUM_BINARY = 50
NUM_FULL = 10
NOISY_CLIENT_NUM = 25

EPOCHS = 1
ROUNDS = 100
BATCH_SIZE = 64
LR = 0.005
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SEED = 60
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)

def load_femnist_json_dir(json_dir):
    all_x, all_y = [], []
    json_files = sorted(glob.glob(os.path.join(json_dir, "*.json")))
    if len(json_files) == 0:
        raise FileNotFoundError(f"No json files found in: {json_dir}")

    for fp in json_files:
        with open(fp, "r") as f:
            data = json.load(f)

        user_data = data.get("user_data", {})
        for _, ud in user_data.items():
            xs = ud.get("x", [])
            ys = ud.get("y", [])
            if len(xs) == 0:
                continue
            all_x.extend(xs)
            all_y.extend(ys)

    X = np.array(all_x)
    y = np.array(all_y)
    return X, y

class FEMNISTDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y.astype(np.int64)
        self.transform = transform

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        x = self.X[idx]
        if x.ndim == 1:
            x = x.reshape(28, 28)
        elif x.ndim == 2:
            pass
        else:
            x = x.squeeze()

        x = x.astype(np.float32)
        if x.max() > 1.0:
            x = x / 255.0

        x = torch.tensor(x).unsqueeze(0)
        y = int(self.y[idx])

        if self.transform is not None:
            x = self.transform(x)

        return x, y

class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=1.0):
        self.mean = mean
        self.std = std
    def __call__(self, tensor):
        return tensor + torch.randn_like(tensor) * self.std + self.mean

class AddSaltPepperNoise(object):
    def __init__(self, prob=0.3):
        self.prob = prob
    def __call__(self, tensor):
        noise = torch.rand_like(tensor)
        out = tensor.clone()
        out[noise < self.prob/2] = 0.0
        out[noise > 1 - self.prob/2] = 1.0
        return out

class RandomPixelFlip(object):
    def __init__(self, ratio=0.2):
        self.ratio = ratio
    def __call__(self, tensor):
        out = tensor.clone()
        total = out.numel()
        k = int(total * self.ratio)
        if k <= 0:
            return out
        idx = torch.randperm(total)[:k]
        flat = out.view(-1)
        flat[idx] = 1.0 - flat[idx]
        return flat.view_as(out)

NUM_CLASSES = 62

class CNN_FEMNIST(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, NUM_CLASSES)

    def forward(self, x):
        x = torch.relu(torch.max_pool2d(self.conv1(x), 2))
        x = torch.relu(torch.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

def flip_label(y, p=0.5):
    if random.random() < p:
        return random.randint(0, NUM_CLASSES - 1)
    return y

X_train, y_train = load_femnist_json_dir(FEMNIST_TRAIN_DIR)
X_test, y_test   = load_femnist_json_dir(FEMNIST_TEST_DIR)

train_dataset_clean = FEMNISTDataset(X_train, y_train, transform=None)
test_dataset = FEMNISTDataset(X_test, y_test, transform=None)

labels = np.array(y_train)
label_indices = {l: np.where(labels == l)[0] for l in range(NUM_CLASSES)}
for l in label_indices:
    np.random.shuffle(label_indices[l])

NOISY_CLIENTS = np.random.choice(NUM_BINARY, size=NOISY_CLIENT_NUM, replace=False)

def safe_take(lbl, want):
    idxs = label_indices.get(lbl, np.array([], dtype=int))
    if len(idxs) == 0:
        return np.array([], dtype=int)
    take = min(want, len(idxs))
    start = 0 if len(idxs) - take <= 1 else np.random.randint(0, len(idxs) - take)
    return idxs[start:start+take]

def setup_clients_datasets(is_feature_skew=True, is_label_skew=True):
    client_datasets = []

    for i in range(NUM_BINARY):
        idx0 = safe_take(0, 600)
        idx1 = safe_take(1, 600)

        idx0 = safe_take(0, 100)
        idx1 = safe_take(1, 100)

        EXTRA_LABEL_NUM = np.random.randint(1, 6)
        if is_label_skew == False:
            EXTRA_LABEL_NUM = np.random.randint(2, 6)
        extra_labels = np.random.choice(range(2, NUM_CLASSES), size=EXTRA_LABEL_NUM, replace=False)

        extra_indices = []
        for lbl in extra_labels:
            idxs = label_indices.get(lbl, np.array([], dtype=int))
            if len(idxs) == 0:
                continue
            max_take = min(80, len(idxs))
            min_take = min(20, max_take)
            if is_label_skew == False:
                max_take = min(100, len(idxs))
                min_take = min(40, max_take)
            take = np.random.randint(min_take, max_take + 1)
            start = 0 if len(idxs) - take <= 1 else np.random.randint(0, len(idxs) - take)
            extra_indices.extend(idxs[start:start+take])

        extra_indices = np.array(extra_indices, dtype=int)

        subset_indices = np.concatenate([idx0, idx1, extra_indices])
        np.random.shuffle(subset_indices)

        if is_feature_skew and i in NOISY_CLIENTS:
            noisy_transform = lambda x: RandomPixelFlip(0.1)(
                AddSaltPepperNoise(0.1)(
                    AddGaussianNoise(std=0.4)(x)
                )
            )
            noisy_dataset = FEMNISTDataset(X_train, y_train, transform=noisy_transform)
            client_datasets.append(Subset(noisy_dataset, subset_indices))
        else:
            client_datasets.append(Subset(train_dataset_clean, subset_indices))

    indices_2_61 = []
    for lbl in range(2, NUM_CLASSES):
        indices_2_61.extend(label_indices.get(lbl, []))
    indices_2_61 = np.array(indices_2_61, dtype=int)
    np.random.shuffle(indices_2_61)

    indices_0_1 = np.concatenate([label_indices.get(0, []), label_indices.get(1, [])]).astype(int)
    np.random.shuffle(indices_0_1)

    print(f"标签 2~61 样本数量: {len(indices_2_61)}")
    print(f"标签 0,1 样本数量: {len(indices_0_1)}")

    samples_per_client = len(indices_2_61) // NUM_FULL if NUM_FULL > 0 else 0

    for i in range(NUM_FULL):
        subset_main = indices_2_61[i * samples_per_client: (i + 1) * samples_per_client]

        extra_num = np.random.randint(10, 31)
        extra_num = min(extra_num, len(indices_0_1))
        extra_indices = np.random.choice(indices_0_1, size=extra_num, replace=False) if extra_num > 0 else np.array([], dtype=int)

        subset = np.concatenate([subset_main, extra_indices])
        np.random.shuffle(subset)

        client_datasets.append(Subset(train_dataset_clean, subset))
        print(f"Full Client {i} → 主标签 2-61 样本 = {len(subset_main)}, 加入 0-1 样本 = {extra_num}")

    assert len(client_datasets) == NUM_CLIENTS, f"client_datasets={len(client_datasets)} != {NUM_CLIENTS}"
    
    return client_datasets

def get_loader(dataset):
    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

def local_train(model, loader, noisy=False):
    model.train()
    opt = optim.Adam(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()

    for _ in range(EPOCHS):
        total_loss = 0
        for x, y in loader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)

            opt.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            opt.step()

            total_loss += loss.item()

    return {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

def test(model, loader):
    correct, total = 0, 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            pred = model(x).argmax(dim=1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
    return correct / total

def compute_kde_weights(local_weights, k):
    flat = np.array([
        torch.cat([p.flatten() for p in w.values()]).cpu().numpy()
        for w in local_weights
    ])

    dist_matrix = squareform(pdist(flat))
    h = np.median(dist_matrix) + 1e-12

    N = len(local_weights)
    kde_vals = np.zeros(N)

    for i in range(N):
        idx = np.argsort(dist_matrix[i])[1:k+1]
        d = dist_matrix[i, idx]
        kde_vals[i] = np.mean(np.exp(-d*d/(2*h*h)))

    kde_vals += 1e-12

    binary_kde_vals = kde_vals[:NUM_BINARY] / kde_vals[:NUM_BINARY].sum()
    binary_avg = np.mean(kde_vals[:NUM_BINARY])

    full_kde_vals = kde_vals[NUM_BINARY:] / kde_vals[NUM_BINARY:].sum()
    full_avg = np.mean(kde_vals[NUM_BINARY:])

    binary_weight = binary_avg / (binary_avg + full_avg)
    full_weight = full_avg / (binary_avg + full_avg)

    binary_kde_vals = [x * binary_weight for x in binary_kde_vals]
    full_kde_vals = [x * full_weight for x in full_kde_vals]
    kde_vals = binary_kde_vals + full_kde_vals

    if k == 0:
        return np.ones(len(kde_vals)) / len(kde_vals)
    else:
        return kde_vals

def fed_kde_avg(local_weights, kde_weights):
    keys = local_weights[0].keys()
    out = {}
    for k in keys:
        out[k] = sum(kde_weights[i] * local_weights[i][k] for i in range(len(local_weights)))
    return out

test_loader = get_loader(test_dataset)

binary_acc, full_acc = [], []
binary_monitor = random.randint(0, NUM_BINARY - 1)
full_monitor = random.randint(NUM_BINARY, NUM_CLIENTS - 1)

def training(k_value, is_feature_skew=True, is_label_skew=True):
    client_datasets = setup_clients_datasets(is_feature_skew, is_label_skew)

    global_acc = []
    global_model = CNN_FEMNIST().to(DEVICE)

    for r in range(ROUNDS):
        print(f"\n===== Round {r+1}/{ROUNDS} =====")

        local_weights = []
        for i in range(NUM_CLIENTS):
            local_model = CNN_FEMNIST().to(DEVICE)
            local_model.load_state_dict(global_model.state_dict())

            loader = get_loader(client_datasets[i])
            w = local_train(local_model, loader, noisy=(i in NOISY_CLIENTS))
            local_weights.append(w)

            if i == binary_monitor:
                binary_acc.append(test(local_model, test_loader))
            if i == full_monitor:
                full_acc.append(test(local_model, test_loader))

        kde_weights = compute_kde_weights(local_weights, k_value)
        aggregated = fed_kde_avg(local_weights, kde_weights)
        global_model.load_state_dict(aggregated)

        acc = test(global_model, test_loader)
        global_acc.append(acc)
        print("🌍 Global Acc =", acc)

    return global_acc

k_value_FedAvg = 0
is_feature_skew_FedAvg = True
is_label_skew_FedAvg = True

k_value_LoMar = 60
is_feature_skew_LoMar = True
is_label_skew_LoMar = True

k_value_FedLC = 0
is_feature_skew_FedLC = True
is_label_skew_FedLC = False

k_value_FedRDN = 0
is_feature_skew_FedRDN = False
is_label_skew_FedRDN = True

k_value_FedKde = 8
is_feature_skew_FedKde = True
is_label_skew_FedKde = True

acc_FedAvg = training(k_value_FedAvg, is_feature_skew_FedAvg, is_label_skew_FedAvg)
acc_LoMar  = training(k_value_LoMar,  is_feature_skew_LoMar,  is_label_skew_LoMar)
acc_FedLC  = training(k_value_FedLC,  is_feature_skew_FedLC,  is_label_skew_FedLC)
acc_FedRDN = training(k_value_FedRDN, is_feature_skew_FedRDN, is_label_skew_FedRDN)
acc_FedKde = training(k_value_FedKde, is_feature_skew_FedKde, is_label_skew_FedKde)

plt.figure(figsize=(6, 4), facecolor='white')
rcParams['font.family'] = 'Times New Roman'

ax = plt.gca()
ax.set_facecolor((230/255, 230/255, 238/255))
ax.spines['bottom'].set_color('white')
ax.spines['top'].set_color('white')
ax.spines['right'].set_color('white')
ax.spines['left'].set_color('white')

color_1 = (128/255, 149/255, 192/255)
color_2 = (215/255, 164/255, 133/255)
color_3 = (141/255, 185/255, 149/255)
color_4 = (197/255, 128/255, 131/255)
color_5 = (158/255, 149/255, 192/255)
color_6 = (171/255, 155/255, 140/255)
colors = [color_1, color_2, color_3, color_4, color_5, color_6]

linestyles = ['-', '--', ':', '-.', (0, (3, 1, 1, 1)), (0, (8, 2))]

plt.plot(acc_FedAvg, label='FedAvg', color=colors[0], linestyle=linestyles[0])
plt.plot(acc_LoMar, label='LoMar',  color=colors[1], linestyle=linestyles[1])
plt.plot(acc_FedLC, label='FedLC',  color=colors[2], linestyle=linestyles[2])
plt.plot(acc_FedRDN, label='FedRDN', color=colors[3], linestyle=linestyles[3])
plt.plot(acc_FedKde, label='FedKde', color=colors[4], linestyle=linestyles[4])

plt.xlabel("Round")
plt.ylabel("Accuracy")
plt.title("(c) FEMNIST")
plt.grid(True, color='white', zorder=0)
legend = plt.legend()
legend.get_frame().set_facecolor((230/255, 230/255, 238/255))
plt.show()
