import os
import re
import random
import tempfile
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.io import arff

from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import matplotlib.pyplot as plt
from matplotlib import rcParams

from scipy.spatial.distance import pdist, squareform

ARFF_PATH = "./synthetic/labor/labor.arff"

SEED = 43
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NUM_CLIENTS = 60
TARGET_NUM_SHARDS = 30

NUM_FULL_CLIENTS = 10
NUM_PARTIAL_CLIENTS = 50

ROUNDS = 30
LOCAL_EPOCHS = 2
BATCH_SIZE = 64
LR = 1e-3

CLIENT_TRAIN_SIZE = 10

NOISY_CLIENT_NUM = 25
TAB_NOISE_STD = 0.2
TAB_SP_PROB = 0.1
TAB_FLIP_RATIO = 0.1

def noisy_clients_setup(num_partial_clients: int, noisy_client_num: int):
    """只从 partial clients (cid >= NUM_FULL_CLIENTS) 里挑 noisy clients"""
    partial_cids = np.arange(NUM_FULL_CLIENTS, NUM_FULL_CLIENTS + num_partial_clients)
    noisy = np.random.choice(partial_cids, size=noisy_client_num, replace=False)
    return set(noisy)

def add_tabular_feature_noise(xb: torch.Tensor,
                             gaussian_std: float = 0.2,
                             sp_prob: float = 0.05,
                             flip_ratio: float = 0.02):
    """
    xb: [B, D] float tensor (preprocessed tabular features)
    - Gaussian: xb + N(0, std^2) (batch 去均值，减小整体偏移)
    - Salt&Pepper: 少量位置置 0 或 1
    - Flip: 少量位置做 x := 1 - x
    """
    x = xb

    if gaussian_std > 0:
        eps = torch.randn_like(x) * gaussian_std
        eps = eps - eps.mean(dim=0, keepdim=True)
        x = x + eps

    if sp_prob > 0:
        r = torch.rand_like(x)
        x = x.clone()
        x[r < sp_prob / 2] = 0.0
        x[r > 1 - sp_prob / 2] = 1.0

    if flip_ratio > 0:
        total = x.numel()
        m = int(total * flip_ratio)
        if m > 0:
            flat = x.view(-1)
            idx = torch.randperm(total, device=x.device)[:m]
            flat[idx] = 1.0 - flat[idx]
            x = flat.view_as(x)

    return x

def sanitize_arff_nominal_values(src_path: str) -> str:
    """
    把 ARFF 里 @attribute nominal 列表 { ... } 的值做 strip，统一空格/引号。
    """
    text = Path(src_path).read_text(encoding="utf-8", errors="ignore")

    def _clean_nominal_list(match):
        inner = match.group(1)
        parts = inner.split(",")
        cleaned = []
        for p in parts:
            v = p.strip()
            if (len(v) >= 2) and ((v[0] == "'" and v[-1] == "'") or (v[0] == '"' and v[-1] == '"')):
                v = v[1:-1].strip()
            cleaned.append(v)
        return "{" + ",".join(cleaned) + "}"

    text2 = re.sub(r"\{([^}]*)\}", _clean_nominal_list, text)

    fd, tmp_path = tempfile.mkstemp(suffix=".arff", prefix="labor_clean_")
    os.close(fd)
    Path(tmp_path).write_text(text2, encoding="utf-8")
    return tmp_path

def load_arff_to_dataframe(arff_path: str) -> pd.DataFrame:
    clean_path = sanitize_arff_nominal_values(arff_path)
    data, meta = arff.loadarff(clean_path)
    df = pd.DataFrame(data)

    for col in df.columns:
        if df[col].dtype == object:
            df[col] = df[col].apply(
                lambda x: x.decode("utf-8", errors="ignore").strip()
                if isinstance(x, (bytes, bytearray)) else x
            )
            df[col] = df[col].replace("?", np.nan)
        else:
            df[col] = df[col].replace("?", np.nan)
    return df

