import os
import re
import random
import tempfile
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.io import arff

from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import matplotlib.pyplot as plt
from matplotlib import rcParams

from scipy.spatial.distance import pdist, squareform

ARFF_PATH = "/Users/xinyun/Desktop/test code/synthetic/autos/autos.arff"

SEED = 43
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NUM_CLIENTS = 60
NUM_SHARDS = 30
NUM_FULL_CLIENTS = 10
NUM_PARTIAL_CLIENTS = 50

ROUNDS = 30
LOCAL_EPOCHS = 2
BATCH_SIZE = 64
LR = 1e-3

CLIENT_TRAIN_SIZE = 200

PREFERRED_TARGETS = ["price"]

NUM_BINS = 3

USE_FEDKDE = True


def sanitize_arff_nominal_values(src_path: str) -> str:
    if not os.path.isfile(src_path):
        raise FileNotFoundError(
            f"Dataset file not found:\n{src_path}\n"
            f"Please set ARFF_PATH to the correct autos.arff path."
        )

    text = Path(src_path).read_text(encoding="utf-8", errors="ignore")

    def _clean_nominal_list(match):
        inner = match.group(1)
        parts = inner.split(",")
        cleaned = []
        for p in parts:
            v = p.strip()
            if (len(v) >= 2) and ((v[0] == "'" and v[-1] == "'") or (v[0] == '"' and v[-1] == '"')):
                v = v[1:-1].strip()
            cleaned.append(v)
        return "{" + ",".join(cleaned) + "}"

    text2 = re.sub(r"\{([^}]*)\}", _clean_nominal_list, text)

    fd, tmp_path = tempfile.mkstemp(suffix=".arff", prefix="autos_clean_")
    os.close(fd)
    Path(tmp_path).write_text(text2, encoding="utf-8")
    return tmp_path


def load_arff_to_dataframe(arff_path: str) -> pd.DataFrame:
    clean_path = sanitize_arff_nominal_values(arff_path)
    data, meta = arff.loadarff(clean_path)
    df = pd.DataFrame(data)

    for col in df.columns:
        if df[col].dtype == object:
            df[col] = df[col].apply(
                lambda x: x.decode("utf-8", errors="ignore").strip()
                if isinstance(x, (bytes, bytearray)) else x
            )
    return df


def pick_target_col(df: pd.DataFrame) -> str:
    for c in PREFERRED_TARGETS:
        if c in df.columns:
            return c
    return df.columns[-1]


def bin_continuous_to_classes(y: pd.Series, num_bins: int = 3):
    y_num = pd.to_numeric(y.astype(str).str.strip(), errors="coerce")
    mask = y_num.notna()
    y_num2 = y_num[mask]

    y_bins = pd.qcut(y_num2, q=num_bins, labels=False, duplicates="drop")
    y_bins = y_bins.astype(int)

    actual_bins = int(y_bins.nunique())
    return y_bins, mask, actual_bins


def merge_rare_classes(y: pd.Series, min_count: int = 2, rare_label: str = "__RARE__"):
    y = y.astype(str).str.strip()
    vc = y.value_counts()
    rare_classes = vc[vc < min_count].index.tolist()
    if len(rare_classes) == 0:
        return y, rare_classes, vc
    y2 = y.where(~y.isin(rare_classes), other=rare_label)
    return y2, rare_classes, vc


def build_preprocessor(df: pd.DataFrame, target_col: str):
    X = df.drop(columns=[target_col])
    y = df[target_col]

    cat_cols = [c for c in X.columns if X[c].dtype == object]
    num_cols = [c for c in X.columns if c not in cat_cols]

    numeric_pipe = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="median")),
        ("scaler", StandardScaler()),
    ])

    categorical_pipe = Pipeline(steps=[
        ("imputer", SimpleImputer(strategy="most_frequent")),
        ("onehot", OneHotEncoder(handle_unknown="ignore")),
    ])

    preprocessor = ColumnTransformer(
        transformers=[
            ("num", numeric_pipe, num_cols),
            ("cat", categorical_pipe, cat_cols),
        ]
    )
    return X, y, preprocessor


class MLP(nn.Module):
    def __init__(self, in_dim: int, num_classes: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, num_classes),
        )

    def forward(self, x):
        return self.net(x)


def get_model_params(model: nn.Module):
    return [p.detach().cpu().clone() for p in model.parameters()]


