# fed_vote_ceka.py
# -------------------------------------------------------
# Federated Learning (FedAvg / FedKDE) on WEKA vote.arff
# - Fix ARFF nominal mismatch by stripping nominal values in header + bytes->str strip
# - 60 clients, 30 shards
# - 10 clients: multi-shards (diverse), 50 clients: single-shard (biased)
# - Equal data size per client (auto-sized from dataset)
# - Robust to small datasets: avoids empty-shard sampling + guards k >= N
# -------------------------------------------------------

import os
import re
import random
import tempfile
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.io import arff
from scipy.spatial.distance import pdist, squareform

from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import matplotlib.pyplot as plt
from matplotlib import rcParams


# =========================
# 0) 基础配置
# =========================
ARFF_PATH = "/Users/xinyun/Desktop/test code/synthetic/vote/vote.arff"  # <- vote 路径

SEED = 43
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NUM_CLIENTS = 60
NUM_SHARDS = 30
NUM_FULL_CLIENTS = 10
NUM_PARTIAL_CLIENTS = NUM_CLIENTS - NUM_FULL_CLIENTS

ROUNDS = 30
LOCAL_EPOCHS = 2
BATCH_SIZE = 64
LR = 1e-3


# =========================
# 1) ARFF 清洗：修复 nominal 中的尾随空格
# =========================
def sanitize_arff_nominal_values(src_path: str) -> str:
    text = Path(src_path).read_text(encoding="utf-8", errors="ignore")

    def _clean_nominal_list(match):
        inner = match.group(1)
        parts = inner.split(",")
        cleaned = []
        for p in parts:
            v = p.strip()
            if (len(v) >= 2) and (
                (v[0] == "'" and v[-1] == "'") or (v[0] == '"' and v[-1] == '"')
            ):
                v = v[1:-1].strip()
            cleaned.append(v)
        return "{" + ",".join(cleaned) + "}"

    text2 = re.sub(r"\{([^}]*)\}", _clean_nominal_list, text)

    fd, tmp_path = tempfile.mkstemp(suffix=".arff", prefix="vote_clean_")
    os.close(fd)
    Path(tmp_path).write_text(text2, encoding="utf-8")
    return tmp_path


def load_arff_to_dataframe(arff_path: str) -> pd.DataFrame:
    clean_path = sanitize_arff_nominal_values(arff_path)
    data, meta = arff.loadarff(clean_path)

    df = pd.DataFrame(data)

    # bytes -> str，并 strip
    for col in df.columns:
        if df[col].dtype == object:
            df[col] = df[col].apply(
                lambda x: x.decode("utf-8", errors="ignore").strip()
                if isinstance(x, (bytes, bytearray))
                else x
            )
    return df


# =========================
# 2) 数据预处理（数值/类别混合）
# =========================
def infer_target_col(df: pd.DataFrame) -> str:
    lower_map = {c.lower(): c for c in df.columns}
    for key in ["class", "target", "label", "y", "outcome"]:
        if key in lower_map:
            return lower_map[key]
    return df.columns[-1]


def build_preprocessor(df: pd.DataFrame, target_col: str):
    X = df.drop(columns=[target_col])
    y = df[target_col].astype(str).str.strip()

    cat_cols = [c for c in X.columns if X[c].dtype == object]
    num_cols = [c for c in X.columns if c not in cat_cols]

    numeric_pipe = Pipeline(
        steps=[
            ("imputer", SimpleImputer(strategy="median")),
            ("scaler", StandardScaler()),
        ]
    )

    categorical_pipe = Pipeline(
        steps=[
            ("imputer", SimpleImputer(strategy="most_frequent")),
            ("onehot", OneHotEncoder(handle_unknown="ignore")),
        ]
    )

    preprocessor = ColumnTransformer(
        transformers=[
            ("num", numeric_pipe, num_cols),
            ("cat", categorical_pipe, cat_cols),
        ]
    )
    return X, y, preprocessor


# =========================
# 3) PyTorch 模型
# =========================
class MLP(nn.Module):
    def __init__(self, in_dim: int, num_classes: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, num_classes),
        )

    def forward(self, x):
        return self.net(x)


def get_model_params(model: nn.Module):
    return [p.detach().cpu().clone() for p in model.parameters()]


def set_model_params(model: nn.Module, params):
    with torch.no_grad():
        for p, new_p in zip(model.parameters(), params):
            p.copy_(new_p.to(p.device))


def fedavg(list_of_params, weights):
    total = float(sum(weights))
    weights = [w / total for w in weights]

    avg_params = []
    for param_i in range(len(list_of_params[0])):
        stacked = torch.stack(
            [client_params[param_i] * weights[idx] for idx, client_params in enumerate(list_of_params)],
            dim=0,
        )
        avg_params.append(stacked.sum(dim=0))
    return avg_params


