
import os
import random
import argparse
from dataclasses import dataclass
from typing import Dict, List, Tuple, Sequence, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from matplotlib import rcParams
from scipy.spatial.distance import pdist, squareform
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms


@dataclass
class MnistKFPlotConfig:
    clean_pth: str = "./saved_models/mnist/without_noisy/round8_local_models.pth"
    noisy_pth: str = "./saved_models/mnist/with_noisy/round8_local_models.pth"
    total_clients: int = 60
    num_full: int = 25
    start_client: int = 0
    value_mode: str = "mean"  # "mean" | "l2" | "ms"
    kf_alpha: float = 1e-2
    gaussian_std: float = 0.02
    kf_beta: float = 0.02
    kf_p0: float = 1.0
    lam: float = 0.1
    rho: float = 0.1
    is_line_chart: bool = False
    bar_sample_n: int = 10
    y_top: Optional[float] = 0.03
    title: str = "MNIST"


def flatten_state_dict(state_dict: Dict[str, torch.Tensor]) -> np.ndarray:
    return torch.cat([p.detach().cpu().flatten() for p in state_dict.values()]).numpy()


def tensor_to_value(t: torch.Tensor, mode: str) -> float:
    x = t.detach().cpu().numpy().reshape(-1)
    N = x.shape[0]
    if mode == "mean":
        return float(x.mean())
    if mode == "l2":
        return float(np.linalg.norm(x, ord=2))
    if mode == "ms":
        return float((x**2).sum() / N)
    raise ValueError("value_mode must be 'mean', 'l2', or 'ms'")