def build_preprocessor(df: pd.DataFrame, target_col: str):
    X = df.drop(columns=[target_col])
    y = df[target_col].astype(str).str.strip().replace("?", np.nan)

    cat_cols = [c for c in X.columns if X[c].dtype == object]
    num_cols = [c for c in X.columns if c not in cat_cols]

    numeric_pipe = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="median")),
        ("scaler", StandardScaler()),
    ])

    categorical_pipe = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="most_frequent")),
        ("onehot", OneHotEncoder(handle_unknown="ignore")),
    ])

    preprocessor = ColumnTransformer(
        transformers=[
            ("num", numeric_pipe, num_cols),
            ("cat", categorical_pipe, cat_cols),
        ]
    )

    return X, y, preprocessor

class MLP(nn.Module):
    def __init__(self, in_dim: int, num_classes: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, num_classes),
        )

    def forward(self, x):
        return self.net(x)

def get_model_params(model: nn.Module):
    return [p.detach().cpu().clone() for p in model.parameters()]

def set_model_params(model: nn.Module, params):
    with torch.no_grad():
        for p, new_p in zip(model.parameters(), params):
            p.copy_(new_p.to(p.device))

def fedavg(list_of_params, weights):
    total = float(sum(weights))
    weights = [w / total for w in weights]

    avg_params = []
    for param_i in range(len(list_of_params[0])):
        stacked = torch.stack(
            [client_params[param_i] * weights[idx] for idx, client_params in enumerate(list_of_params)],
            dim=0
        )
        avg_params.append(stacked.sum(dim=0))
    return avg_params

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        logits = model(xb)
        pred = logits.argmax(dim=1)
        correct += (pred == yb).sum().item()
        total += yb.numel()
    return correct / max(total, 1)

def train_one_client(model, train_loader, epochs=1, lr=1e-3,
                     noisy_feature=False,
                     gaussian_std=0.2, sp_prob=0.1, flip_ratio=0.1):
    model.train()
    opt = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    for _ in range(epochs):
        for xb, yb in train_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)

            if noisy_feature:
                xb = add_tabular_feature_noise(
                    xb,
                    gaussian_std=gaussian_std,
                    sp_prob=sp_prob,
                    flip_ratio=flip_ratio
                )

            opt.zero_grad()
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            opt.step()

def make_shards(indices, y, num_shards=6):
    shards = [[] for _ in range(num_shards)]
    y = np.asarray(y)

    classes = np.unique(y)
    for c in classes:
        idx_c = indices[y[indices] == c]
        idx_c = idx_c.copy()
        np.random.shuffle(idx_c)
        parts = np.array_split(idx_c, num_shards)
        for s in range(num_shards):
            shards[s].extend(parts[s].tolist())

    for s in range(num_shards):
        np.random.shuffle(shards[s])
    return shards

def sample_fixed_size(pool_indices: np.ndarray, size: int):
    pool_indices = np.asarray(pool_indices, dtype=int)
    if len(pool_indices) == 0:
        raise ValueError("采样池为空：请检查分片/数据是否正确（常见原因：num_shards 太大导致 shard 为空）。")
    if len(pool_indices) >= size:
        return np.random.choice(pool_indices, size=size, replace=False)
    else:
        return np.random.choice(pool_indices, size=size, replace=True)

def build_clients_from_shards(shards, client_train_size, num_clients=10, num_full_clients=2, is_label_skew_p=True):
    assert num_full_clients <= num_clients
    assert len(shards) >= 2, "至少需要 2 个 shards（请降低 TARGET_NUM_SHARDS 或增大训练集）"

    clients = {}

    for cid in range(num_full_clients):
        chosen = [0, 1]
        chosen = [0,1,2,3,4,5]
        pool = np.array([i for s in chosen for i in shards[s]], dtype=int)
        clients[cid] = sample_fixed_size(pool, client_train_size)

    for cid in range(num_full_clients, num_clients):
        chosen = [0, 1]
        chosen = [0,1,2,3]
        if is_label_skew_p:
            chosen = int(np.random.choice(chosen))
            chosen = [chosen]
        pool = np.array([i for s in chosen for i in shards[s]], dtype=int)
        clients[cid] = sample_fixed_size(pool, client_train_size)

    return clients

