import os
import random
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from scipy.spatial.distance import pdist, squareform
from matplotlib import rcParams


CLEAN_PTH = "./saved_models/mnist/without_noisy/round8_local_models.pth"
NOISY_PTH = "./saved_models/mnist/with_noisy/round8_local_models.pth"


TOTAL_CLIENTS = 60
NUM_FULL = 25
NUM_BINARY = TOTAL_CLIENTS - NUM_FULL
START_CLIENT = 0
VALUE_MODE = "mean"

KF_ALPHA = 1e-2
GAUSSIAN_STD = 0.02
KF_BETA = GAUSSIAN_STD
KF_P0 = 1.0
LAMBDA = 0.1
RHO = 0.1


NUM_CLIENTS = 60
NOISY_CLIENT_NUM = 25
GAUSSION_NOISE_PARAMETER = 0.3
SALTPEPPER_NOISE_PARAMETER = 0.1
PIXELFLIP_PARAMETER = 0.1
EPOCHS = 1
ROUNDS = 9
BATCH_SIZE = 64
LR = 0.01
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAVE_ROUND = 8
SAVE_PATH_WITH_NOISY = "./saved_models/mnist/with_noisy/"
SAVE_PATH_WITHOUT_NOISY = "./saved_models/mnist/without_noisy/"
SEED = 60


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=1.0):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean


class AddSaltPepperNoise(object):
    def __init__(self, prob=0.3):
        self.prob = prob

    def __call__(self, tensor):
        noise = torch.rand(tensor.size())
        tensor = tensor.clone()
        tensor[noise < self.prob / 2] = 0.0
        tensor[noise > 1 - self.prob / 2] = 1.0
        return tensor


class RandomPixelFlip(object):
    def __init__(self, ratio=0.2):
        self.ratio = ratio

    def __call__(self, tensor):
        total = tensor.numel()
        idx = torch.randperm(total)[: int(total * self.ratio)]
        tensor = tensor.view(-1)
        tensor[idx] = 1 - tensor[idx]
        return tensor.view(1, 28, 28)


def flip_label(y, p=0.5):
    if random.random() < p:
        return random.randint(0, 9)
    return y


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = torch.relu(torch.max_pool2d(self.conv1(x), 2))
        x = torch.relu(torch.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)


def get_loader(dataset):
    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