def kalman_filter_scalar_sequence(
    observations: Sequence[float],
    alpha: float,
    beta: float,
    total_clients: int,
    num_full: int,
    lam: float,
    rho: float,
    p0: float = 1.0,
    m0: Optional[float] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    z = np.asarray(observations, dtype=np.float64)
    n = len(z)
    if n == 0:
        return np.array([]), np.array([]), np.array([])

    M = float(z[0]) if m0 is None else float(m0)
    P = float(p0)

    M_filt = np.zeros(n, dtype=np.float64)
    K_list = np.zeros(n, dtype=np.float64)
    P_list = np.zeros(n, dtype=np.float64)

    boundary = total_clients - num_full

    for k in range(n):
        if k == boundary:
            M_minus = rho * M + (1.0 - rho) * z[k]
            P_minus = (1.0 / lam) * P + alpha
        else:
            M_minus = M
            P_minus = P + alpha

        K = P_minus / (P_minus + beta)
        M = M_minus + K * (z[k] - M_minus)

        P = (1.0 - K) ** 2 * P_minus + (K**2) * beta

        M_filt[k] = M
        K_list[k] = K
        P_list[k] = P

    return M_filt, K_list, P_list


def run_mnist_kf_plot(cfg: MnistKFPlotConfig) -> None:
    if not os.path.exists(cfg.clean_pth):
        raise FileNotFoundError(f"Clean model file not found: {cfg.clean_pth}")
    if not os.path.exists(cfg.noisy_pth):
        raise FileNotFoundError(f"Noisy model file not found: {cfg.noisy_pth}")

    clean_models = torch.load(cfg.clean_pth, map_location="cpu")
    noisy_models = torch.load(cfg.noisy_pth, map_location="cpu")

    if not isinstance(clean_models, (list, tuple)) or len(clean_models) == 0:
        raise ValueError("CLEAN_PTH is not a non-empty list/tuple of state_dicts.")
    if not isinstance(noisy_models, (list, tuple)) or len(noisy_models) == 0:
        raise ValueError("NOISY_PTH is not a non-empty list/tuple of state_dicts.")
    if len(clean_models) != len(noisy_models):
        raise ValueError(f"clean/noisy client count mismatch: {len(clean_models)} vs {len(noisy_models)}")

    num_clients = len(clean_models)
    if cfg.start_client < 0 or cfg.start_client >= num_clients:
        raise ValueError(f"start_client out of range: {cfg.start_client}")

    fv_clean = np.array([flatten_state_dict(m) for m in clean_models])
    dist_mat = squareform(pdist(fv_clean))
    order = np.argsort(dist_mat[cfg.start_client])

    param_name = list(clean_models[0].keys())[0]

    clean_vals = np.array([tensor_to_value(clean_models[i][param_name], cfg.value_mode) for i in order], dtype=np.float64)
    noisy_vals = np.array([tensor_to_value(noisy_models[i][param_name], cfg.value_mode) for i in order], dtype=np.float64)

    noisy_kf, _, _ = kalman_filter_scalar_sequence(
        noisy_vals,
        alpha=cfg.kf_alpha,
        beta=cfg.kf_beta,
        total_clients=cfg.total_clients,
        num_full=cfg.num_full,
        lam=cfg.lam,
        rho=cfg.rho,
        p0=cfg.kf_p0,
        m0=float(noisy_vals[0]) if len(noisy_vals) > 0 else None,
    )

    x = np.arange(1, num_clients + 1)

    rcParams["font.family"] = "Times New Roman"

    if cfg.is_line_chart:
        plt.figure(figsize=(8, 3))
        plt.plot(x, clean_vals, marker="o", label="clean model parameter")
        plt.plot(x, noisy_vals, marker="o", label="noisy model parameter (raw)")
        plt.plot(x, noisy_kf, linewidth=2, label=f"noisy model parameter (KF, α={cfg.kf_alpha}, β={cfg.kf_beta})")
        plt.fill_between(x, clean_vals, noisy_vals, alpha=0.12)
        plt.xlabel("clients reordered by distance (near → far)")
        plt.ylabel(f"parameter tensor {cfg.value_mode} value")
        plt.title(cfg.title)
        plt.legend()
        plt.grid(True, linestyle="--", alpha=0.3)
        plt.tight_layout()
        plt.show()
        return

    idx = np.linspace(0, len(clean_vals) - 1, cfg.bar_sample_n, dtype=int) if len(clean_vals) > 0 else np.array([], dtype=int)
    x_sample = x[idx]
    clean_sample = clean_vals[idx]
    noisy_sample = noisy_vals[idx]
    kf_sample = noisy_kf[idx]

    color_1 = (128 / 255, 149 / 255, 192 / 255)
    color_2 = (215 / 255, 164 / 255, 133 / 255)
    color_3 = (141 / 255, 185 / 255, 149 / 255)

    plt.figure(figsize=(6, 4))
    ax = plt.gca()
    ax.set_facecolor((230 / 255, 230 / 255, 238 / 255))
    for s in ["bottom", "top", "right", "left"]:
        ax.spines[s].set_color("white")

    width = 1.6
    plt.bar(x_sample - width, clean_sample, width=width, label="without noisy", color=color_1, zorder=2)
    plt.bar(x_sample, kf_sample, width=width, label="with noisy after filter", color=color_3, zorder=4)
    plt.bar(x_sample + width, noisy_sample, width=width, label="with noisy before filter", color=color_2, zorder=3)

    if cfg.y_top is not None:
        plt.ylim(top=cfg.y_top)

    plt.xlabel("clients")
    plt.ylabel("margin")
    plt.title(cfg.title)
    plt.grid(True, color="white", zorder=0)
    legend = plt.legend()
    legend.get_frame().set_facecolor((230 / 255, 230 / 255, 238 / 255))
    plt.tight_layout()
    plt.show()


@dataclass
class AddNoiseConfig:
    clean_pth: str = "./saved_models/cifar_10/without_noisy/round8_local_models.pth"
    noisy_pth: str = "./saved_models/cifar_10/with_noisy/round8_local_models.pth"
    gaussian_std: float = 0.02


def add_zero_mean_noise_to_models(clean_path: str, noisy_path: str, std: float) -> None:
    if not os.path.exists(clean_path):
        raise FileNotFoundError(f"Clean file not found: {clean_path}")

    local_models = torch.load(clean_path, map_location="cpu")
    if not isinstance(local_models, (list, tuple)) or len(local_models) == 0:
        raise ValueError("clean_path should contain a non-empty list/tuple of state_dicts.")

    noisy_models: List[Dict[str, torch.Tensor]] = []
    for state_dict in local_models:
        noisy_state: Dict[str, torch.Tensor] = {}
        for k, v in state_dict.items():
            if isinstance(v, torch.Tensor):
                noise = torch.randn_like(v) * std
                noisy_state[k] = v + noise
            else:
                noisy_state[k] = v
        noisy_models.append(noisy_state)

    os.makedirs(os.path.dirname(noisy_path), exist_ok=True)
    torch.save(noisy_models, noisy_path)



@dataclass
class CifarKDETrainConfig:
    num_clients: int = 60
    num_full: int = 25
    noisy_client_num: int = 25
    gaussian_noise_parameter: float = 0.3

    epochs: int = 1
    rounds: int = 9
    batch_size: int = 64
    lr: float = 0.01
    seed: int = 60

    save_round: int = 8
    save_path: str = "./saved_models/cifar_10/without_noisy/"
    data_root: str = "./data"

    k_values: Tuple[int, ...] = (8,)
    title: str = "CIFAR-10"


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class AddGaussianNoise:
    def __init__(self, mean: float = 0.0, std: float = 0.2):
        self.mean = mean
        self.std = std

    def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
        return tensor + torch.randn_like(tensor) * self.std + self.mean


class CNN_CIFAR10(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.classifier(self.features(x))


def get_loader(dataset, batch_size: int) -> DataLoader:
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)


def local_train(model: nn.Module, loader: DataLoader, device: torch.device, lr: float, epochs: int) -> Dict[str, torch.Tensor]:
    model.train()
    opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    loss_fn = nn.CrossEntropyLoss()
    for _ in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            opt.step()
    return {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}


def test(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
    correct, total = 0, 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(dim=1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
    return correct / total if total > 0 else 0.0


def compute_kde_weights(local_weights: List[Dict[str, torch.Tensor]], num_binary: int, k: int) -> List[float]:
    flat = np.array([torch.cat([p.flatten() for p in w.values()]).numpy() for w in local_weights])
    dist_matrix = squareform(pdist(flat))
    h = np.median(dist_matrix) + 1e-12

    N = len(local_weights)
    kde_vals = np.zeros(N, dtype=np.float64)

    for i in range(N):
        idx = np.argsort(dist_matrix[i])[1 : k + 1]
        d = dist_matrix[i, idx]
        kde_vals[i] = np.mean(np.exp(-(d * d) / (2.0 * h * h)))

    kde_vals += 1e-12

    binary_kde = kde_vals[:num_binary]
    full_kde = kde_vals[num_binary:]

    binary_kde_vals = binary_kde / binary_kde.sum()
    full_kde_vals = full_kde / full_kde.sum()

    binary_avg = float(np.mean(binary_kde))
    full_avg = float(np.mean(full_kde))

    denom = (binary_avg + full_avg) if (binary_avg + full_avg) != 0 else 1.0
    binary_weight = binary_avg / denom
    full_weight = full_avg / denom

    binary_kde_vals = [float(x) * binary_weight for x in binary_kde_vals]
    full_kde_vals = [float(x) * full_weight for x in full_kde_vals]

    return binary_kde_vals + full_kde_vals


def fed_kde_avg(local_weights: List[Dict[str, torch.Tensor]], kde_weights: List[float]) -> Dict[str, torch.Tensor]:
    keys = local_weights[0].keys()
    out: Dict[str, torch.Tensor] = {}
    for k in keys:
        out[k] = sum(kde_weights[i] * local_weights[i][k] for i in range(len(local_weights)))
    return out


def build_cifar10_clients(cfg: CifarKDETrainConfig, noisy_clients: np.ndarray):
    clean_transform = transforms.Compose([transforms.ToTensor()])

    train_clean = datasets.CIFAR10(cfg.data_root, train=True, download=True, transform=clean_transform)
    test_ds = datasets.CIFAR10(cfg.data_root, train=False, download=True, transform=clean_transform)

    labels = np.array(train_clean.targets)
    label_indices = {l: np.where(labels == l)[0] for l in range(10)}
    for l in label_indices:
        np.random.shuffle(label_indices[l])

    num_binary = cfg.num_clients - cfg.num_full
    client_datasets: List[Subset] = []

    for i in range(num_binary):
        rand0 = rand1 = 400
        start0 = np.random.randint(0, len(label_indices[0]) - rand0)
        start1 = np.random.randint(0, len(label_indices[1]) - rand1)

        idx0 = label_indices[0][start0 : start0 + rand0]
        idx1 = label_indices[1][start1 : start1 + rand1]

        extra_label_num = np.random.randint(1, 6)
        extra_labels = np.random.choice(range(2, 10), size=extra_label_num, replace=False)

        extra_indices: List[int] = []
        for lbl in extra_labels:
            total_lbl = len(label_indices[lbl])
            max_take = min(60, total_lbl)
            min_take = min(20, max_take)
            take = np.random.randint(min_take, max_take + 1)
            start_lbl = 0 if total_lbl - take <= 1 else np.random.randint(0, total_lbl - take)
            extra_indices.extend(label_indices[lbl][start_lbl : start_lbl + take])

        subset_indices = np.concatenate([idx0, idx1, np.array(extra_indices, dtype=int)])
        np.random.shuffle(subset_indices)

        if i in noisy_clients:
            noisy_transform = transforms.Compose([
                transforms.ToTensor(),
                AddGaussianNoise(std=cfg.gaussian_noise_parameter),
            ])
            noisy_ds = datasets.CIFAR10(cfg.data_root, train=True, download=True, transform=noisy_transform)
            client_datasets.append(Subset(noisy_ds, subset_indices))
        else:
            client_datasets.append(Subset(train_clean, subset_indices))

    indices_2_9 = np.concatenate([np.where(labels == l)[0] for l in range(2, 10)])
    np.random.shuffle(indices_2_9)

    indices_0_1 = np.concatenate([np.where(labels == 0)[0], np.where(labels == 1)[0]])
    np.random.shuffle(indices_0_1)

    samples_per_client = len(indices_2_9) // cfg.num_full

    for j in range(cfg.num_full):
        subset_main = indices_2_9[j * samples_per_client : (j + 1) * samples_per_client]

        extra_num = np.random.randint(30, 81)
        extra_num = min(extra_num, len(indices_0_1))
        extra_indices = np.random.choice(indices_0_1, size=extra_num, replace=False)

        subset_indices = np.concatenate([subset_main, extra_indices])
        np.random.shuffle(subset_indices)

        client_id = num_binary + j
        if client_id in noisy_clients:
            noisy_transform = transforms.Compose([
                transforms.ToTensor(),
                AddGaussianNoise(std=cfg.gaussian_noise_parameter),
            ])
            noisy_ds = datasets.CIFAR10(cfg.data_root, train=True, download=True, transform=noisy_transform)
            client_datasets.append(Subset(noisy_ds, subset_indices))
        else:
            client_datasets.append(Subset(train_clean, subset_indices))

    if len(client_datasets) != cfg.num_clients:
        raise RuntimeError(f"client_datasets={len(client_datasets)} != num_clients={cfg.num_clients}")

    return client_datasets, test_ds, num_binary


def run_cifar_kde_train(cfg: CifarKDETrainConfig) -> None:
    set_seed(cfg.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(cfg.save_path, exist_ok=True)

    noisy_clients = np.random.choice(cfg.num_clients, size=cfg.noisy_client_num, replace=False)
    client_datasets, test_ds, num_binary = build_cifar10_clients(cfg, noisy_clients)
    test_loader = get_loader(test_ds, cfg.batch_size)

    global_accs: List[List[float]] = []

    for k_value in cfg.k_values:
        global_model = CNN_CIFAR10().to(device)
        acc_curve: List[float] = []

        for r in range(cfg.rounds):
            local_weights: List[Dict[str, torch.Tensor]] = []

            for i in range(cfg.num_clients):
                local_model = CNN_CIFAR10().to(device)
                local_model.load_state_dict(global_model.state_dict())

                loader = get_loader(client_datasets[i], cfg.batch_size)
                w = local_train(local_model, loader, device, cfg.lr, cfg.epochs)
                local_weights.append(w)

            kde_w = compute_kde_weights(local_weights, num_binary=num_binary, k=k_value)
            aggregated = fed_kde_avg(local_weights, kde_w)

            global_model.load_state_dict(aggregated)
            acc = test(global_model, test_loader, device)
            acc_curve.append(acc)

            if (r + 1) == cfg.save_round:
                torch.save(local_weights, os.path.join(cfg.save_path, f"round{cfg.save_round}_local_models.pth"))

        global_accs.append(acc_curve)

    rcParams["font.family"] = "Times New Roman"
    plt.figure(figsize=(6, 4), facecolor="white")

    ax = plt.gca()
    ax.set_facecolor((230 / 255, 230 / 255, 238 / 255))
    for s in ["bottom", "top", "right", "left"]:
        ax.spines[s].set_color("white")

    colors = [
        (128 / 255, 149 / 255, 192 / 255),
        (215 / 255, 164 / 255, 133 / 255),
        (141 / 255, 185 / 255, 149 / 255),
        (197 / 255, 128 / 255, 131 / 255),
        (158 / 255, 149 / 255, 192 / 255),
        (171 / 255, 155 / 255, 140 / 255),
    ]
    linestyles = ["-", "--", ":", "-.", (0, (3, 1, 1, 1)), (0, (8, 2))]

    for idx, k_value in enumerate(cfg.k_values):
        c = colors[idx % len(colors)]
        ls = linestyles[idx % len(linestyles)]
        plt.plot(global_accs[idx], label=f"k={k_value}", color=c, linestyle=ls)

    plt.xlabel("Round")
    plt.ylabel("Accuracy")
    plt.title(cfg.title)
    plt.grid(True, color="white", zorder=0)
    legend = plt.legend()
    legend.get_frame().set_facecolor((230 / 255, 230 / 255, 238 / 255))
    plt.tight_layout()
    plt.show()


def build_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(description="Unified script (3-in-1), with all prints removed.")
    sub = p.add_subparsers(dest="cmd", required=True)

    p1 = sub.add_parser("mnist_kf_plot", help="Plot MNIST clean/noisy vs KF filtered curves/bars from saved local models.")
    p1.add_argument("--clean_pth", type=str, default=MnistKFPlotConfig.clean_pth)
    p1.add_argument("--noisy_pth", type=str, default=MnistKFPlotConfig.noisy_pth)
    p1.add_argument("--start_client", type=int, default=MnistKFPlotConfig.start_client)
    p1.add_argument("--value_mode", type=str, default=MnistKFPlotConfig.value_mode, choices=["mean", "l2", "ms"])
    p1.add_argument("--kf_alpha", type=float, default=MnistKFPlotConfig.kf_alpha)
    p1.add_argument("--kf_beta", type=float, default=MnistKFPlotConfig.kf_beta)
    p1.add_argument("--kf_p0", type=float, default=MnistKFPlotConfig.kf_p0)
    p1.add_argument("--lam", type=float, default=MnistKFPlotConfig.lam)
    p1.add_argument("--rho", type=float, default=MnistKFPlotConfig.rho)
    p1.add_argument("--total_clients", type=int, default=MnistKFPlotConfig.total_clients)
    p1.add_argument("--num_full", type=int, default=MnistKFPlotConfig.num_full)
    p1.add_argument("--line", action="store_true", help="Use line chart (default: bar chart).")
    p1.add_argument("--title", type=str, default=MnistKFPlotConfig.title)

    p2 = sub.add_parser("add_noise", help="Add zero-mean Gaussian noise to saved local models and save.")
    p2.add_argument("--clean_pth", type=str, default=AddNoiseConfig.clean_pth)
    p2.add_argument("--noisy_pth", type=str, default=AddNoiseConfig.noisy_pth)
    p2.add_argument("--std", type=float, default=AddNoiseConfig.gaussian_std)

    p3 = sub.add_parser("train_cifar_kde", help="Train CIFAR-10 KDE-weighted FL and save local models at save_round.")
    p3.add_argument("--num_clients", type=int, default=CifarKDETrainConfig.num_clients)
    p3.add_argument("--num_full", type=int, default=CifarKDETrainConfig.num_full)
    p3.add_argument("--noisy_client_num", type=int, default=CifarKDETrainConfig.noisy_client_num)
    p3.add_argument("--gaussian_noise_parameter", type=float, default=CifarKDETrainConfig.gaussian_noise_parameter)
    p3.add_argument("--epochs", type=int, default=CifarKDETrainConfig.epochs)
    p3.add_argument("--rounds", type=int, default=CifarKDETrainConfig.rounds)
    p3.add_argument("--batch_size", type=int, default=CifarKDETrainConfig.batch_size)
    p3.add_argument("--lr", type=float, default=CifarKDETrainConfig.lr)
    p3.add_argument("--seed", type=int, default=CifarKDETrainConfig.seed)
    p3.add_argument("--save_round", type=int, default=CifarKDETrainConfig.save_round)
    p3.add_argument("--save_path", type=str, default=CifarKDETrainConfig.save_path)
    p3.add_argument("--data_root", type=str, default=CifarKDETrainConfig.data_root)
    p3.add_argument("--k_values", type=str, default="8", help="Comma-separated, e.g. '5,8,10'.")
    p3.add_argument("--title", type=str, default=CifarKDETrainConfig.title)

    return p


def main():
    parser = build_parser()
    args = parser.parse_args()

    if args.cmd == "mnist_kf_plot":
        cfg = MnistKFPlotConfig(
            clean_pth=args.clean_pth,
            noisy_pth=args.noisy_pth,
            start_client=args.start_client,
            value_mode=args.value_mode,
            kf_alpha=args.kf_alpha,
            kf_beta=args.kf_beta,
            kf_p0=args.kf_p0,
            lam=args.lam,
            rho=args.rho,
            total_clients=args.total_clients,
            num_full=args.num_full,
            is_line_chart=bool(args.line),
            title=args.title,
        )
        run_mnist_kf_plot(cfg)
        return

    if args.cmd == "add_noise":
        add_zero_mean_noise_to_models(args.clean_pth, args.noisy_pth, args.std)
        return

    if args.cmd == "train_cifar_kde":
        k_vals = tuple(int(s.strip()) for s in args.k_values.split(",") if s.strip())
        cfg = CifarKDETrainConfig(
            num_clients=args.num_clients,
            num_full=args.num_full,
            noisy_client_num=args.noisy_client_num,
            gaussian_noise_parameter=args.gaussian_noise_parameter,
            epochs=args.epochs,
            rounds=args.rounds,
            batch_size=args.batch_size,
            lr=args.lr,
            seed=args.seed,
            save_round=args.save_round,
            save_path=args.save_path,
            data_root=args.data_root,
            k_values=k_vals if len(k_vals) > 0 else (8,),
            title=args.title,
        )
        run_cifar_kde_train(cfg)
        return


if __name__ == "__main__":
    main()