def compute_kde_weights(local_weights, k):
    flat = np.array([
        torch.cat([p.flatten() for p in w.values()]).detach().cpu().numpy()
        for w in local_weights
    ])

    dist_matrix = squareform(pdist(flat))
    h = np.median(dist_matrix) + 1e-12

    N = len(local_weights)
    kde_vals = np.zeros(N, dtype=np.float64)

    for i in range(N):
        if k == 0:
            break
        idx = np.argsort(dist_matrix[i])[1:k+1]
        d = dist_matrix[i, idx]
        kde_vals[i] = np.mean(np.exp(-d * d / (2 * h * h)))

    if k == 0:
        return np.ones(N) / N

    kde_vals += 1e-12

    partial = kde_vals[NUM_FULL_CLIENTS:]
    full = kde_vals[:NUM_FULL_CLIENTS]

    partial_kde_vals = partial / partial.sum()
    full_kde_vals = full / full.sum()

    partial_avg = np.mean(partial)
    full_avg = np.mean(full)

    partial_weight = partial_avg / (partial_avg + full_avg)
    full_weight = full_avg / (partial_avg + full_avg)

    partial_kde_vals = [x * partial_weight for x in partial_kde_vals]
    full_kde_vals = [x * full_weight for x in full_kde_vals]

    kde_vals = full_kde_vals + partial_kde_vals
    return kde_vals

def main(k_value, is_feature_skew=True, is_label_skew=True):
    df = load_arff_to_dataframe(ARFF_PATH)

    target_col = "class" if "class" in df.columns else df.columns[-1]
    X_raw, y_raw, preprocessor = build_preprocessor(df, target_col)

    try:
        X_train_raw, X_test_raw, y_train_raw, y_test_raw = train_test_split(
            X_raw, y_raw, test_size=0.2, random_state=SEED, stratify=y_raw
        )
    except ValueError:
        X_train_raw, X_test_raw, y_train_raw, y_test_raw = train_test_split(
            X_raw, y_raw, test_size=0.2, random_state=SEED, stratify=None
        )

    X_train = preprocessor.fit_transform(X_train_raw)
    X_test = preprocessor.transform(X_test_raw)

    if hasattr(X_train, "toarray"):
        X_train = X_train.toarray()
    if hasattr(X_test, "toarray"):
        X_test = X_test.toarray()

    y_train_raw = y_train_raw.reset_index(drop=True)
    y_test_raw = y_test_raw.reset_index(drop=True)

    classes = sorted(pd.Series(y_train_raw).dropna().unique().tolist())
    class_to_id = {c: i for i, c in enumerate(classes)}
    y_train = pd.Series(y_train_raw).map(class_to_id).to_numpy(dtype=np.int64)
    y_test = pd.Series(y_test_raw).map(class_to_id).to_numpy(dtype=np.int64)

    in_dim = X_train.shape[1]
    num_classes = len(classes)

    test_ds = TensorDataset(torch.tensor(X_test, dtype=torch.float32),
                            torch.tensor(y_test, dtype=torch.long))
    test_loader = DataLoader(test_ds, batch_size=1024, shuffle=False)

    noisy_clients = noisy_clients_setup(NUM_PARTIAL_CLIENTS, NOISY_CLIENT_NUM)

    train_indices = np.arange(len(X_train))
    y_arr = np.asarray(y_train)
    class_counts = [np.sum(y_arr == c) for c in np.unique(y_arr)]
    min_class = int(np.min(class_counts)) if len(class_counts) > 0 else 1
    num_shards = int(max(2, min(TARGET_NUM_SHARDS, min_class)))

    shards = make_shards(train_indices, y_train, num_shards=num_shards)

    clients = build_clients_from_shards(
        shards=shards,
        client_train_size=CLIENT_TRAIN_SIZE,
        num_clients=NUM_CLIENTS,
        num_full_clients=NUM_FULL_CLIENTS,
        is_label_skew_p=is_label_skew
    )

    global_model = MLP(in_dim, num_classes).to(DEVICE)
    global_params = get_model_params(global_model)

    global_accs = []
    for rnd in range(1, ROUNDS + 1):
        all_cids = list(range(NUM_CLIENTS))

        client_params_list = []
        local_model_state_dicts = []

        for cid in all_cids:
            idx = clients[cid]

            xb = torch.tensor(X_train[idx], dtype=torch.float32)
            yb = torch.tensor(y_train[idx], dtype=torch.long)
            train_loader = DataLoader(TensorDataset(xb, yb), batch_size=BATCH_SIZE, shuffle=True)

            local_model = MLP(in_dim, num_classes).to(DEVICE)
            set_model_params(local_model, global_params)

            is_noisy_feature_client = (is_feature_skew and (cid in noisy_clients))

            train_one_client(
                local_model,
                train_loader,
                epochs=LOCAL_EPOCHS,
                lr=LR,
                noisy_feature=is_noisy_feature_client,
                gaussian_std=TAB_NOISE_STD,
                sp_prob=TAB_SP_PROB,
                flip_ratio=TAB_FLIP_RATIO
            )

            local_model_state_dicts.append(local_model.state_dict())
            client_params_list.append(get_model_params(local_model))

        kde_weights = compute_kde_weights(local_model_state_dicts, k_value)
        global_params = fedavg(client_params_list, kde_weights)
        set_model_params(global_model, global_params)

        acc = evaluate(global_model, test_loader)
        global_accs.append(acc)
        print(f"[Round {rnd:02d}] Global Test Acc = {acc:.4f}")

    print("\nDone.")
    print(f"Final Test Acc = {global_accs[-1]:.4f}")
    return global_accs