@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        logits = model(xb)
        pred = logits.argmax(dim=1)
        correct += (pred == yb).sum().item()
        total += yb.numel()
    return correct / max(total, 1)


def train_one_client(model, train_loader, epochs=1, lr=1e-3):
    model.train()
    opt = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    for _ in range(epochs):
        for xb, yb in train_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad()
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            opt.step()


# =========================
# 4) 构造 shards + clients（鲁棒：避免空 shard）
# =========================
def make_shards(indices, y, num_shards=6):
    shards = [[] for _ in range(num_shards)]
    y = np.asarray(y)

    classes = np.unique(y)
    for c in classes:
        idx_c = indices[y[indices] == c].copy()
        np.random.shuffle(idx_c)
        parts = np.array_split(idx_c, num_shards)  # 小数据集下会出现空 part：允许
        for s in range(num_shards):
            shards[s].extend(parts[s].tolist())

    for s in range(num_shards):
        np.random.shuffle(shards[s])
    return shards


def sample_fixed_size(pool_indices: np.ndarray, size: int):
    pool_indices = np.asarray(pool_indices, dtype=int)
    if len(pool_indices) == 0:
        raise ValueError("采样池为空：请检查 shards 是否出现空分片/标签是否异常。")
    if len(pool_indices) >= size:
        return np.random.choice(pool_indices, size=size, replace=False)
    return np.random.choice(pool_indices, size=size, replace=True)


def nonempty_shard_ids(shards):
    return [i for i, s in enumerate(shards) if len(s) > 0]


def build_clients_from_shards(
    shards,
    client_train_size,
    num_clients=60,
    num_full_clients=10,
    full_mix_shards=5,
):
    """
    full clients：从多个“非空 shard”混合采样（更丰富）
    partial clients：从单一“非空 shard”采样（更偏置）
    """
    assert num_full_clients <= num_clients
    clients = {}

    ne_ids = nonempty_shard_ids(shards)
    if len(ne_ids) == 0:
        raise ValueError("所有 shards 都为空：请减小 NUM_SHARDS 或检查数据/标签。")

    # full clients
    for cid in range(num_full_clients):
        k = min(full_mix_shards, len(ne_ids))
        chosen = random.sample(ne_ids, k=k)
        pool = np.array([i for s in chosen for i in shards[s]], dtype=int)
        clients[cid] = sample_fixed_size(pool, client_train_size)

    # partial clients (single-shard, but non-empty)
    for cid in range(num_full_clients, num_clients):
        chosen = [random.choice(ne_ids)]
        pool = np.array([i for s in chosen for i in shards[s]], dtype=int)
        clients[cid] = sample_fixed_size(pool, client_train_size)

    return clients


# =========================
# 5) FedKDE 权重
# =========================
def compute_kde_weights(local_state_dicts, k, eps=1e-12):
    """
    local_state_dicts: List[state_dict]
    返回每个客户端的 FedKDE 权重（含 full vs partial 两段式权重分配）。
    """
    flat = []
    for sd in local_state_dicts:
        vec = []
        for _, t in sd.items():
            if torch.is_tensor(t):
                vec.append(t.detach().cpu().reshape(-1))
        flat.append(torch.cat(vec).numpy())
    flat = np.asarray(flat)

    dist_matrix = squareform(pdist(flat))
    h = float(np.median(dist_matrix) + eps)

    N = len(local_state_dicts)
    kde_vals = np.zeros(N, dtype=np.float64)

    k_eff = int(max(1, min(k, N - 1)))
    for i in range(N):
        idx = np.argsort(dist_matrix[i])[1 : k_eff + 1]  # exclude self
        d = dist_matrix[i, idx]
        kde_vals[i] = np.mean(np.exp(-d * d / (2.0 * h * h)))

    kde_vals += eps

    # full vs partial 两段式（保持你一直在用的逻辑）
    full = kde_vals[:NUM_FULL_CLIENTS]
    part = kde_vals[NUM_FULL_CLIENTS:]

    full_norm = full / full.sum()
    part_norm = part / part.sum()

    full_avg = float(np.mean(full))
    part_avg = float(np.mean(part))

    part_weight = part_avg / (part_avg + full_avg + eps)
    full_weight = full_avg / (part_avg + full_avg + eps)

    full_norm = full_norm * full_weight
    part_norm = part_norm * part_weight

    return np.concatenate([full_norm, part_norm], axis=0)


