import torch
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, confusion_matrix, ConfusionMatrixDisplay
from tqdm import tqdm
import matplotlib.pyplot as plt
import traceback


def run_epoch(
    epoch,
    model,
    dloader,
    loss_fn,
    optimizer,
    scheduler,
    device,
    running_loss,
    n_samples,
    all_probas,
    all_y_batch,
    running_acc,
    writer,
    is_train,
    cleanup=True,
    train_after_oom=False,
    log_interval=50,
    scaler: torch.cuda.amp.GradScaler | None = None,
    enable_plots: bool = False,
):
    logging_dir = "Train" if is_train else "Test"
    print()
    print(f"{logging_dir} run")

    autocast_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16

    with torch.autograd.set_detect_anomaly(False):
        for i, (data_batch, y_batch) in enumerate(tqdm(dloader)):

            try:
                y_batch = y_batch.to(device, non_blocking=True)

                # Forward + loss under autocast for speed
                with torch.cuda.amp.autocast(enabled=torch.cuda.is_available(), dtype=autocast_dtype):
                    outputs = model(data_batch, y_batch.shape[0])
                    if isinstance(outputs, tuple):
                        outputs = outputs[0]
                    loss = loss_fn(outputs.view(-1), y_batch)

                if is_train:
                    optimizer.zero_grad(set_to_none=True)
                    if scaler is not None:
                        scaler.scale(loss).backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                        optimizer.step()

                running_loss += loss.item()
                n_samples += len(y_batch)

                # Compute positive-class probabilities exactly as in Raindrop:
                # - If model outputs 2 logits, use softmax and take class-1 probability
                # - If model outputs a single logit, use sigmoid
                if outputs.dim() == 2 and outputs.size(1) == 2:
                    # Match Raindrop: apply sigmoid to each logit, then take class-1 score
                    pos_probs = torch.sigmoid(outputs)[:, 1]
                else:
                    pos_probs = torch.sigmoid(outputs.view(-1))

                preds = pos_probs > 0.5
                running_acc += (preds == y_batch).sum().item()

                # Collect for epoch-level metrics (AUC/AUPRC) using positive-class probability
                all_probas.append(pos_probs.detach().cpu().numpy())
                all_y_batch.append(y_batch.detach().cpu().numpy())

                if writer is not None and ((i + 1) % log_interval == 0):
                    intermediate_probas = np.concatenate(all_probas)
                    intermediate_y_batch = np.concatenate(all_y_batch)
                    intermediate_auc = roc_auc_score(intermediate_y_batch, intermediate_probas)
                    intermediate_acc = running_acc / n_samples
            except Exception as e:
                print(f"Skipping batch {i} due to error: {e}")
                traceback.print_exc()
                continue

        # Aggregate AUC computation
        all_probas = np.concatenate(all_probas) if len(all_probas) > 0 else np.array([])
        all_y_batch = np.concatenate(all_y_batch) if len(all_y_batch) > 0 else np.array([])

        if len(all_y_batch) > 0:
            auc = roc_auc_score(all_y_batch, all_probas)
            auprc = average_precision_score(all_y_batch, all_probas)
            avg_acc = running_acc / n_samples
        else:
            auc = 0.0
            auprc = 0.0
            avg_acc = 0.0
        avg_loss = running_loss / max(1, len(dloader))

        # Only compute and plot confusion matrix when writer exists (typically eval) to save time
        # Scalars are always logged when writer exists
        if writer is not None:
            writer.add_scalar(f"{logging_dir} Loss/Epoch", avg_loss, epoch)
            writer.add_scalar(f"{logging_dir} Accuracy/Epoch", avg_acc, epoch)
            writer.add_scalar(f"{logging_dir} AUC/Epoch", auc, epoch)
            writer.add_scalar(f"{logging_dir} AUPRC/Epoch", auprc, epoch)

            # Only compute and plot confusion matrix when explicitly enabled
            if enable_plots and not is_train and len(all_y_batch) > 0:
                all_preds = (all_probas > 0.5).astype(int)
                cm = confusion_matrix(all_y_batch, all_preds)
                cm_percentage = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100

                fig, ax = plt.subplots(figsize=(6, 6))
                disp = ConfusionMatrixDisplay(confusion_matrix=cm_percentage, display_labels=[0, 1])
                disp.plot(ax=ax, cmap="Blues", values_format=".2f")
                plt.title(f"{logging_dir} Confusion Matrix - Epoch {epoch}")
                plt.tight_layout()
                writer.add_figure(f"{logging_dir} Confusion Matrix/Epoch {epoch}", fig, global_step=epoch)
                plt.close(fig)

    return avg_loss, avg_acc, auc, auprc


def train_epoch(epoch, model, dloader, loss_fn, optimizer, scheduler, device, train_after_oom=False, writer=None, scaler: torch.cuda.amp.GradScaler | None = None, enable_plots: bool = False):
    model.train()
    running_loss = 0.0
    n_samples = 0
    all_probas = []
    all_y_batch = []
    running_acc = 0.0

    avg_loss, avg_acc, auc, auprc = run_epoch(
        epoch,
        model,
        dloader,
        loss_fn,
        optimizer,
        scheduler,
        device,
        running_loss,
        n_samples,
        all_probas,
        all_y_batch,
        running_acc,
        writer,
        is_train=True,
        scaler=scaler,
        enable_plots=enable_plots,
    )

    return avg_loss, avg_acc, auc, auprc


def test_epoch(epoch, model, dloader, loss_fn, device, train_after_oom=False, writer=None, enable_plots: bool = False):
    with torch.no_grad():
        running_loss = 0.0
        n_samples = 0
        all_probas = []
        all_y_batch = []
        running_acc = 0.0

        model.eval()
    
        avg_loss, avg_acc, auc, auprc = run_epoch(
            epoch,
            model,
            dloader,
            loss_fn,
            None,
            None,
            device,
            running_loss,
            n_samples,
            all_probas,
            all_y_batch,
            running_acc,
            writer,
            is_train=False,
            enable_plots=enable_plots,
        )

        return avg_loss, avg_acc, auc, auprc