def set_model_params(model: nn.Module, params):
    with torch.no_grad():
        for p, new_p in zip(model.parameters(), params):
            p.copy_(new_p.to(p.device))


def fedavg(list_of_params, weights):
    total = float(sum(weights))
    weights = [w / total for w in weights]

    avg_params = []
    for param_i in range(len(list_of_params[0])):
        stacked = torch.stack(
            [client_params[param_i] * weights[idx] for idx, client_params in enumerate(list_of_params)],
            dim=0
        )
        avg_params.append(stacked.sum(dim=0))
    return avg_params


@torch.no_grad()
def evaluate_acc(model, loader):
    model.eval()
    correct, total = 0, 0
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        logits = model(xb)
        pred = logits.argmax(dim=1)
        correct += (pred == yb).sum().item()
        total += yb.numel()
    return correct / max(total, 1)


def train_one_client(model, train_loader, epochs=1, lr=1e-3):
    model.train()
    opt = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    for _ in range(epochs):
        for xb, yb in train_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad()
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            opt.step()


def make_shards(indices, y, num_shards=30):
    shards = [[] for _ in range(num_shards)]
    y = np.asarray(y)

    classes = np.unique(y)
    for c in classes:
        idx_c = indices[y[indices] == c]
        idx_c = idx_c.copy()
        np.random.shuffle(idx_c)
        parts = np.array_split(idx_c, num_shards)
        for s in range(num_shards):
            shards[s].extend(parts[s].tolist())

    for s in range(num_shards):
        np.random.shuffle(shards[s])
    return shards


def sample_fixed_size(pool_indices: np.ndarray, size: int):
    pool_indices = np.asarray(pool_indices, dtype=int)
    if len(pool_indices) == 0:
        raise ValueError("采样池为空")
    if len(pool_indices) >= size:
        return np.random.choice(pool_indices, size=size, replace=False)
    else:
        return np.random.choice(pool_indices, size=size, replace=True)


def build_clients_from_shards(shards, client_train_size, num_clients=60, num_full_clients=10):
    assert num_full_clients <= num_clients
    clients = {}

    for cid in range(num_full_clients):
        chosen = [0, 1, 3, 4, 5]
        chosen = [s for s in chosen if s < len(shards)]
        if len(chosen) == 0:
            chosen = [0]
        pool = np.array([i for s in chosen for i in shards[s]], dtype=int)
        clients[cid] = sample_fixed_size(pool, client_train_size)

    for cid in range(num_full_clients, num_clients):
        chosen = [2] if 2 < len(shards) else [0]
        pool = np.array([i for s in chosen for i in shards[s]], dtype=int)
        clients[cid] = sample_fixed_size(pool, client_train_size)

    return clients


def compute_kde_weights(local_state_dicts, k):
    flat = np.array([
        torch.cat([p.flatten() for p in sd.values()]).detach().cpu().numpy()
        for sd in local_state_dicts
    ])

    dist_matrix = squareform(pdist(flat))
    h = np.median(dist_matrix) + 1e-12

    N = len(local_state_dicts)
    kde_vals = np.zeros(N)

    for i in range(N):
        idx = np.argsort(dist_matrix[i])[1:k+1]
        d = dist_matrix[i, idx]
        kde_vals[i] = np.mean(np.exp(-d*d/(2*h*h)))

    kde_vals += 1e-12

    part_vals = kde_vals[NUM_FULL_CLIENTS:]
    full_vals = kde_vals[:NUM_FULL_CLIENTS]

    part_norm = part_vals / part_vals.sum()
    full_norm = full_vals / full_vals.sum()

    part_avg = np.mean(part_vals)
    full_avg = np.mean(full_vals)

    part_w = part_avg / (part_avg + full_avg)
    full_w = full_avg / (part_avg + full_avg)

    part_norm = part_norm * part_w
    full_norm = full_norm * full_w

    mixed = np.concatenate([full_norm, part_norm], axis=0)
    return mixed.tolist()


