# %%
import json
import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import utils as u
import seaborn as sns
import torch
import torch.nn as nn
import pandas as pd
from tqdm import tqdm

cmap = sns.color_palette("colorblind", 10)

# %%

dataset_path = Path("../leaf/data/femnist/data/train")

# Créer la liste des agents
agents = []

for json_file in sorted(dataset_path.glob("all_data_*_train_*")):
    with open(json_file, "r") as f:
        data = json.load(f)

    # Chaque user dans le fichier
    for user in data["users"]:
        user_x = data["user_data"][user]["x"]
        user_y = data["user_data"][user]["y"]
        agents.append(
            {
                "id": user,
                "x": user_x,
                "y": user_y,
                "file": json_file.name,  # fichier d'origine
            }
        )

print(f"Nombre total d'agents chargés : {len(agents)}")
print("Exemple d'un agent :")
print(agents[0])

# %% test dra

agents_test = []
dataset_path = Path("../leaf/data/femnist/data/test")

# Parcourir tous les fichiers JSON de type all_data_*_test_*
for json_file in sorted(dataset_path.glob("all_data_*_test_*")):
    with open(json_file, "r") as f:
        data = json.load(f)

    # Chaque user dans le fichier
    for user in data["users"]:
        user_x = data["user_data"][user]["x"]
        user_y = data["user_data"][user]["y"]
        agents_test.append(
            {
                "id": user,
                "x": user_x,
                "y": user_y,
                "file": json_file.name,  # fichier d'origine
            }
        )


# %%
def plot_points_per_agent(agents, title="Number of points by agents"):
    counts = np.sort([len(agent["x"]) for agent in agents])
    plt.figure(figsize=(12, 6))
    plt.bar(range(len(counts)), counts)
    plt.xlabel("Agent index")
    plt.ylabel("Nombre d'exemples")
    plt.title(title)
    plt.show()


plot_points_per_agent(agents, title="Number of points by agents (est)")
plot_points_per_agent(agents_test, title="Number of points by agents (train)")


# %% Use test set as a train to keep numerous more data for evaluating the accuracy
X_list_test = [np.array(agent["x"]) for agent in agents]
Y_list_test = [np.array(agent["y"]) for agent in agents]

X_list = [np.array(agent["x"]) for agent in agents_test]
Y_list = [np.array(agent["y"]) for agent in agents_test]


# %%


class PerceptronClassifier(nn.Module):
    def __init__(self, d, hidden_dim=32, n_classes=62):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_classes),
        )

    def forward(self, x):
        return self.net(x)


