#!/usr/bin/env python
# training.py - Training functions
# --------------------------------------------------------------------
import time
import torch
from torch.optim.lr_scheduler import MultiStepLR

from models import ResMLP, class_from_z
from metrics import compute_onc_metrics_from_net, compute_additional_metrics

def train_clm_explicit(Xtr, ytr, Xte, yte, loss_layer, tag, device, lr=1e-2):
    epochs = 5000
    start_time = time.time()
    
    trL = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(Xtr), torch.tensor(ytr)), batch_size=2048, shuffle=True)
    teL = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(Xte), torch.tensor(yte)), batch_size=2048, shuffle=False)

    net = ResMLP(Xtr.shape[1], width=128, depth=4).to(device)
    loss_layer = loss_layer.to(device)
    
    param_groups = [{'params': net.parameters(), 'weight_decay': 5e-3},]
    if hasattr(loss_layer, 'delta'):
        param_groups.append({'params': [loss_layer.delta], 'weight_decay': 0.0})

    opt = torch.optim.Adam(param_groups, lr=lr)

    scheduler = MultiStepLR(opt, milestones=[200, 800, 3000], gamma=0.1)
    
    tr_nll, te_nll, tr_mae, te_mae = [], [], [], []
    onc1_tr, onc21_tr, onc22_tr, onc3_tr = [], [], [], []
    onc1_te, onc21_te, onc22_te, onc3_te = [], [], [], []
    # New metrics
    tr_acc, te_acc = [], []
    tr_within1, te_within1 = [], []
    tr_min_sens, te_min_sens = [], []
    tr_qwk, te_qwk = [], []

    for epoch in range(1, epochs+1):
        net.train()

        for xb, yb in trL:
            xb, yb = xb.to(device), yb.to(device)
            nll, *_ = loss_layer(net(xb), yb)
            opt.zero_grad()
            nll.backward()
            opt.step()
        scheduler.step()

        net.eval()
        n_tr, N_tr = 0, 0
        n_te, N_te = 0, 0
        tr_mae_sum, te_mae_sum = 0, 0
        
        with torch.no_grad():
            for xb, yb in trL:
                xb, yb = xb.to(device), yb.to(device)
                nll, prob, *extra = loss_layer(net(xb), yb)
                N_tr += len(yb); n_tr += nll.item()*len(yb)
                
                link_type = getattr(loss_layer, 'link', None)
                if hasattr(link_type, '__class__'):
                    link_name = link_type.__class__.__name__
                   
                    if link_name == "LogitLink":
                        link_type = "logit"
                    elif link_name == "ProbitLink":
                        link_type = "probit"
                    else:
                        link_type = "logit"
                else:
                    link_type = "logit"
                
                if hasattr(loss_layer, '_b'):
                    raw_b = loss_layer._b().detach().cpu()
                    b_eff = raw_b
                else:
                    raw_b = loss_layer.b.detach().cpu()
                    b_eff = raw_b[1:-1]
                
                pred = class_from_z(net(xb).cpu(), b_eff)
                yb_cpu = yb.cpu()
                
                tr_mae_sum += torch.abs(pred - yb_cpu).sum().item()
                
            for xb, yb in teL:
                xb, yb = xb.to(device), yb.to(device)
                nll, prob, *extra = loss_layer(net(xb), yb)
                N_te += len(yb); n_te += nll.item()*len(yb)
                
                link_type = getattr(loss_layer, 'link', None)
                if hasattr(link_type, '__class__'):
                    link_name = link_type.__class__.__name__

                    if link_name == "LogitLink":
                        link_type = "logit"
                    elif link_name == "ProbitLink":
                        link_type = "probit"
                    else:
                        link_type = "logit"
                else:
                    link_type = "logit"
                
                if hasattr(loss_layer, '_b'):
                    raw_b = loss_layer._b().detach().cpu()
                    b_eff = raw_b
                else:
                    raw_b = loss_layer.b.detach().cpu()
                    b_eff = raw_b[1:-1]
                
                pred = class_from_z(net(xb).cpu(), b_eff)
                yb_cpu = yb.cpu()
                
                te_mae_sum += torch.abs(pred - yb_cpu).sum().item()  

        tr_nll.append(n_tr/N_tr); te_nll.append(n_te/N_te)
        tr_mae.append(tr_mae_sum/N_tr); te_mae.append(te_mae_sum/N_te)

        # Compute predictions for all data for additional metrics
        with torch.no_grad():
            # Training set predictions
            tr_preds_all = []
            tr_labels_all = []
            for xb, yb in trL:
                xb = xb.to(device)
                if hasattr(loss_layer, '_b'):
                    b_eff = loss_layer._b().detach().cpu()
                else:
                    b_eff = loss_layer.b.detach().cpu()[1:-1]
                pred = class_from_z(net(xb).cpu(), b_eff)
                tr_preds_all.append(pred)
                tr_labels_all.append(yb)
            tr_preds_all = torch.cat(tr_preds_all).numpy()
            tr_labels_all = torch.cat(tr_labels_all).numpy()

            # Test set predictions
            te_preds_all = []
            te_labels_all = []
            for xb, yb in teL:
                xb = xb.to(device)
                if hasattr(loss_layer, '_b'):
                    b_eff = loss_layer._b().detach().cpu()
                else:
                    b_eff = loss_layer.b.detach().cpu()[1:-1]
                pred = class_from_z(net(xb).cpu(), b_eff)
                te_preds_all.append(pred)
                te_labels_all.append(yb)
            te_preds_all = torch.cat(te_preds_all).numpy()
            te_labels_all = torch.cat(te_labels_all).numpy()

        # Compute additional metrics
        tr_metrics = compute_additional_metrics(tr_labels_all, tr_preds_all)
        te_metrics = compute_additional_metrics(te_labels_all, te_preds_all)

        tr_acc.append(tr_metrics['accuracy'])
        te_acc.append(te_metrics['accuracy'])
        tr_within1.append(tr_metrics['within_1_acc'])
        te_within1.append(te_metrics['within_1_acc'])
        tr_min_sens.append(tr_metrics['min_sensitivity'])
        te_min_sens.append(te_metrics['min_sensitivity'])
        tr_qwk.append(tr_metrics['qwk'])
        te_qwk.append(te_metrics['qwk'])

        onctr = compute_onc_metrics_from_net(net, loss_layer, Xtr, torch.tensor(ytr))
        oncte = compute_onc_metrics_from_net(net, loss_layer, Xte, torch.tensor(yte))
        onc1_tr.append(   onctr['ONC1']   )
        onc21_tr.append( onctr['ONC2-1'] )
        onc22_tr.append( onctr['ONC2-2'] )
        onc3_tr.append(   onctr['ONC3']   )
        onc1_te.append(   oncte['ONC1']   )
        onc21_te.append( oncte['ONC2-1'] )
        onc22_te.append( oncte['ONC2-2'] )
        onc3_te.append(   oncte['ONC3']   )

        if epoch % 200 == 0 or epoch == epochs:
            elapsed = time.time() - start_time
            current_lr = scheduler.get_last_lr()[0]
            print(f"[{tag}] Epoch {epoch}/{epochs} | "
                  f"lr={current_lr:.1e} | "
                  f"train NLL={tr_nll[-1]:.4f}, MAE={tr_mae[-1]:.4f}, ACC={tr_acc[-1]:.4f} | "
                  f"val NLL={te_nll[-1]:.4f}, MAE={te_mae[-1]:.4f}, ACC={te_acc[-1]:.4f} | "
                  f"time={elapsed:.1f}s")
            
    link_type = getattr(loss_layer, 'link', None)
    if hasattr(link_type, '__class__'):
        link_name = link_type.__class__.__name__
        if link_name == "LogitLink":
            link_type = "logit"
        elif link_name == "ProbitLink":
            link_type = "probit"
        else:
            link_type = "logit"
    else:
        link_type = "logit"

    if hasattr(loss_layer, '_b'):
        b_final = loss_layer._b().cpu()
    else:
        raw_b = loss_layer.b.cpu()
        b_final = raw_b[1:-1]
        
    with torch.no_grad():
        z_val = net(torch.tensor(Xte, device=device)).detach().cpu().view(-1,1)
    pred_val = class_from_z(z_val, b_final)
    val_nll = te_nll[-1]

    return (
        net.cpu(), loss_layer.cpu(),
        val_nll,
        Xtr, ytr,
        tr_nll, te_nll,
        onc1_tr, onc21_tr, onc22_tr, onc3_tr,
        onc1_te, onc21_te, onc22_te, onc3_te,
        tr_mae, te_mae,
        tr_acc, te_acc,
        tr_within1, te_within1,
        tr_min_sens, te_min_sens,
        tr_qwk, te_qwk,
    )