import os
import json
import glob
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform
from matplotlib import rcParams

CLEAN_PTH = "./saved_models/femnist/without_noisy/round80_local_models.pth"
NOISY_PTH = "./saved_models/femnist/with_noisy/round80_local_models.pth"

TOTAL_CLIENTS = 60
NUM_FULL = 25
NUM_BINARY = TOTAL_CLIENTS - NUM_FULL
START_CLIENT = 0
VALUE_MODE = "mean"

KF_ALPHA = 1e-2
GAUSSIAN_STD = 0.02
KF_BETA = GAUSSIAN_STD
KF_P0 = 1.0
LAMBDA = 0.1
RHO = 0.1

FEMNIST_TRAIN_DIR = "./data/femnist/train"
FEMNIST_TEST_DIR = "./data/femnist/test"

NUM_CLIENTS = 60
NUM_FULL = 25
NUM_BINARY = NUM_CLIENTS - NUM_FULL
NOISY_CLIENT_NUM = 0

EPOCHS = 1
ROUNDS = 90
BATCH_SIZE = 64
LR = 0.005
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SAVE_ROUND = 8
SAVE_PATH = "./saved_models/femnist/without_noisy/"

SEED = 60
NUM_CLASSES = 62

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)

def flatten_state_dict(state_dict):
    return torch.cat([p.detach().cpu().flatten() for p in state_dict.values()]).numpy()

def tensor_to_value(t: torch.Tensor, mode: str) -> float:
    x = t.detach().cpu().numpy().reshape(-1)
    N = x.shape[0]
    if mode == "mean":
        return float(x.mean())
    if mode == "l2":
        return float(np.linalg.norm(x, ord=2))
    if mode == "ms":
        return float((x ** 2).sum() / N)
    raise ValueError("VALUE_MODE must be 'mean', 'l2', or 'ms'")

def kalman_filter_scalar_sequence(observations, alpha, beta, p0=1.0, m0=None):
    z = np.asarray(observations, dtype=np.float64)
    dz = z[1:] - z[:-1]
    _ = np.var(dz)
    n = len(z)
    if n == 0:
        return np.array([]), np.array([]), np.array([])
    M = float(z[0]) if m0 is None else float(m0)
    P = float(p0)
    M_filt = np.zeros(n, dtype=np.float64)
    K_list = np.zeros(n, dtype=np.float64)
    P_list = np.zeros(n, dtype=np.float64)
    for k in range(n):
        M_minus = M
        P_minus = P + alpha
        K = P_minus / (P_minus + beta)
        M = M_minus + K * (z[k] - M_minus)
        P = (1.0 - K) ** 2 * P_minus + K ** 2 * beta
        M_filt[k] = M
        K_list[k] = K
        P_list[k] = P
    return M_filt, K_list, P_list

def add_zero_mean_noise_to_models(clean_path, noisy_path, std):
    if not os.path.exists(clean_path):
        raise FileNotFoundError("CLEAN_PTH 文件不存在，请检查路径是否正确")
    local_models = torch.load(clean_path, map_location="cpu")
    noisy_models = []
    for state_dict in local_models:
        noisy_state = {}
        for k, v in state_dict.items():
            if isinstance(v, torch.Tensor):
                noise = torch.randn_like(v) * std
                noisy_state[k] = v + noise
            else:
                noisy_state[k] = v
        noisy_models.append(noisy_state)
    os.makedirs(os.path.dirname(noisy_path), exist_ok=True)
    torch.save(noisy_models, noisy_path)

def load_femnist_json_dir(json_dir):
    all_x, all_y = [], []
    json_files = sorted(glob.glob(os.path.join(json_dir, "*.json")))
    if len(json_files) == 0:
        raise FileNotFoundError(f"No json files found in: {json_dir}")
    for fp in json_files:
        with open(fp, "r") as f:
            data = json.load(f)
        user_data = data.get("user_data", {})
        for _, ud in user_data.items():
            xs = ud.get("x", [])
            ys = ud.get("y", [])
            if len(xs) == 0:
                continue
            all_x.extend(xs)
            all_y.extend(ys)
    X = np.array(all_x)
    y = np.array(all_y)
    return X, y

class FEMNISTDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y.astype(np.int64)
        self.transform = transform

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        x = self.X[idx]
        if x.ndim == 1:
            x = x.reshape(28, 28)
        elif x.ndim == 2:
            pass
        else:
            x = x.squeeze()
        x = x.astype(np.float32)
        if x.max() > 1.0:
            x = x / 255.0
        x = torch.tensor(x).unsqueeze(0)
        y = int(self.y[idx])
        if self.transform is not None:
            x = self.transform(x)
        return x, y

