import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
from regularizer import loss_ridge, loss_conf, loss_itv
from sklearn.metrics import roc_auc_score, average_precision_score

def train_adam(model, X, lr, max_iter, lam_h, lam_c, lam_v, lookback=5, check_every=10,
               verbose=True, label=None, device="cpu"):
    """
        Training with Adam + (optional) proximal step.

        Data convention:
          - X has shape [num_patches, B, T, d]
          - For each patch k, each node i predicts x[:, 1:, i] from x[:, :-1, :].
          - Loss is averaged across nodes (implicitly by dividing by d in the report).
        Returns:
          - train_loss: list of mean losses (per report step)
    """
    model = model.to(device)
    X = X.to(device)

    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss(reduction="mean")

    num_env, B, T, d = X.shape

    assert d == model.d, "X last dim must equal model.d"
    assert T == model.T + 1, "X time length must equal model.T"
    assert num_env == model.n, "X first dim must equal model.n"

    best_it = None
    best_loss = np.inf
    best_model = None

    for it in range(max_iter):

        loss_p = 0.0
        for idx, x in enumerate(X):
            loss_p += sum([loss_fn(model.networks[i](x[:,:-1,:], idx=idx), x[:,1:,i:i+1]) for i in range(d)])

        # Regularization
        loss_h = sum([loss_ridge(net, lam_h) for net in model.networks])  if lam_h > 0.0 else 0.0 # loss_h for hidden layers.
        loss_c = sum([loss_conf(net, lam_c) for net in model.networks]) if lam_c > 0.0 else 0.0
        loss_v = sum([loss_itv(net, lam_v) for net in model.networks]) if lam_v > 0.0 else 0.0
        loss = loss_p + loss_h + loss_c + loss_v

        loss.backward()
        optimizer.step()
        model.zero_grad()

        # Check Process
        if (it + 1) % check_every == 0:
            mean_loss = loss / d

            if verbose:
                print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1))
                print('Loss = %f' % mean_loss)

                if label is not None and label.size > 0:  # Ensure label is a valid non-empty array
                    est = model.est_gc(threshold=None).detach().cpu().numpy()
                    # Check if label and est shapes match
                    if label.shape != est.shape:
                        raise ValueError(f"Shape Mismatch: Label Size: {label.shape}; Pred Size: {est.shape}")

                    auc_roc = roc_auc_score(label.flatten(), est.flatten())
                    au_prc = average_precision_score(label.flatten(), est.flatten())
                    print(f"AUROC: {auc_roc:.4f}, AUPRC: {au_prc:.4f}")

                # Early Stopping
                if mean_loss < best_loss - 1e-8:
                    best_loss = mean_loss
                    best_it = it
                    best_model = deepcopy(model.state_dict())
                elif it - best_it >= lookback * check_every:
                    if verbose:
                        print(f"Early stopping at iter {it}")
                    break
    return best_model