import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from scipy.spatial.distance import pdist, squareform
from matplotlib import rcParams

CLEAN_PTH = "./saved_models/fashion_mnist/without_noisy/round8_local_models.pth"
NOISY_PTH = "./saved_models/fashion_mnist/with_noisy/round8_local_models.pth"

TOTAL_CLIENTS = 60
NUM_FULL = 25
NUM_BINARY = TOTAL_CLIENTS - NUM_FULL
START_CLIENT = 0
VALUE_MODE = "mean"

KF_ALPHA = 1e-2
GAUSSIAN_STD = 0.02
KF_BETA = GAUSSIAN_STD
KF_P0 = 1.0
LAMBDA = 0.1
RHO = 0.1

NUM_CLIENTS = 60
NOISY_CLIENT_NUM = 25
GAUSSION_NOISE_PARAMETER = 0.3
EPOCHS = 1
ROUNDS = 9
BATCH_SIZE = 64
LR = 0.01
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAVE_ROUND = 8
SAVE_PATH = "./saved_models/fashion_mnist/without_noisy/"
SEED = 60

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def flatten_state_dict(state_dict):
    return torch.cat([p.detach().cpu().flatten() for p in state_dict.values()]).numpy()

def tensor_to_value(t, mode):
    x = t.detach().cpu().numpy().reshape(-1)
    if mode == "mean":
        return float(x.mean())
    if mode == "l2":
        return float(np.linalg.norm(x, ord=2))
    if mode == "ms":
        return float((x ** 2).sum() / x.shape[0])
    raise ValueError

def kalman_filter_scalar_sequence(observations, alpha, beta, p0=1.0, m0=None):
    z = np.asarray(observations, dtype=np.float64)
    n = len(z)
    M = float(z[0]) if m0 is None else float(m0)
    P = float(p0)
    M_filt = np.zeros(n)
    K_list = np.zeros(n)
    P_list = np.zeros(n)
    for k in range(n):
        if k == TOTAL_CLIENTS - NUM_FULL:
            M_minus = RHO * M + (1 - RHO) * z[k]
            P_minus = (1.0 / LAMBDA) * P + alpha
        else:
            M_minus = M
            P_minus = P + alpha
        K = P_minus / (P_minus + beta)
        M = M_minus + K * (z[k] - M_minus)
        P = (1.0 - K) ** 2 * P_minus + K ** 2 * beta
        M_filt[k] = M
        K_list[k] = K
        P_list[k] = P
    return M_filt, K_list, P_list

def add_zero_mean_noise_to_models(clean_path, noisy_path, std):
    local_models = torch.load(clean_path, map_location="cpu")
    noisy_models = []
    for state_dict in local_models:
        noisy_state = {}
        for k, v in state_dict.items():
            if isinstance(v, torch.Tensor):
                noisy_state[k] = v + torch.randn_like(v) * std
            else:
                noisy_state[k] = v
        noisy_models.append(noisy_state)
    os.makedirs(os.path.dirname(noisy_path), exist_ok=True)
    torch.save(noisy_models, noisy_path)

class AddGaussianNoise(object):
    def __init__(self, std):
        self.std = std
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, 5)
        self.conv2 = nn.Conv2d(10, 20, 5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
    def forward(self, x):
        x = torch.relu(torch.max_pool2d(self.conv1(x), 2))
        x = torch.relu(torch.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

def get_loader(dataset):
    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

def local_train(model, loader):
    model.train()
    opt = optim.SGD(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()
    for _ in range(EPOCHS):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            opt.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            opt.step()
    return model.state_dict()

def test(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            pred = model(x).argmax(dim=1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
    return correct / total

def compute_kde_weights(local_weights, k):
    flat = np.array([torch.cat([p.flatten() for p in w.values()]).cpu().numpy() for w in local_weights])
    dist_matrix = squareform(pdist(flat))
    h = np.median(dist_matrix) + 1e-12
    N = len(local_weights)
    kde_vals = np.zeros(N)
    for i in range(N):
        idx = np.argsort(dist_matrix[i])[1:k+1]
        d = dist_matrix[i, idx]
        kde_vals[i] = np.mean(np.exp(-d*d/(2*h*h)))
    kde_vals += 1e-12
    binary = kde_vals[:NUM_BINARY]
    full = kde_vals[NUM_BINARY:]
    bw = binary.mean()
    fw = full.mean()
    binary = binary / binary.sum() * bw / (bw + fw)
    full = full / full.sum() * fw / (bw + fw)
    return list(binary) + list(full)

def fed_kde_avg(local_weights, weights):
    out = {}
    for k in local_weights[0].keys():
        out[k] = sum(weights[i] * local_weights[i][k] for i in range(len(local_weights)))
    return out

def main():
    set_seed(SEED)
    add_zero_mean_noise_to_models(CLEAN_PTH, NOISY_PTH, GAUSSIAN_STD)

    clean_models = torch.load(CLEAN_PTH, map_location="cpu")
    noisy_models = torch.load(NOISY_PTH, map_location="cpu")

    fv_clean = np.array([flatten_state_dict(m) for m in clean_models])
    dist_mat = squareform(pdist(fv_clean))
    order = np.argsort(dist_mat[START_CLIENT])

    param_name = list(clean_models[0].keys())[0]
    clean_vals = np.array([tensor_to_value(clean_models[i][param_name], VALUE_MODE) for i in order])
    noisy_vals = np.array([tensor_to_value(noisy_models[i][param_name], VALUE_MODE) for i in order])

    noisy_kf, _, _ = kalman_filter_scalar_sequence(noisy_vals, KF_ALPHA, KF_BETA, KF_P0)
    idx = np.linspace(0, len(clean_vals) - 1, 10, dtype=int)

    rcParams['font.family'] = 'Times New Roman'
    plt.figure(figsize=(6, 4))
    width = 1.6
    plt.bar(idx - width, clean_vals[idx], width)
    plt.bar(idx, noisy_kf[idx], width)
    plt.bar(idx + width, noisy_vals[idx], width)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()
