import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

##################################################
# 3-layer ReLU MLP baseline
##################################################

class MLP(nn.Module):
    def __init__(self, d, width):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d, width),
            nn.ReLU(),
            nn.Linear(width, 1)
        )

    def forward(self, x):
        return self.net(x).squeeze(-1)


def train_mlp(X, y, d, width, lr=1e-3, iters=1000, seed=0):

    torch.manual_seed(seed)
    np.random.seed(seed)

    X_t = torch.tensor(X, dtype=torch.float32)
    y_t = torch.tensor(y, dtype=torch.float32)

    model = MLP(d, width)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    losses = []

    for _ in range(iters):
        optimizer.zero_grad()

        pred = model(X_t)
        loss = ((pred - y_t) ** 2).mean()

        loss.backward()
        optimizer.step()

        losses.append(loss.item())

    return model, np.array(losses)

def eval_mlp(model, X_train, y_train, X_test, y_test):
    with torch.no_grad():
        Xtr = torch.tensor(X_train, dtype=torch.float32)
        Xte = torch.tensor(X_test, dtype=torch.float32)

        yhat_tr = model(Xtr).cpu().numpy()
        yhat_te = model(Xte).cpu().numpy()

    train_mse = np.mean((yhat_tr - y_train) ** 2)
    test_mse  = np.mean((yhat_te - y_test) ** 2)

    return train_mse, test_mse, yhat_tr, yhat_te


def run_mlp_experiment(
    X_train, y_train,
    X_test, y_test,
    X_train_test,
    y_train_test,
    d,
    width=64,
    seeds=10,
    lr=1e-3,
    iters=2000
):

    all_losses = []
    train_curves = []
    test_curves = []
    final_train_mse = []
    final_test_mse = []
    final_preds = []

    for s in range(seeds):

        model, losses = train_mlp(
            X_train, y_train,
            d=d,
            width=width,
            lr=lr,
            iters=iters,
            seed=s
        )

        train_mse, test_mse, yhat_tr, yhat_te = eval_mlp(
            model, X_train, y_train, X_test, y_test
        )

        with torch.no_grad():
            X_all = torch.tensor(X_train_test, dtype=torch.float32)
            yhat_all = model(X_all).cpu().numpy()

        all_losses.append(losses)
        final_train_mse.append(train_mse)
        final_test_mse.append(test_mse)
        final_preds.append(yhat_all)

    return {
        "losses": all_losses,
        "train_mse": np.array(final_train_mse),
        "test_mse": np.array(final_test_mse),
        "preds": final_preds
    }

def plot_mlp_results(results, title="MLP performance over seeds"):
    train = results["all_train"]
    test = results["all_test"]
    curves = results["loss_curves"]

    train_mean = results["train_mean"]
    test_mean = results["test_mean"]
    train_std = results["train_std"]
    test_std = results["test_std"]

    ##################################################
    # 1) Bar / summary plot (train vs test)
    ##################################################
    fig, ax = plt.subplots(1, 1, figsize=(6,4))

    ax.bar(["train", "test"], [train_mean, test_mean],
           yerr=[train_std, test_std],
           capsize=5)

    ax.set_title(title)
    ax.set_ylabel("MSE")

    ##################################################
    # 2) Loss curves over seeds
    ##################################################
    fig2, ax2 = plt.subplots(1, 1, figsize=(7,4))

    max_len = max(len(c) for c in curves)
    x = np.arange(max_len)

    # plot individual seeds (light)
    for c in curves:
        ax2.plot(c, color="gray", alpha=0.3)

    # plot mean curve
    padded = np.array([
        np.pad(c, (0, max_len - len(c)), mode="edge")
        for c in curves
    ])

    mean_curve = padded.mean(axis=0)

    ax2.plot(mean_curve, color="black", linewidth=2, label="mean loss")

    ax2.set_title("Training loss over initializations")
    ax2.set_xlabel("iteration")
    ax2.set_ylabel("loss")
    ax2.legend()

    plt.show()

    return fig, fig2