# =========================
# 6) 主流程
# =========================
def main(k_value):
    df = load_arff_to_dataframe(ARFF_PATH)
    target_col = infer_target_col(df)

    X_raw, y_raw, preprocessor = build_preprocessor(df, target_col)

    X_train_raw, X_test_raw, y_train_raw, y_test_raw = train_test_split(
        X_raw, y_raw, test_size=0.2, random_state=SEED, stratify=y_raw
    )

    X_train = preprocessor.fit_transform(X_train_raw)
    X_test = preprocessor.transform(X_test_raw)

    if hasattr(X_train, "toarray"):
        X_train = X_train.toarray()
    if hasattr(X_test, "toarray"):
        X_test = X_test.toarray()

    classes = sorted(y_train_raw.unique().tolist())
    class_to_id = {c: i for i, c in enumerate(classes)}
    y_train = y_train_raw.map(class_to_id).to_numpy(dtype=np.int64)
    y_test = y_test_raw.map(class_to_id).to_numpy(dtype=np.int64)

    in_dim = X_train.shape[1]
    num_classes = len(classes)

    test_ds = TensorDataset(
        torch.tensor(X_test, dtype=torch.float32),
        torch.tensor(y_test, dtype=torch.long),
    )
    test_loader = DataLoader(test_ds, batch_size=1024, shuffle=False)

    # auto client_train_size: 保证每个 client 一样大，同时避免过小
    train_n = int(len(X_train))
    client_train_size = int(max(64, min(512, max(1, train_n // NUM_CLIENTS))))

    train_indices = np.arange(train_n)
    shards = make_shards(train_indices, y_train, num_shards=NUM_SHARDS)

    clients = build_clients_from_shards(
        shards=shards,
        client_train_size=client_train_size,
        num_clients=NUM_CLIENTS,
        num_full_clients=NUM_FULL_CLIENTS,
        full_mix_shards=5,
    )

    global_model = MLP(in_dim, num_classes).to(DEVICE)
    global_params = get_model_params(global_model)

    global_accs = []

    all_cids = list(range(NUM_CLIENTS))

    for rnd in range(1, ROUNDS + 1):
        client_params_list = []
        local_state_dicts = []

        for cid in all_cids:
            idx = clients[cid]
            xb = torch.tensor(X_train[idx], dtype=torch.float32)
            yb = torch.tensor(y_train[idx], dtype=torch.long)
            train_loader = DataLoader(TensorDataset(xb, yb), batch_size=BATCH_SIZE, shuffle=True)

            local_model = MLP(in_dim, num_classes).to(DEVICE)
            set_model_params(local_model, global_params)

            train_one_client(local_model, train_loader, epochs=LOCAL_EPOCHS, lr=LR)

            local_state_dicts.append(local_model.state_dict())
            client_params_list.append(get_model_params(local_model))

        kde_weights = compute_kde_weights(local_state_dicts, k_value)
        global_params = fedavg(client_params_list, kde_weights)
        set_model_params(global_model, global_params)

        acc = evaluate(global_model, test_loader)
        global_accs.append(acc)
        print(f"[Round {rnd:02d}] Global Test Acc = {acc:.4f}")

    print(f"Final Test Acc = {global_accs[-1]:.4f}")
    return global_accs


if __name__ == "__main__":
    plt.figure(figsize=(6, 4), facecolor="white")
    rcParams["font.family"] = "Times New Roman"

    ax = plt.gca()
    ax.set_facecolor((230 / 255, 230 / 255, 238 / 255))
    for s in ["bottom", "top", "right", "left"]:
        ax.spines[s].set_color("white")

    colors = [
        (128 / 255, 149 / 255, 192 / 255),
        (215 / 255, 164 / 255, 133 / 255),
        (141 / 255, 185 / 255, 149 / 255),
        (197 / 255, 128 / 255, 131 / 255),
        (158 / 255, 149 / 255, 192 / 255),
        (171 / 255, 155 / 255, 140 / 255),
    ]
    linestyles = ["-", "--", ":", "-.", (0, (3, 1, 1, 1)), (0, (8, 2))]

    k_values = [5, 8, 10, 13, 15, 20]
    for idx, k_value in enumerate(k_values):
        print(f"\n===== k = {k_value} =====")
        global_accs = main(k_value)
        plt.plot(
            global_accs,
            label=f"k={k_value}",
            color=colors[idx],
            linestyle=linestyles[idx],
        )
        print(f"k={k_value}, avg acc (round 16~end) = {sum(global_accs[15:]) / len(global_accs[15:]):.4f}")

    plt.xlabel("Round")
    plt.ylabel("Accuracy")
    plt.title("vote")
    plt.grid(True, color="white", zorder=0)
    legend = plt.legend()
    legend.get_frame().set_facecolor((230 / 255, 230 / 255, 238 / 255))
    plt.show()