def local_train(model, loader, noisy=False):
    model.train()
    opt = optim.SGD(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()
    for _ in range(EPOCHS):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            opt.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            opt.step()
    return model.state_dict()


def test(model, loader):
    correct, total = 0, 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            pred = model(x).argmax(dim=1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
    return correct / total


def compute_kde_weights(local_weights, k):
    flat = np.array(
        [
            torch.cat([p.flatten() for p in w.values()]).detach().cpu().numpy()
            for w in local_weights
        ]
    )
    dist_matrix = squareform(pdist(flat))
    h = np.median(dist_matrix) + 1e-12
    N = len(local_weights)
    kde_vals = np.zeros(N)

    for i in range(N):
        idx = np.argsort(dist_matrix[i])[1 : k + 1]
        d = dist_matrix[i, idx]
        kde_vals[i] = np.mean(np.exp(-d * d / (2 * h * h)))

    kde_vals += 1e-12

    binary_kde_vals = kde_vals[:NUM_BINARY] / kde_vals[:NUM_BINARY].sum()
    binary_average_kde_val = np.mean(kde_vals[:NUM_BINARY])

    full_kde_vals = kde_vals[NUM_BINARY:] / kde_vals[NUM_BINARY:].sum()
    full_average_kde_val = np.mean(kde_vals[NUM_BINARY:])

    binary_weight = binary_average_kde_val / (binary_average_kde_val + full_average_kde_val)
    full_weight = full_average_kde_val / (binary_average_kde_val + full_average_kde_val)

    binary_kde_vals = [x * binary_weight for x in binary_kde_vals]
    full_kde_vals = [x * full_weight for x in full_kde_vals]
    kde_vals = binary_kde_vals + full_kde_vals

    return kde_vals


def fed_kde_avg(local_weights, kde_weights):
    keys = local_weights[0].keys()
    out = {}
    for kk in keys:
        out[kk] = sum(kde_weights[i] * local_weights[i][kk] for i in range(len(local_weights)))
    return out


def flatten_state_dict(state_dict):
    return torch.cat([p.detach().cpu().flatten() for p in state_dict.values()]).numpy()


def tensor_to_value(t: torch.Tensor, mode: str) -> float:
    x = t.detach().cpu().numpy().reshape(-1)
    N = x.shape[0]
    if mode == "mean":
        return float(x.mean())
    if mode == "l2":
        return float(np.linalg.norm(x, ord=2))
    if mode == "ms":
        return float((x**2).sum() / N)
    raise ValueError("VALUE_MODE must be 'mean', 'l2', or 'ms'")


def kalman_filter_scalar_sequence(observations, alpha, beta, p0=1.0, m0=None):
    z = np.asarray(observations, dtype=np.float64)
    n = len(z)
    if n == 0:
        return np.array([]), np.array([]), np.array([])

    M = float(z[0]) if m0 is None else float(m0)
    P = float(p0)

    M_filt = np.zeros(n, dtype=np.float64)
    K_list = np.zeros(n, dtype=np.float64)
    P_list = np.zeros(n, dtype=np.float64)

    for k in range(n):
        if k == TOTAL_CLIENTS - NUM_FULL:
            M_minus = RHO * M + (1 - RHO) * z[k]
        else:
            M_minus = M

        if k == TOTAL_CLIENTS - NUM_FULL:
            P_minus = (1.0 / LAMBDA) * P + alpha
        else:
            P_minus = P + alpha

        K = P_minus / (P_minus + beta)
        M = M_minus + K * (z[k] - M_minus)
        P = (1.0 - K) ** 2 * P_minus + K**2 * beta

        M_filt[k] = M
        K_list[k] = K
        P_list[k] = P

    return M_filt, K_list, P_list


def analyze_kf_plot(
    clean_pth=CLEAN_PTH,
    noisy_pth=NOISY_PTH,
    value_mode=VALUE_MODE,
    start_client=START_CLIENT,
    kf_alpha=KF_ALPHA,
    kf_beta=KF_BETA,
    kf_p0=KF_P0,
):
    if not os.path.exists(clean_pth):
        raise FileNotFoundError(f"Clean model file not found: {clean_pth}")
    if not os.path.exists(noisy_pth):
        raise FileNotFoundError(f"Noisy model file not found: {noisy_pth}")

    clean_models = torch.load(clean_pth, map_location="cpu")
    noisy_models = torch.load(noisy_pth, map_location="cpu")

    if not isinstance(clean_models, (list, tuple)) or len(clean_models) == 0:
        raise ValueError("CLEAN_PTH is not a non-empty list/tuple of state_dicts")
    if not isinstance(noisy_models, (list, tuple)) or len(noisy_models) == 0:
        raise ValueError("NOISY_PTH is not a non-empty list/tuple of state_dicts")
    if len(clean_models) != len(noisy_models):
        raise ValueError(f"Mismatch: {len(clean_models)} vs {len(noisy_models)}")

    num_clients = len(clean_models)
    if start_client < 0 or start_client >= num_clients:
        raise ValueError(f"START_CLIENT out of range: {start_client}")

    fv_clean = np.array([flatten_state_dict(m) for m in clean_models])
    dist_mat = squareform(pdist(fv_clean))
    order = np.argsort(dist_mat[start_client])

    param_name = list(clean_models[0].keys())[0]
    clean_vals = np.array(
        [tensor_to_value(clean_models[i][param_name], value_mode) for i in order],
        dtype=np.float64,
    )
    noisy_vals = np.array(
        [tensor_to_value(noisy_models[i][param_name], value_mode) for i in order],
        dtype=np.float64,
    )

    noise_kf, _, _ = kalman_filter_scalar_sequence(
        noisy_vals, alpha=kf_alpha, beta=kf_beta, p0=kf_p0, m0=noisy_vals[0]
    )
    noisy_kf = noise_kf

    x = np.arange(1, num_clients + 1)

    idx = np.linspace(0, len(clean_vals) - 1, 10, dtype=int)

    x_sample = x[idx]
    clean_sample = clean_vals[idx]
    noisy_sample = noisy_vals[idx]
    kf_sample = noisy_kf[idx]

    plt.figure(figsize=(6, 4))
    rcParams["font.family"] = "Times New Roman"

    ax = plt.gca()
    ax.set_facecolor((230 / 255, 230 / 255, 238 / 255))
    ax.spines["bottom"].set_color("white")
    ax.spines["top"].set_color("white")
    ax.spines["right"].set_color("white")
    ax.spines["left"].set_color("white")

    width = 1.6

    plt.bar(x_sample - width, clean_sample, width=width, label="without noisy", zorder=2)
    plt.bar(x_sample, kf_sample, width=width, label="with noisy after filter", zorder=4)
    plt.bar(x_sample + width, noisy_sample, width=width, label="with noisy before filter", zorder=3)

    plt.ylim(top=0.03)
    plt.xlabel("clients")
    plt.ylabel("margin")
    plt.title("MNIST")
    plt.grid(True, color="white", zorder=0)
    legend = plt.legend()
    legend.get_frame().set_facecolor((230 / 255, 230 / 255, 238 / 255))
    plt.tight_layout()
    plt.show()


def add_zero_mean_noise_to_models(clean_path, noisy_path, std):
    if not os.path.exists(clean_path):
        raise FileNotFoundError("CLEAN_PTH not found")

    local_models = torch.load(clean_path, map_location="cpu")
    noisy_models = []
    for state_dict in local_models:
        noisy_state = {}
        for k, v in state_dict.items():
            if isinstance(v, torch.Tensor):
                noise = torch.randn_like(v) * std
                noisy_state[k] = v + noise
            else:
                noisy_state[k] = v
        noisy_models.append(noisy_state)

    os.makedirs(os.path.dirname(noisy_path), exist_ok=True)
    torch.save(noisy_models, noisy_path)


def train_fedkde(save_path=SAVE_PATH_WITH_NOISY):
    set_seed(SEED)
    os.makedirs(save_path, exist_ok=True)

    clean_transform = transforms.Compose([transforms.ToTensor()])
    train_dataset_clean = datasets.MNIST("./data", train=True, download=True, transform=clean_transform)
    test_dataset = datasets.MNIST("./data", train=False, download=True, transform=clean_transform)

    noisy_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            AddGaussianNoise(std=GAUSSION_NOISE_PARAMETER),
        ]
    )
    train_dataset_noisy = datasets.MNIST("./data", train=True, download=True, transform=noisy_transform)

    labels = np.array(train_dataset_clean.targets)
    label_indices = {l: np.where(labels == l)[0] for l in range(10)}
    for l in label_indices:
        np.random.shuffle(label_indices[l])

    client_datasets_with_noisy = []
    client_datasets_without_noisy = []

    NOISY_CLIENTS = np.random.choice(NUM_CLIENTS, size=NOISY_CLIENT_NUM, replace=False)

    for i in range(NUM_BINARY):
        rand0 = rand1 = 600
        start0 = np.random.randint(0, len(label_indices[0]) - rand0)
        start1 = np.random.randint(0, len(label_indices[1]) - rand1)

        idx0 = label_indices[0][start0 : start0 + rand0]
        idx1 = label_indices[1][start1 : start1 + rand1]

        EXTRA_LABEL_NUM = np.random.randint(1, 6)
        extra_labels = np.random.choice(range(2, 10), size=EXTRA_LABEL_NUM, replace=False)

        extra_indices = []
        for lbl in extra_labels:
            total_lbl = len(label_indices[lbl])
            max_take = min(80, total_lbl)
            min_take = min(20, max_take)
            take = np.random.randint(min_take, max_take + 1)

            if total_lbl - take <= 1:
                start_lbl = 0
            else:
                start_lbl = np.random.randint(0, total_lbl - take)

            extra_indices.extend(label_indices[lbl][start_lbl : start_lbl + take])

        extra_indices = np.array(extra_indices)
        subset_indices = np.concatenate([idx0, idx1, extra_indices])
        np.random.shuffle(subset_indices)

        base_dataset = train_dataset_noisy if (i in NOISY_CLIENTS) else train_dataset_clean
        client_datasets_with_noisy.append(Subset(base_dataset, subset_indices))
        client_datasets_without_noisy.append(Subset(train_dataset_clean, subset_indices))

    indices_2_9 = []
    for label in range(2, 10):
        label_idx = np.where(train_dataset_clean.targets == label)[0]
        indices_2_9.extend(label_idx)

    indices_2_9 = np.array(indices_2_9)
    np.random.shuffle(indices_2_9)

    indices_0 = np.where(train_dataset_clean.targets == 0)[0]
    indices_1 = np.where(train_dataset_clean.targets == 1)[0]
    indices_0_1 = np.concatenate([indices_0, indices_1])
    np.random.shuffle(indices_0_1)

    samples_per_client = len(indices_2_9) // NUM_FULL

    for i in range(NUM_FULL):
        subset_main = indices_2_9[i * samples_per_client : (i + 1) * samples_per_client]

        extra_num = np.random.randint(10, 31)
        if extra_num > len(indices_0_1):
            extra_num = len(indices_0_1)
        extra_indices = np.random.choice(indices_0_1, size=extra_num, replace=False)

        subset = np.concatenate([subset_main, extra_indices])
        np.random.shuffle(subset)

        full_global_id = NUM_BINARY + i
        base_dataset = train_dataset_noisy if (full_global_id in NOISY_CLIENTS) else train_dataset_clean
        client_datasets_with_noisy.append(Subset(base_dataset, subset))
        client_datasets_without_noisy.append(Subset(train_dataset_clean, subset))

    test_loader = get_loader(test_dataset)
    k_values = [10]
    global_accs = []

    for k_value in k_values:
        global_acc = []
        global_model = CNN().to(DEVICE)

        for r in range(ROUNDS):
            local_weights = []
            for i in range(NUM_CLIENTS):
                local_model = CNN().to(DEVICE)
                local_model.load_state_dict(global_model.state_dict())
                loader = get_loader(client_datasets_with_noisy[i])
                w = local_train(local_model, loader, noisy=(i in NOISY_CLIENTS))
                local_weights.append(w)

            kde_weights = compute_kde_weights(local_weights, k_value)
            aggregated = fed_kde_avg(local_weights, kde_weights)
            global_model.load_state_dict(aggregated)

            acc = test(global_model, test_loader)
            global_acc.append(acc)

            if (r + 1) == SAVE_ROUND:
                torch.save(local_weights, f"{save_path}/round{SAVE_ROUND}_local_models.pth")

        global_accs.append(global_acc)

    plt.figure(figsize=(6, 4), facecolor="white")
    rcParams["font.family"] = "Times New Roman"

    ax = plt.gca()
    ax.set_facecolor((230 / 255, 230 / 255, 238 / 255))
    ax.spines["bottom"].set_color("white")
    ax.spines["top"].set_color("white")
    ax.spines["right"].set_color("white")
    ax.spines["left"].set_color("white")

    linestyles = ["-"]
    for idx, k_value in enumerate(k_values):
        plt.plot(global_accs[idx], label=f"k={k_values[idx]}", linestyle=linestyles[idx % len(linestyles)])

    plt.xlabel("Round")
    plt.ylabel("Accuracy")
    plt.title("MNIST")
    plt.grid(True, color="white", zorder=0)
    legend = plt.legend()
    legend.get_frame().set_facecolor((230 / 255, 230 / 255, 238 / 255))
    plt.show()


