# %%
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import utils as u
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.colors import LogNorm


import seaborn as sns
import numpy as np

cmap = sns.color_palette("colorblind", 10)


# %%
class Regressor(nn.Module):
    def __init__(self, d, hidden_dim=32, n_hidden_layers=1):
        super().__init__()

        layers = []
        layers.append(nn.Linear(d, hidden_dim))
        layers.append(nn.ReLU())

        for _ in range(n_hidden_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())

        layers.append(nn.Linear(hidden_dim, 1))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


# %%
def train_weighted_regressor(
    X_list,
    y_list,
    weights,
    d,
    Xtest,
    ytest,
    n_epochs=200,
    lr=1e-3,
    hidden_dim=32,
    n_hidden_layers=1,
    device="cpu",
):
    """
    Training of a small perceptron for regression

    Inputs :
    - X_list : list of np.ndarray, shape (n_t, d)
    - y_list : list of np.ndarray, shape (n_t, 1) ou (n_t,)
    - weights : array-like, weight by agent
    - d : dimension des features
    - Xtest, ytest : test data
    - n_epochs, lr, hidden_dim : hyperparameter

    Output :
    - model
    - train_errors : list of floats, weighted MSE on train
    - test_errors : list of floats, MSE on test
    """

    T = len(X_list)
    weights = torch.tensor(weights, dtype=torch.float32, device=device)
    weights = weights / weights.sum()  # normalisation

    model = Regressor(d, hidden_dim=hidden_dim, n_hidden_layers=n_hidden_layers).to(
        device
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss(reduction="mean")

    train_errors = []
    test_errors = []

    # Préparer Xtest et ytest
    Xtest_t = torch.tensor(Xtest, dtype=torch.float32, device=device)
    ytest_t = torch.tensor(ytest, dtype=torch.float32, device=device)
    if ytest_t.ndim == 1:
        ytest_t = ytest_t.view(-1, 1)

    for epoch in range(n_epochs):
        model.train()
        optimizer.zero_grad()
        total_loss = 0.0

        # Forward pour chaque agent
        for t in range(T):
            X = torch.tensor(X_list[t], dtype=torch.float32, device=device)
            y = torch.tensor(y_list[t], dtype=torch.float32, device=device)
            if y.ndim == 1:
                y = y.view(-1, 1)

            y_pred = model(X)
            loss_t = loss_fn(y_pred, y)

            total_loss += weights[t] * loss_t

        total_loss.backward()
        optimizer.step()

        # Calcul de l'erreur de train pondérée
        model.eval()
        with torch.no_grad():
            mse_train = 0.0
            for t in range(T):
                X = torch.tensor(X_list[t], dtype=torch.float32, device=device)
                y = torch.tensor(y_list[t], dtype=torch.float32, device=device)
                if y.ndim == 1:
                    y = y.view(-1, 1)
                mse_train += weights[t] * torch.mean((model(X) - y) ** 2)
            train_errors.append(mse_train.item())

            # Test error
            y_pred_test = model(Xtest_t)
            mse_test = torch.mean((y_pred_test - ytest_t) ** 2)
            test_errors.append(mse_test.item())

        if epoch % 50 == 0:
            print(
                f"Epoch {epoch:4d} | Train MSE = {mse_train.item():.4e} | Test MSE = {mse_test.item():.4e}"
            )

    return model, train_errors, test_errors


# %% data generations
def sample_multidim_gaussian(mean_2d, cov_2d, n, d):
    """
    X ∈ R^d
    - First two coordinates follow N(mean_2d, cov_2d)
    - Remaining coordinates are standard Gaussian
    """
    X2 = np.random.multivariate_normal(mean_2d, cov_2d, size=n)
    Xrest = np.random.randn(n, d - 2)
    return np.hstack([X2, Xrest])


def generate_agent_data_clustered_covshift(
    n_samples=100, d=10, noise_std=0.1, agent_id=0, n_close=5, n_medium=10, mean=None
):
    """
    Similar agents and others agents
    Same p(y|x), but different p(x).
    """

    assert d >= 2
    rng = np.random.default_rng(agent_id)

    # ---------- Distribution of X ----------
    if not mean is None:
        scale = 0.4
        X = rng.normal(mean, scale, size=(n_samples, d))
    elif agent_id < n_close:
        mean = np.zeros(d) + 0.1 * rng.normal(size=d)
        scale = 0.4
        X = rng.normal(mean, scale, size=(n_samples, d))

    elif agent_id < n_close + n_medium:
        mean = rng.normal(2, 0.3, size=d)
        scale = 0.8
        X = rng.normal(mean, scale, size=(n_samples, d))

    else:
        X = rng.uniform(-6, 6, size=(n_samples, d))
        mean = np.zeros(d)
    # ---------- true function ----------
    X1, X2 = X[:, 0], X[:, 1]
    signal = np.sin(3 * X1) + 0.5 * X2**2 + 0.1 * np.sum(X[:, 2:], axis=1)

    y = signal + noise_std * rng.standard_normal(n_samples)

    return X, y[:, None], mean


def generate_federated_dataset_clustered_covshift(
    n_agents=20,
    n_samples_per_agent=50,
    d=10,
    noise_std=0.1,
    n_close=5,
    n_medium=10,
):
    agents_data = []

    for agent_id in range(n_agents):
        X, y, mean = generate_agent_data_clustered_covshift(
            n_samples=n_samples_per_agent,
            d=d,
            noise_std=noise_std,
            agent_id=agent_id,
            n_close=n_close,
            n_medium=n_medium,
        )

        agents_data.append({"X": X, "y": y, "agent_id": agent_id, "mean": mean})

    return agents_data


# %%
D = 500

n_epochs = 2000
lr = 0.001

T = 100
d = 4
n_samples = 20
cov_scale = 0.5
noise_std = 0.2

n_close = 30
n_medium = 30
ntest = 2000

M = 1
u0 = 1
c0 = np.sqrt(u0)
cbs = u0

n_repet = 20

oracle = np.array([agent < n_close for agent in range(T)])

weights_loc = np.zeros(T)
weights_loc[0] = 1
weights_gm = np.ones(T) / T
weights_or = oracle / np.sum(oracle)
weights_fed = np.zeros(T)

architectures = [
    {"n_hidden_layers": 1, "hidden_dim": 16},
    {"n_hidden_layers": 1, "hidden_dim": 32},
    {"n_hidden_layers": 2, "hidden_dim": 16},
    {"n_hidden_layers": 2, "hidden_dim": 32},
    {"n_hidden_layers": 3, "hidden_dim": 16},
    {"n_hidden_layers": 3, "hidden_dim": 32},
]

methods = {
    "Local": weights_loc,
    "GrandMean": weights_gm,
    "Fed": weights_fed,
    "Oracle": weights_or,
}

results = {}

# initialisation
for arch_id, arch in enumerate(architectures):
    for method_name, weights in methods.items():
        results[(arch_id, method_name)] = []


for i in range(n_repet):
    data = generate_federated_dataset_clustered_covshift(
        n_agents=T,
        n_samples_per_agent=10,
        d=d,
        noise_std=noise_std,
        n_close=n_close,
        n_medium=n_medium,
    )

    X = [agent["X"] for agent in data]
    y = [agent["y"] for agent in data]

    Xtest, ytest, _ = generate_agent_data_clustered_covshift(
        n_samples=ntest,
        d=d,
        noise_std=noise_std,
        agent_id=0,
        n_close=n_close,
        n_medium=n_medium,
        mean=data[0]["mean"],
    )
    RFFs, _, _ = u.RFF(X, D)

    weights_fed = u.Qaggregation(RFFs, 1, c0=c0, cbs=cbs)
    methods["Fed"] = weights_fed

    for arch_id, arch in enumerate(architectures):
        print(
            f"\n=== Architecture {arch_id} | "
            f"{arch['n_hidden_layers']} layers, "
            f"{arch['hidden_dim']} units ==="
        )

        for method_name, weights in methods.items():

            print(f"  → {method_name}")

            model, train_loss, test_loss = train_weighted_regressor(
                X,
                y,
                weights,
                d,
                Xtest,
                ytest,
                n_epochs=n_epochs,
                lr=lr,
                hidden_dim=arch["hidden_dim"],
                n_hidden_layers=arch["n_hidden_layers"],
            )

            best_test_mse = np.min(test_loss)
            results[(arch_id, method_name)] += [best_test_mse]

            print(f" best test MSE = {best_test_mse:.3e}")


# %% test figures
rows = []
for arch_id, arch in enumerate(architectures):
    for method in methods:
        for mse in results[(arch_id, method)]:
            rows.append(
                {
                    "Architecture": arch_id,
                    "Layers": arch["n_hidden_layers"],
                    "Hidden dim": arch["hidden_dim"],
                    "Method": method,
                    "Test MSE": mse,
                }
            )

df_results = pd.DataFrame(rows)

# %%


df_resultsplot = df_results
df_resultsplot["Method"] = df_resultsplot["Method"].replace({"Fed": "Q-aggreg."})

sns.set_theme(style="whitegrid")

plt.figure(figsize=(7, 4))  # carré

arch_labels = [
    f"{arch['n_hidden_layers']}L-{arch['hidden_dim']}H" for arch in architectures
]

# Boxplot
sns.boxplot(
    data=df_resultsplot,
    x="Architecture",
    y="Test MSE",
    hue="Method",
)

fontsize = 18
plt.xticks(ticks=range(len(architectures)), labels=arch_labels, fontsize=fontsize - 2)

plt.xlabel("Architecture (Layers-HiddenDim)", fontsize=fontsize)
plt.ylabel("Test MSE", fontsize=fontsize)
# plt.yscale("log")

plt.legend()
plt.tight_layout()  # resserre les marges automatiquement
plt.savefig("boxplot_covshift_ncnm30_u01_D500_ntest2000_sigstd02.pdf")
plt.show()


# %% Similarity matric generation
Mat = np.zeros((T, T))
for i in range(T):
    Mat[i] = u.Qaggregation(RFFs, 1, i=i, c0=c0, cbs=cbs)

M = np.abs(M) + 1e-12  # éviter log(0)

plt.figure(figsize=(6, 5))

plt.imshow(M, cmap="viridis", norm=LogNorm())

cbar = plt.colorbar()

plt.grid(False)
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.show()
