from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_recall_curve, auc, average_precision_score
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
import torch.nn.functional as F

class EnergyMLP(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_sizes=[512], dropout=0.0, activation="gelu"):
        super().__init__()
        layers = []
        in_dim = input_dim
        act = {"relu": nn.ReLU, "gelu": nn.GELU, "tanh": nn.Tanh}.get(activation, nn.GELU)

        for h in hidden_sizes:
            layers += [nn.Linear(in_dim, h), act()]
            if dropout and dropout > 0:
                layers += [nn.Dropout(dropout)]
            in_dim = h

        layers += [nn.Linear(in_dim, num_classes)]
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        E = self.net(x)
        return E

def energy_nll_loss(E, y):
    Ey = E.gather(1, y.view(-1,1)).squeeze(1)            
    lse = torch.logsumexp(-E, dim=1)                  
    return (Ey + lse).mean()

@torch.no_grad()
def batched_energy_scores(model, loader, device, temperature: float = 1.0):
    model.eval()
    all_scores = []
    for batch in loader:
        X = batch[0].to(device)
        logits = -model(X) 
        if temperature != 1.0:
            logits = logits / temperature
        energy = -torch.logsumexp(logits, dim=1)  
        all_scores.append(energy.detach().cpu())
    return torch.cat(all_scores, dim=0).numpy()  

def fpr_at_tpr(id_scores: np.ndarray, ood_scores: np.ndarray, tpr: float = 0.95):
    thr = np.quantile(id_scores, tpr)
    fpr = float(np.mean(ood_scores <= thr))
    return fpr, thr

def expected_calibration_error(y_true, y_score, n_bins=15):
    y_true = np.asarray(y_true)
    y_score = np.asarray(y_score)
    confidences = y_score.max(axis=1)          
    predictions = y_score.argmax(axis=1)       
    accuracies  = (predictions == y_true).astype(float)
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0

    for i in range(n_bins):
        if i < n_bins - 1:
            inds = (confidences >= bin_edges[i]) & (confidences < bin_edges[i+1])
        else:
            inds = (confidences >= bin_edges[i]) & (confidences <= bin_edges[i+1])

        if np.any(inds):
            bin_acc  = accuracies[inds].mean()
            bin_conf = confidences[inds].mean()
            weight   = inds.mean()  # |B_b| / n
            ece += weight * abs(bin_acc - bin_conf)

    return float(ece)

def main_energy_mlp(train_loader, test_loader, input_dim, num_classes,
                    hidden_sizes=[512], dropout=0.0, lr=1e-3, 
                    weight_decay=1e-4, epochs=50, max_grad_norm: float = 1, 
                    ood_loader=None, tpr_target=0.95, energy_temperature: float = 1.0):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = EnergyMLP(input_dim=input_dim, num_classes=num_classes,
                      hidden_sizes=hidden_sizes, dropout=dropout).to(device)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    acc, f1m, nll, ece, roc, pr_auc, params = None, None, None, None, None, None, None
    ood_auroc, ood_aupr, fpr95, thr = None, None, None, None

    for epoch in range(epochs):
        model.train()
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            E = model(X)                           
            loss = energy_nll_loss(E, y)         
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()

        if (epoch + 1) == epochs:
            model.eval()
            y_true, y_pred, y_score = [], [], []
            nll_sum, n_samples = 0.0, 0

            with torch.no_grad():
                for X, y in test_loader:
                    X, y = X.to(device), y.to(device)
                    E = model(X)
                    logits = -E
                    probs = torch.softmax(logits, dim=1)
                    nll_sum += F.cross_entropy(logits, y, reduction="sum").item()
                    n_samples += y.size(0)
                    y_true.extend(y.cpu().numpy())
                    y_pred.extend(logits.argmax(1).cpu().numpy())
                    y_score.extend(probs.cpu().numpy())

            y_true = np.array(y_true)
            y_pred = np.array(y_pred)
            y_score = np.array(y_score)
            acc = accuracy_score(y_true, y_pred)
            f1m = f1_score(y_true, y_pred, average="macro")
            nll = nll_sum / n_samples
            ece = expected_calibration_error(y_true, y_score, n_bins=15)

            if len(set(y_true)) == 2:
                roc = roc_auc_score(y_true, y_score[:, 1])
            else:
                roc = roc_auc_score(y_true, y_score, multi_class="ovr")

            pr_auc = None
            if len(set(y_true)) == 2:
                p, r, _ = precision_recall_curve(y_true, y_score[:, 1])
                pr_auc = auc(r, p)

            params = sum(p.numel() for p in model.parameters())

            if ood_loader is not None:
                id_energy = batched_energy_scores(model, test_loader, device, energy_temperature)
                ood_energy = batched_energy_scores(model, ood_loader, device, energy_temperature)
                labels = np.concatenate([np.zeros_like(id_energy), np.ones_like(ood_energy)])
                scores = np.concatenate([id_energy, ood_energy])
                ood_auroc = roc_auc_score(labels, scores)
                ood_aupr = average_precision_score(labels, scores)
                fpr95, thr = fpr_at_tpr(id_energy, ood_energy, tpr=tpr_target)

        elif (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1}: loss={loss.item():.4f}")

    metrics = {
        'accuracy': acc,
        'f1_macro': f1m,
        'nll': nll,
        'ece': ece,
        'roc_auc': roc,
        'pr_auc': pr_auc,
        'params': params
    }
    
    if ood_loader is not None:
        metrics.update({
            'ood_auroc': ood_auroc,
            'ood_aupr': ood_aupr,
            'fpr95': fpr95,
            'threshold': thr
        })
    
    return metrics