def _build_cli():
    p = argparse.ArgumentParser()
    sub = p.add_subparsers(dest="cmd", required=True)

    p_train = sub.add_parser("train")
    p_train.add_argument("--save_path", type=str, default=SAVE_PATH_WITH_NOISY)

    p_noise = sub.add_parser("noise")
    p_noise.add_argument("--clean_path", type=str, default=CLEAN_PTH)
    p_noise.add_argument("--noisy_path", type=str, default=NOISY_PTH)
    p_noise.add_argument("--std", type=float, default=GAUSSIAN_STD)

    p_an = sub.add_parser("analyze")
    p_an.add_argument("--clean_path", type=str, default=CLEAN_PTH)
    p_an.add_argument("--noisy_path", type=str, default=NOISY_PTH)
    p_an.add_argument("--value_mode", type=str, default=VALUE_MODE)
    p_an.add_argument("--start_client", type=int, default=START_CLIENT)
    p_an.add_argument("--kf_alpha", type=float, default=KF_ALPHA)
    p_an.add_argument("--kf_beta", type=float, default=KF_BETA)
    p_an.add_argument("--kf_p0", type=float, default=KF_P0)

    return p


def main():
    parser = _build_cli()
    args = parser.parse_args()

    if args.cmd == "train":
        train_fedkde(save_path=args.save_path)
        return

    if args.cmd == "noise":
        add_zero_mean_noise_to_models(args.clean_path, args.noisy_path, args.std)
        return

    if args.cmd == "analyze":
        analyze_kf_plot(
            clean_pth=args.clean_path,
            noisy_pth=args.noisy_path,
            value_mode=args.value_mode,
            start_client=args.start_client,
            kf_alpha=args.kf_alpha,
            kf_beta=args.kf_beta,
            kf_p0=args.kf_p0,
        )
        return


if __name__ == "__main__":
    main()