class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=1.0):
        self.mean = mean
        self.std = std
    def __call__(self, tensor):
        return tensor + torch.randn_like(tensor) * self.std + self.mean

class AddSaltPepperNoise(object):
    def __init__(self, prob=0.3):
        self.prob = prob
    def __call__(self, tensor):
        noise = torch.rand_like(tensor)
        out = tensor.clone()
        out[noise < self.prob / 2] = 0.0
        out[noise > 1 - self.prob / 2] = 1.0
        return out

class RandomPixelFlip(object):
    def __init__(self, ratio=0.2):
        self.ratio = ratio
    def __call__(self, tensor):
        out = tensor.clone()
        total = out.numel()
        k = int(total * self.ratio)
        if k <= 0:
            return out
        idx = torch.randperm(total)[:k]
        flat = out.view(-1)
        flat[idx] = 1.0 - flat[idx]
        return flat.view_as(out)

class CNN_FEMNIST(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, NUM_CLASSES)

    def forward(self, x):
        x = torch.relu(torch.max_pool2d(self.conv1(x), 2))
        x = torch.relu(torch.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

def flip_label(y, p=0.5):
    if random.random() < p:
        return random.randint(0, NUM_CLASSES - 1)
    return y

def get_loader(dataset):
    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

def local_train(model, loader, noisy=False):
    model.train()
    opt = optim.Adam(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()
    for _ in range(EPOCHS):
        total_loss = 0
        for x, y in loader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            if noisy:
                y = torch.tensor([flip_label(int(t), p=0.5) for t in y], device=DEVICE)
            opt.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            opt.step()
            total_loss += loss.item()
    return {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

def test(model, loader):
    correct, total = 0, 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            pred = model(x).argmax(dim=1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
    return correct / total

def compute_kde_weights(local_weights, k):
    flat = np.array([
        torch.cat([p.flatten() for p in w.values()]).cpu().numpy()
        for w in local_weights
    ])
    dist_matrix = squareform(pdist(flat))
    h = np.median(dist_matrix) + 1e-12
    N = len(local_weights)
    kde_vals = np.zeros(N)
    for i in range(N):
        idx = np.argsort(dist_matrix[i])[1:k + 1]
        d = dist_matrix[i, idx]
        kde_vals[i] = np.mean(np.exp(-d * d / (2 * h * h)))
    kde_vals += 1e-12
    binary_kde_vals = kde_vals[:NUM_BINARY] / kde_vals[:NUM_BINARY].sum()
    binary_avg = np.mean(kde_vals[:NUM_BINARY])
    full_kde_vals = kde_vals[NUM_BINARY:] / kde_vals[NUM_BINARY:].sum()
    full_avg = np.mean(kde_vals[NUM_BINARY:])
    binary_weight = binary_avg / (binary_avg + full_avg)
    full_weight = full_avg / (binary_avg + full_avg)
    binary_kde_vals = [x * binary_weight for x in binary_kde_vals]
    full_kde_vals = [x * full_weight for x in full_kde_vals]
    kde_vals = binary_kde_vals + full_kde_vals
    return kde_vals

def fed_kde_avg(local_weights, kde_weights):
    keys = local_weights[0].keys()
    out = {}
    for k in keys:
        out[k] = sum(kde_weights[i] * local_weights[i][k] for i in range(len(local_weights)))
    return out

def run_training_and_save():
    X_train, y_train = load_femnist_json_dir(FEMNIST_TRAIN_DIR)
    X_test, y_test = load_femnist_json_dir(FEMNIST_TEST_DIR)

    train_dataset_clean = FEMNISTDataset(X_train, y_train, transform=None)
    test_dataset = FEMNISTDataset(X_test, y_test, transform=None)

    labels = np.array(y_train)
    label_indices = {l: np.where(labels == l)[0] for l in range(NUM_CLASSES)}
    for l in label_indices:
        np.random.shuffle(label_indices[l])

    client_datasets = []
    NOISY_CLIENTS = np.random.choice(NUM_BINARY, size=NOISY_CLIENT_NUM, replace=False)

    for i in range(NUM_BINARY):
        def safe_take(lbl, want):
            idxs = label_indices.get(lbl, np.array([], dtype=int))
            if len(idxs) == 0:
                return np.array([], dtype=int)
            take = min(want, len(idxs))
            start = 0 if len(idxs) - take <= 1 else np.random.randint(0, len(idxs) - take)
            return idxs[start:start + take]

        idx0 = safe_take(0, 600)
        idx1 = safe_take(1, 600)

        EXTRA_LABEL_NUM = np.random.randint(1, 6)
        extra_labels = np.random.choice(range(2, NUM_CLASSES), size=EXTRA_LABEL_NUM, replace=False)

        extra_indices = []
        for lbl in extra_labels:
            idxs = label_indices.get(lbl, np.array([], dtype=int))
            if len(idxs) == 0:
                continue
            max_take = min(80, len(idxs))
            min_take = min(20, max_take)
            take = np.random.randint(min_take, max_take + 1)
            start = 0 if len(idxs) - take <= 1 else np.random.randint(0, len(idxs) - take)
            extra_indices.extend(idxs[start:start + take])

        extra_indices = np.array(extra_indices, dtype=int)

        subset_indices = np.concatenate([idx0, idx1, extra_indices])
        np.random.shuffle(subset_indices)

        if i in NOISY_CLIENTS:
            noisy_transform = lambda x: RandomPixelFlip(0.1)(
                AddSaltPepperNoise(0.1)(
                    AddGaussianNoise(std=0.4)(x)
                )
            )
            noisy_dataset = FEMNISTDataset(X_train, y_train, transform=noisy_transform)
            client_datasets.append(Subset(noisy_dataset, subset_indices))
        else:
            client_datasets.append(Subset(train_dataset_clean, subset_indices))

    indices_2_61 = []
    for lbl in range(2, NUM_CLASSES):
        indices_2_61.extend(label_indices.get(lbl, []))
    indices_2_61 = np.array(indices_2_61, dtype=int)
    np.random.shuffle(indices_2_61)

    indices_0_1 = np.concatenate([label_indices.get(0, []), label_indices.get(1, [])]).astype(int)
    np.random.shuffle(indices_0_1)

    samples_per_client = len(indices_2_61) // NUM_FULL if NUM_FULL > 0 else 0

    for i in range(NUM_FULL):
        subset_main = indices_2_61[i * samples_per_client: (i + 1) * samples_per_client]
        extra_num = np.random.randint(10, 31)
        extra_num = min(extra_num, len(indices_0_1))
        extra_indices = np.random.choice(indices_0_1, size=extra_num, replace=False) if extra_num > 0 else np.array([], dtype=int)
        subset = np.concatenate([subset_main, extra_indices])
        np.random.shuffle(subset)
        client_datasets.append(Subset(train_dataset_clean, subset))

    assert len(client_datasets) == NUM_CLIENTS

    test_loader = get_loader(test_dataset)

    k_values = [8]
    global_accs = []

    for k_value in k_values:
        global_acc = []
        global_model = CNN_FEMNIST().to(DEVICE)

        for r in range(ROUNDS):
            local_weights = []
            for i in range(NUM_CLIENTS):
                local_model = CNN_FEMNIST().to(DEVICE)
                local_model.load_state_dict(global_model.state_dict())
                loader = get_loader(client_datasets[i])
                w = local_train(local_model, loader, noisy=(i in NOISY_CLIENTS))
                local_weights.append(w)

            kde_weights = compute_kde_weights(local_weights, k_value)
            aggregated = fed_kde_avg(local_weights, kde_weights)
            global_model.load_state_dict(aggregated)

            acc = test(global_model, test_loader)
            global_acc.append(acc)

            if (r + 1) == SAVE_ROUND:
                torch.save(local_weights, f"{SAVE_PATH}/round{SAVE_ROUND}_local_models.pth")
            if (r + 1) == 18:
                torch.save(local_weights, f"{SAVE_PATH}/round{18}_local_models.pth")
            if (r + 1) == 28:
                torch.save(local_weights, f"{SAVE_PATH}/round{28}_local_models.pth")
            if (r + 1) == 80:
                torch.save(local_weights, f"{SAVE_PATH}/round{80}_local_models.pth")

        global_accs.append(global_acc)

    plt.figure(figsize=(6, 4), facecolor="white")
    rcParams["font.family"] = "Times New Roman"
    ax = plt.gca()
    ax.set_facecolor((230 / 255, 230 / 255, 238 / 255))
    for s in ["bottom", "top", "right", "left"]:
        ax.spines[s].set_color("white")

    colors = [
        (128 / 255, 149 / 255, 192 / 255),
        (215 / 255, 164 / 255, 133 / 255),
        (141 / 255, 185 / 255, 149 / 255),
        (197 / 255, 128 / 255, 131 / 255),
        (158 / 255, 149 / 255, 192 / 255),
        (171 / 255, 155 / 255, 140 / 255),
    ]
    linestyles = ["-", "--", ":", "-.", (0, (3, 1, 1, 1)), (0, (8, 2))]

    for idx, k_value in enumerate(k_values):
        plt.plot(global_accs[idx], label=f"k={k_value}", color=colors[idx], linestyle=linestyles[idx])

    plt.xlabel("Round")
    plt.ylabel("Accuracy")
    plt.title("FEMNIST")
    plt.grid(True, color="white", zorder=0)
    legend = plt.legend()
    legend.get_frame().set_facecolor((230 / 255, 230 / 255, 238 / 255))
    plt.show()

def analyze_saved_models_with_kf():
    if not os.path.exists(CLEAN_PTH):
        raise FileNotFoundError(f"❌ Clean model file not found: {CLEAN_PTH}")
    if not os.path.exists(NOISY_PTH):
        raise FileNotFoundError(f"❌ Noisy model file not found: {NOISY_PTH}")

    clean_models = torch.load(CLEAN_PTH, map_location="cpu")
    noisy_models = torch.load(NOISY_PTH, map_location="cpu")

    if not isinstance(clean_models, (list, tuple)) or len(clean_models) == 0:
        raise ValueError("❌ CLEAN_PTH 内容不是非空 list/tuple（期望多个客户端 state_dict 列表）")
    if not isinstance(noisy_models, (list, tuple)) or len(noisy_models) == 0:
        raise ValueError("❌ NOISY_PTH 内容不是非空 list/tuple（期望多个客户端 state_dict 列表）")

    if len(clean_models) != len(noisy_models):
        raise ValueError(f"❌ clean/noisy 客户端数不一致：{len(clean_models)} vs {len(noisy_models)}")

    num_clients = len(clean_models)

    fv_clean = np.array([flatten_state_dict(m) for m in clean_models])

    if START_CLIENT < 0 or START_CLIENT >= num_clients:
        raise ValueError(f"❌ START_CLIENT out of range: {START_CLIENT}")

    dist_mat = squareform(pdist(fv_clean))
    order = np.argsort(dist_mat[START_CLIENT])

    param_name = list(clean_models[0].keys())[0]

    clean_vals = np.array([abs(tensor_to_value(clean_models[i][param_name], VALUE_MODE)) for i in order], dtype=np.float64)
    noisy_vals = np.array([abs(tensor_to_value(noisy_models[i][param_name], VALUE_MODE)) for i in order], dtype=np.float64)

    noise_sequence = noisy_vals - clean_vals

    noise_kf, K_list, P_list = kalman_filter_scalar_sequence(
        noisy_vals, alpha=KF_ALPHA, beta=KF_BETA, p0=KF_P0, m0=noisy_vals[0]
    )
    noisy_kf = noise_kf

    x = np.arange(1, num_clients + 1)

    isLineChart = False

    if isLineChart:
        plt.figure(figsize=(8, 3))
        plt.plot(x, clean_vals, marker="o", label="clean model parameter")
        plt.plot(x, noisy_vals, marker="o", label="noisy model parameter (raw)")
        plt.plot(x, noisy_kf, linewidth=2, label=f"noisy model parameter (KF, α={KF_ALPHA}, β={KF_BETA})")
        plt.fill_between(x, clean_vals, noisy_vals, alpha=0.12)
        plt.xlabel("clients reordered by distance (near → far)")
        plt.ylabel(f"parameter tensor {VALUE_MODE} value")
        plt.title("Model parameter values: clean vs noisy (raw) vs noisy (Kalman-filtered)")
        plt.legend()
        plt.grid(True, linestyle="--", alpha=0.3)
        plt.tight_layout()
        plt.show()
    else:
        idx = np.linspace(0, len(clean_vals) - 1, 10, dtype=int)

        color_1 = (128 / 255, 149 / 255, 192 / 255)
        color_2 = (215 / 255, 164 / 255, 133 / 255)
        color_3 = (141 / 255, 185 / 255, 149 / 255)

        x_sample = x[idx]
        clean_sample = clean_vals[idx]
        noisy_sample = noisy_vals[idx]
        kf_sample = noisy_kf[idx]

        plt.figure(figsize=(6, 4))
        rcParams["font.family"] = "Times New Roman"

        ax = plt.gca()
        ax.set_facecolor((230 / 255, 230 / 255, 238 / 255))
        ax.spines["bottom"].set_color("white")
        ax.spines["top"].set_color("white")
        ax.spines["right"].set_color("white")
        ax.spines["left"].set_color("white")

        width = 1.6

        plt.bar(x_sample - width, clean_sample, width=width, label="without noisy", color=color_1, zorder=2)
        plt.bar(x_sample, kf_sample, width=width, label="with noisy after filter", color=color_3, zorder=4)
        plt.bar(x_sample + width, noisy_sample, width=width, label="with noisy before filter", color=color_2, zorder=3)

        plt.ylim(top=0.011)
        plt.xlabel("clients")
        plt.ylabel("margin")
        plt.title("FEMNIST")
        plt.grid(True, color="white", zorder=0)
        legend = plt.legend()
        legend.get_frame().set_facecolor((230 / 255, 230 / 255, 238 / 255))
        plt.tight_layout()
        plt.show()

def main():
    analyze_saved_models_with_kf()

if __name__ == "__main__":
    main()