# %%
def train_weighted_classifier(
    X_list,
    y_list,
    weights,
    d,
    Xtest,
    ytest,
    n_classes=62,
    n_epochs=50,
    lr=1e-3,
    hidden_dim=32,
    device="cpu",
    patience=100,
):
    """
    Training of a perceptron
    """
    T = len(X_list)
    weights = torch.tensor(weights, dtype=torch.float32, device=device)

    model = PerceptronClassifier(d, hidden_dim, n_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss(
        reduction="none"
    )  # on appliquera les poids manuellement

    train_errors = []
    test_errors = []

    best_test_acc = 0.0
    epochs_no_improve = 0

    # Préparer Xtest et ytest
    Xtest_t = torch.tensor(Xtest, dtype=torch.float32, device=device)
    ytest_t = torch.tensor(ytest, dtype=torch.long, device=device)

    X_list_torch = [torch.tensor(X, dtype=torch.float32, device=device) for X in X_list]
    y_list_torch = [torch.tensor(y, dtype=torch.long, device=device) for y in y_list]

    for epoch in range(n_epochs):
        model.train()
        optimizer.zero_grad()
        total_loss = 0.0

        for t in range(T):
            if weights[t] != 0:
                Xt = X_list_torch[t]
                yt = y_list_torch[t]

                logits = model(Xt)
                loss_t = loss_fn(logits, yt).mean()
                total_loss += weights[t] * loss_t

        total_loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            # Test error
            logits_test = model(Xtest_t)
            preds_test = logits_test.argmax(dim=1)
            test_acc = (preds_test == ytest_t).float().mean().item()
            test_errors.append(test_acc)

        # Early stopping check
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(
                f"Early stopping at epoch {epoch} (no improvement for {patience} epochs)"
            )
            break

        if epoch % 50 == 0:
            print(f"Epoch {epoch:4d} | Test acc = {test_acc:.4f}")

    return model, train_errors, test_errors


# %%
D = 1000
RFFs, _, _ = u.RFF(X_list, D)

T = len(X_list)
M = 1
u0 = np.log(T)
c0 = np.sqrt(u0)
cbs = u0
weights_fed = u.Qaggregation(RFFs, 1, c0=c0, cbs=cbs)

Mat = np.zeros((T, T))
for i in range(T):
    Mat[i] = u.Qaggregation(RFFs, 1, i=i, c0=c0, cbs=cbs)


# %%
n_epochs = 2000
lr = 0.001
hidden_dim = 32
d = X_list[0].shape[1]
X = X_list
y = Y_list

train_sizes = [Xi.shape[0] for Xi in X_list]

Tmax = T
patience = 400

results = {"local": [], "grandmean": [], "fed": [], "weights": []}


for i in tqdm(range(Tmax)):
    print("Agent " + str(i))

    Xtest = X_list_test[i]
    ytest = Y_list_test[i]

    print("Local")
    weights_loc = np.zeros(T)
    weights_loc[i] = 1
    model_loc, trainloss_loc, testloss_loc = train_weighted_classifier(
        X,
        y,
        weights_loc,
        d,
        Xtest,
        ytest,
        n_epochs=n_epochs,
        lr=lr,
        hidden_dim=hidden_dim,
        patience=patience,
    )

    print("Grandmean")
    weights_gm = train_sizes / np.sum(train_sizes)
    model_gm, trainloss_gm, testloss_gm = train_weighted_classifier(
        X,
        y,
        weights_gm,
        d,
        Xtest,
        ytest,
        n_epochs=n_epochs,
        lr=lr,
        hidden_dim=hidden_dim,
        patience=patience,
    )

    print("Federated")
    weights_fed = u.Qaggregation(RFFs, 1, i=i, c0=c0, cbs=cbs)

    model_fed, trainloss_fed, testloss_fed = train_weighted_classifier(
        X,
        y,
        weights_fed,
        d,
        Xtest,
        ytest,
        n_epochs=n_epochs,
        lr=lr,
        hidden_dim=hidden_dim,
        patience=patience,
    )

    results["local"].append(np.max(testloss_loc))
    results["grandmean"].append(np.max(testloss_gm))
    results["fed"].append(np.max(testloss_fed))


# %%
iterations = range(1, len(results["local"]) + 1)
sort_index = np.argsort(results["fed"])

# Keep coherent colors
method_colors = {
    "Local": "#1f77b4",  # bleu
    "Grand Mean": "#ff7f0e",  # orange
    "Fed": "#2ca02c",  # vert
}

fig, axes = plt.subplots(1, 2, figsize=(10, 3.5), gridspec_kw={"width_ratios": [2, 1]})

avecsize = False
if avecsize:
    sizes = train_sizes[: len(results["fed"])]
    sizes /= np.max(sizes)

    axes[0].bar(
        iterations,
        sizes[sort_index],
        color="lightgray",
        alpha=0.7,
        width=0.8,
        zorder=0,
    )

# ---- Plot 1 : Scatter per agents (left figure ----
axes[0].scatter(
    iterations,
    np.array(results["local"])[sort_index],
    s=20,  # size of points
    label="Local",
    color=method_colors["Local"],
    alpha=0.7,
    zorder=1,
    marker="^",
)
axes[0].scatter(
    iterations,
    np.array(results["grandmean"])[sort_index],
    s=20,
    label="Grand Mean",
    color=method_colors["Grand Mean"],
    alpha=0.7,
    zorder=1,
    marker="s",
)
axes[0].scatter(
    iterations,
    np.array(results["fed"])[sort_index],
    s=20,
    label="Q-aggreg",
    color=method_colors["Fed"],
    alpha=0.7,
    zorder=1,
    marker="o",
)

axes[0].set_xlabel("Agents", fontsize=16)
axes[0].set_ylabel("Accuracy", fontsize=16)
axes[0].legend(
    fontsize=14,
    loc="upper center",
    bbox_to_anchor=(0.5, 0.2),
    ncol=3,
)
axes[0].grid(True)
axes[0].set_xticks([])

# ---- Plot 2 : Boxplot (right) ----
df = pd.DataFrame(
    {
        "MSE": results["local"] + results["grandmean"] + results["fed"],
        "Method": ["Local"] * len(results["local"])
        + ["Grand Mean"] * len(results["grandmean"])
        + ["Fed"] * len(results["fed"]),
    }
)

sns.boxplot(x="Method", y="MSE", data=df, palette=method_colors, ax=axes[1], width=0.5)

subset_df = (
    df.groupby("Method")
    .apply(lambda x: x.sample(min(len(x), 20)))
    .reset_index(drop=True)
)

sns.stripplot(
    x="Method",
    y="MSE",
    data=subset_df,
    color="black",
    size=5,
    jitter=True,
    alpha=0.6,
    ax=axes[1],
)

axes[1].set_xlabel("")
axes[1].set_xticklabels(["Local", "G. Mean", "Q-aggreg"], fontsize=15)
axes[1].set_ylabel("")

plt.tight_layout()
plt.show()