def main(k_value):
    df = load_arff_to_dataframe(ARFF_PATH)
    target_col = pick_target_col(df)

    X_raw, y_price, preprocessor = build_preprocessor(df, target_col)

    y_bins, mask, actual_bins = bin_continuous_to_classes(y_price, num_bins=NUM_BINS)

    X_raw = X_raw.loc[mask].reset_index(drop=True)
    y_raw = pd.Series(y_bins.values, index=X_raw.index).astype(int).astype(str)

    if actual_bins < 2:
        raise ValueError("actual_bins < 2")

    y_raw2, rare_classes, vc = merge_rare_classes(y_raw, min_count=2, rare_label="__RARE__")

    X_train_raw, X_test_raw, y_train_raw, y_test_raw = train_test_split(
        X_raw, y_raw2, test_size=0.2, random_state=SEED, stratify=y_raw2
    )

    X_train = preprocessor.fit_transform(X_train_raw)
    X_test = preprocessor.transform(X_test_raw)

    if hasattr(X_train, "toarray"):
        X_train = X_train.toarray()
    if hasattr(X_test, "toarray"):
        X_test = X_test.toarray()

    classes = sorted(pd.Series(y_train_raw).unique().tolist())
    class_to_id = {c: i for i, c in enumerate(classes)}

    y_train = pd.Series(y_train_raw).map(class_to_id).to_numpy(dtype=np.int64)
    y_test = pd.Series(y_test_raw).map(class_to_id).to_numpy(dtype=np.int64)

    in_dim = X_train.shape[1]
    num_classes = len(classes)

    test_ds = TensorDataset(torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.long))
    test_loader = DataLoader(test_ds, batch_size=1024, shuffle=False)

    train_indices = np.arange(len(X_train))
    shards = make_shards(train_indices, y_train, num_shards=NUM_SHARDS)

    clients = build_clients_from_shards(
        shards=shards,
        client_train_size=CLIENT_TRAIN_SIZE,
        num_clients=NUM_CLIENTS,
        num_full_clients=NUM_FULL_CLIENTS
    )

    global_model = MLP(in_dim, num_classes).to(DEVICE)
    global_params = get_model_params(global_model)

    global_accs = []

    for rnd in range(1, ROUNDS + 1):
        all_cids = list(range(NUM_CLIENTS))

        client_params_list = []
        client_weights = []
        local_state_dicts = []

        for cid in all_cids:
            idx = clients[cid]
            xb = torch.tensor(X_train[idx], dtype=torch.float32)
            yb = torch.tensor(y_train[idx], dtype=torch.long)
            train_loader = DataLoader(TensorDataset(xb, yb), batch_size=BATCH_SIZE, shuffle=True)

            local_model = MLP(in_dim, num_classes).to(DEVICE)
            set_model_params(local_model, global_params)

            train_one_client(local_model, train_loader, epochs=LOCAL_EPOCHS, lr=LR)

            client_params_list.append(get_model_params(local_model))
            client_weights.append(len(idx))
            local_state_dicts.append(local_model.state_dict())

        if USE_FEDKDE:
            weights = compute_kde_weights(local_state_dicts, k_value)
        else:
            weights = client_weights

        global_params = fedavg(client_params_list, weights)
        set_model_params(global_model, global_params)

        acc = evaluate_acc(global_model, test_loader)
        global_accs.append(acc)
        print(f"[Round {rnd:02d}] Global Test Acc = {acc:.4f}")

    print(f"Final Test Acc = {global_accs[-1]:.4f}")
    return global_accs


if __name__ == "__main__":
    plt.figure(figsize=(6, 4), facecolor='white')
    rcParams['font.family'] = 'Times New Roman'

    ax = plt.gca()
    ax.set_facecolor((230/255, 230/255, 238/255))
    for side in ['bottom', 'top', 'right', 'left']:
        ax.spines[side].set_color('white')

    colors = [
        (128/255, 149/255, 192/255),
        (215/255, 164/255, 133/255),
        (141/255, 185/255, 149/255),
        (197/255, 128/255, 131/255),
        (158/255, 149/255, 192/255),
        (171/255, 155/255, 140/255),
    ]
    linestyles = ['-', '--', ':', '-.', (0, (3, 1, 1, 1)), (0, (8, 2))]

    k_values = [5, 8, 10, 13, 15, 20]
    for idx, k_value in enumerate(k_values):
        global_accs = main(k_value)
        plt.plot(global_accs, label=f'k={k_value}', color=colors[idx], linestyle=linestyles[idx])

    plt.xlabel("Round")
    plt.ylabel("Accuracy")
    plt.title("autos")
    plt.grid(True, color='white', zorder=0)
    legend = plt.legend()
    legend.get_frame().set_facecolor((230/255, 230/255, 238/255))
    plt.show()