if __name__ == "__main__":

    k_value_FedAvg = 0
    is_feature_skew_FedAvg = True
    is_label_skew_FedAvg = True

    k_value_LoMar = 60
    is_feature_skew_LoMar = True
    is_label_skew_LoMar = True

    k_value_FedLC = 0
    is_feature_skew_FedLC = True
    is_label_skew_FedLC = False

    k_value_FedRDN = 0
    is_feature_skew_FedRDN = False
    is_label_skew_FedRDN = True

    k_value_FedKde = 8
    is_feature_skew_FedKde = True
    is_label_skew_FedKde = True

    print("FedAvg:")
    acc_FedAvg = main(k_value_FedAvg, is_feature_skew_FedAvg, is_label_skew_FedAvg)
    print("LoMar:")
    acc_LoMar = main(k_value_LoMar, is_feature_skew_LoMar, is_label_skew_LoMar)
    print("FedLC:")
    acc_FedLC = main(k_value_FedLC, is_feature_skew_FedLC, is_label_skew_FedLC)
    print("FedRDN:")
    acc_FedRDN = main(k_value_FedRDN, is_feature_skew_FedRDN, is_label_skew_FedRDN)
    print("FedKde:")
    acc_FedKde = main(k_value_FedKde, is_feature_skew_FedKde, is_label_skew_FedKde)

    index = 20
    mean_FedAvg = np.mean(acc_FedAvg[-index:])
    mean_LoMar = np.mean(acc_LoMar[-index:])
    mean_FedLC = np.mean(acc_FedLC[-index:])
    mean_FedRDN = np.mean(acc_FedRDN[-index:])
    mean_FedKde = np.mean(acc_FedKde[-index:])
    print(f"mean_FedAvg: {mean_FedAvg}")
    print(f"mean_LoMar: {mean_LoMar}")
    print(f"mean_FedLC: {mean_FedLC}")
    print(f"mean_FedRDN: {mean_FedRDN}")
    print(f"mean_FedKde: {mean_FedKde}")

    plt.figure(figsize=(6, 4), facecolor='white')
    rcParams['font.family'] = 'Times New Roman'

    ax = plt.gca()
    ax.set_facecolor((230/255, 230/255, 238/255))
    ax.spines['bottom'].set_color('white')
    ax.spines['top'].set_color('white')
    ax.spines['right'].set_color('white')
    ax.spines['left'].set_color('white')

    colors = [
        (128/255, 149/255, 192/255),
        (215/255, 164/255, 133/255),
        (141/255, 185/255, 149/255),
        (197/255, 128/255, 131/255),
        (158/255, 149/255, 192/255),
        (171/255, 155/255, 140/255),
    ]
    linestyles = ['-', '--', ':', '-.', (0, (3, 1, 1, 1)), (0, (8, 2))]

    plt.plot(acc_FedAvg, label='FedAvg', color=colors[0], linestyle=linestyles[0])
    plt.plot(acc_LoMar, label='LoMar', color=colors[1], linestyle=linestyles[1])
    plt.plot(acc_FedLC, label='FedLC', color=colors[2], linestyle=linestyles[2])
    plt.plot(acc_FedRDN, label='FedRDN', color=colors[3], linestyle=linest_
