import os
import math
import time
import random
import argparse
import numpy as np
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import vmap
from torch.utils.data import DataLoader
from torch import autocast 
import sklearn.metrics
from datasets import load_dataset
from transformers import (
    BertTokenizerFast,
    BertForSequenceClassification,
    AdamW,
    get_linear_schedule_with_warmup,
    DataCollatorWithPadding,
)

from tqdm import tqdm
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from options import options
from utils import (
    adjust_lambda_reg_sin,
    pca,
)
from poly.wd_regularization_torch import (
    polynomial_regularization,
    precompute_chebyshev_matrix,
)


# =============================
# 0) Utils
# =============================
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def accuracy_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> float:
    preds = torch.argmax(logits, dim=-1)
    return (preds == labels).float().mean().item()


def pearson_corr_from_logits(logits: torch.Tensor, labels: torch.Tensor, eps: float = 1e-8) -> float:
    x = logits.view(-1).detach().float().cpu().numpy()
    y = labels.view(-1).detach().float().cpu().numpy()
    x = x - x.mean()
    y = y - y.mean()
    denom = (np.sqrt((x * x).mean()) * np.sqrt((y * y).mean())) + eps
    return float((x * y).mean() / denom)


def _rankdata_average_ties(x: np.ndarray) -> np.ndarray:
    """Tie-aware rankdata (1..N), like scipy.stats.rankdata(method='average')."""
    x = np.asarray(x)
    n = x.shape[0]
    if n == 0:
        return x.astype(np.float64)

    # Stable sort so equal values keep deterministic order
    sorter = np.argsort(x, kind="mergesort")
    x_sorted = x[sorter]

    # Find group boundaries of equal values
    unequal = np.empty(n, dtype=bool)
    unequal[0] = True
    unequal[1:] = x_sorted[1:] != x_sorted[:-1]
    group_starts = np.nonzero(unequal)[0]
    group_starts = np.concatenate([group_starts, np.array([n], dtype=group_starts.dtype)])

    ranks_sorted = np.empty(n, dtype=np.float64)
    for i in range(group_starts.shape[0] - 1):
        start = int(group_starts[i])
        end = int(group_starts[i + 1])
        # average rank over [start, end)
        # convert 0-based positions to 1-based ranks
        avg_rank = 0.5 * (start + (end - 1)) + 1.0
        ranks_sorted[start:end] = avg_rank

    ranks = np.empty(n, dtype=np.float64)
    ranks[sorter] = ranks_sorted
    return ranks


def spearman_corr_from_logits(logits: torch.Tensor, labels: torch.Tensor, eps: float = 1e-8) -> float:
    x = logits.view(-1).detach().float().cpu().numpy()
    y = labels.view(-1).detach().float().cpu().numpy()
    if x.size == 0:
        return 0.0

    rx = _rankdata_average_ties(x)
    ry = _rankdata_average_ties(y)

    rx = rx - rx.mean()
    ry = ry - ry.mean()
    denom = (np.sqrt((rx * rx).mean()) * np.sqrt((ry * ry).mean())) + eps
    return float((rx * ry).mean() / denom)


def matthews_corr_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> float:
    preds = torch.argmax(logits, dim=-1).detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    return sklearn.metrics.matthews_corrcoef(labels, preds)


def metric_from_logits(task: str, logits: torch.Tensor, labels: torch.Tensor) -> float:
    # Metrics: Matthews Corr for CoLA, Accuracy for other classification, and Spearman for STS-B.
    task = task.lower()
    if task == "cola":
        return matthews_corr_from_logits(logits, labels)
    if task == "stsb":
        return spearman_corr_from_logits(logits, labels)
    return accuracy_from_logits(logits, labels)


def safe_filename(name: str) -> str:
    return "".join(ch if (ch.isalnum() or ch in ("-", "_", ".")) else "_" for ch in str(name))


def parse_args():
    parser = options()
    parser.add_argument(
        "--task",
        type=str,
        default="rte",
        choices=["cola", "sst2", "mrpc", "stsb", "qqp", "mnli", "qnli", "rte", "wnli"],
        help="Which GLUE sub-task to run",
    )
    parser.add_argument("--model_name", type=str, default="bert-base-cased")
    parser.add_argument("--max_len", type=int, default=128)
    parser.add_argument("--eval_batch_size", type=int, default=32)
    parser.add_argument("--warmup_ratio", type=float, default=0.1)
    parser.add_argument("--log_every", type=int, default=50)
    parser.add_argument("--grad_clip", type=float, default=1.0)
    parser.add_argument("--output_dir", type=str, default="./bert_mnli_wd")
    parser.add_argument("--num_workers", type=int, default=2)
    parser.add_argument("--mixup_mode", type=str, default=None, help="Mixup mode: 'embeddings' or 'pooled'")
    return parser.parse_args()


def forward_from_embeddings(model: BertForSequenceClassification, embeddings: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    """Forward the encoder + pooler + classifier given pre-computed embeddings."""
    extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
    # Note: Ensure the dtype of the mask matches that of the embeddings (e.g., bf16)
    extended_attention_mask = extended_attention_mask.to(dtype=embeddings.dtype)
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

    encoder_outputs = model.bert.encoder(
        embeddings,
        attention_mask=extended_attention_mask,
        head_mask=[None] * model.config.num_hidden_layers,
    )
    sequence_output = encoder_outputs[0]
    pooled_output = model.bert.pooler(sequence_output)
    pooled_output = model.dropout(pooled_output)
    logits = model.classifier(pooled_output)
    return logits


def sample_alpha_sequences(args, resolution: int, device: torch.device, num_pairs: int):
    """Create alpha sequences for interpolation and reg fitting."""
    if resolution < 2:
        raise ValueError("resolution must be >= 2 for WD regularization")

    n_inner = resolution - 2
    if args.random_alpha:
        full_alpha_list = []
        inner_alpha_list = []
        for _ in range(num_pairs):
            if n_inner > 0:
                steps = torch.arange(n_inner, device=device, dtype=torch.float32)
                jitter = torch.rand(n_inner, device=device, dtype=torch.float32)
                raw_fraction = (steps + jitter) / n_inner
                padding = 1.0 / resolution
                compressed_fraction = padding + raw_fraction * (1 - 2 * padding)
                theta = compressed_fraction * math.pi
                alpha_inner = -torch.cos(theta)
                alpha_inner, _ = torch.sort(alpha_inner)
            else:
                alpha_inner = torch.tensor([], device=device)

            full_alpha = torch.cat(
                [torch.tensor([-1.0], device=device), alpha_inner, torch.tensor([1.0], device=device)]
            )
            full_alpha_list.append(full_alpha)

            if n_inner > 0:
                inner_mix = (alpha_inner + 1) * 0.5
                inner_alpha_list.append(inner_mix.view(-1, 1, 1))

        full_alpha_tensor = torch.stack(full_alpha_list)
        inner_alpha_tensor = torch.stack(inner_alpha_list) if n_inner > 0 else None
    else:
        cached = precompute_chebyshev_matrix(resolution, args.max_degree, device)
        alpha_values = cached["alpha_values"].to(torch.float32)
        full_alpha_tensor = alpha_values.unsqueeze(0).expand(num_pairs, -1)
        if n_inner > 0:
            inner = alpha_values[1:-1]
            inner_alpha_tensor = ((inner + 1) * 0.5).view(1, n_inner, 1, 1).expand(num_pairs, -1, -1, -1)
        else:
            inner_alpha_tensor = None

    return full_alpha_tensor, inner_alpha_tensor

# ...existing code...

def compute_wd_regularization(
    model,
    batch,
    args,
    num_labels: int,
    device: torch.device,
    lambda_reg: float,
    mixup_mode: str = "embeddings",  
) -> torch.Tensor:
    """Compute WD regularization term for a batch of tokenized inputs.

    mixup_mode:
      - "embeddings": Interpolate token embeddings before the encoder, then pass through encoder+pooler+classifier (using forward_from_embeddings)
      - "pooled":     Pass through encoder to get pooled_output, then interpolate on pooled representations and directly use classifier
    """
    is_regression = (num_labels == 1)

    resolution = int(args.resolution)
    if resolution < 2:
        return torch.tensor(0.0, device=device)

    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    labels = batch["labels"].to(device)

    token_type_ids = batch.get("token_type_ids", None)
    if token_type_ids is not None:
        token_type_ids = token_type_ids.to(device)

    batch_size = input_ids.size(0)
    if batch_size < 2:
        return torch.tensor(0.0, device=device)

    num_pairs = min(args.nums_pairs, batch_size * (batch_size - 1) // 2)
    if num_pairs == 0:
        return torch.tensor(0.0, device=device)

    # Randomly select sample pairs
    x1_indices = torch.randint(low=0, high=batch_size, size=(num_pairs,), device=device)
    offset = torch.randint(low=1, high=batch_size, size=(num_pairs,), device=device)
    x2_indices = (x1_indices + offset) % batch_size

    full_alpha, inner_alpha = sample_alpha_sequences(args, resolution, device, num_pairs)
    n_inner = resolution - 2

    if args.mixup_mode is not None:
        mixup_mode = args.mixup_mode
    mixup_mode = (mixup_mode or "pooled").lower()
    if mixup_mode in ("embedding", "emb", "embeddings"):
        mixup_mode = "embeddings"
    elif mixup_mode in ("pooled", "pool", "encoder_after", "post_encoder"):
        mixup_mode = "pooled"
    else:
        raise ValueError(f"Unknown mixup_mode={mixup_mode}. Use 'embeddings' or 'pooled'.")

    if n_inner > 0 and inner_alpha is None:
        return torch.tensor(0.0, device=device)

    # ==========================================================
    # 1) Compute inner outputs (choose one of two mixup modes)
    #    - Classification: inner_probs = softmax(logits)
    #    - Regression: inner_vals  = logits (no softmax), shape [..., 1]
    # ==========================================================
    if mixup_mode == "embeddings":
        # Before encoder: interpolate embeddings, then pass through encoder+pooler+classifier
        embeddings = model.bert.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)

        emb1 = embeddings[x1_indices]  # [P, L, H]
        emb2 = embeddings[x2_indices]  # [P, L, H]
        mask1 = attention_mask[x1_indices]
        mask2 = attention_mask[x2_indices]

        if n_inner > 0:
            diff = emb2 - emb1  # [P, L, H]
            # inner_alpha: [P, n_inner, 1, 1] broadcast to [P, n_inner, L, H]
            inner_emb = emb1.unsqueeze(1) + inner_alpha.to(emb1.dtype) * diff.unsqueeze(1)
            inner_emb_flat = inner_emb.reshape(-1, emb1.size(1), emb1.size(2))  # [P*n_inner, L, H]

            # Take union mask (convert to bool then or)
            union_mask = mask1.to(torch.bool) | mask2.to(torch.bool)  # [P, L]
            inner_mask_flat = (
                union_mask.unsqueeze(1).expand(-1, n_inner, -1).reshape(-1, mask1.size(1))
            )  # [P*n_inner, L]

            inner_logits = forward_from_embeddings(model, inner_emb_flat, inner_mask_flat)
            if is_regression:
                inner_probs = inner_logits.reshape(num_pairs, n_inner, 1)
            else:
                inner_probs = F.softmax(inner_logits, dim=-1).reshape(num_pairs, n_inner, num_labels)
        else:
            inner_dim = 1 if is_regression else num_labels
            inner_probs = torch.zeros(num_pairs, 0, inner_dim, device=device, dtype=embeddings.dtype)

    else:
        # After encoder: pass through encoder to get pooled_output, then interpolate on pooled representations and directly use classifier
        bert_out = model.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=True,
        )
        pooled_all = bert_out.pooler_output  # [B, H]

        pooled1 = pooled_all[x1_indices]  # [P, H]
        pooled2 = pooled_all[x2_indices]  # [P, H]

        if n_inner > 0:
            # inner_alpha: [P, n_inner, 1, 1] -> lam: [P, n_inner, 1]
            lam = inner_alpha.squeeze(-1).squeeze(-1).to(dtype=pooled1.dtype).unsqueeze(-1)

            
            pooled1 = model.dropout(pooled1)
            pooled2 = model.dropout(pooled2)
            inner_pooled = lam * pooled1.unsqueeze(1) + (1.0 - lam) * pooled2.unsqueeze(1)  # [P, n_inner, H]
            inner_pooled_flat = inner_pooled.reshape(-1, pooled1.size(-1))  # [P*n_inner, H]

            # inner_pooled_flat = model.dropout(inner_pooled_flat)
            inner_logits = model.classifier(inner_pooled_flat)  # [P*n_inner, C]
            if is_regression:
                inner_probs = inner_logits.reshape(num_pairs, n_inner, 1)
            else:
                inner_probs = F.softmax(inner_logits, dim=-1).reshape(num_pairs, n_inner, num_labels)
        else:
            inner_dim = 1 if is_regression else num_labels
            inner_probs = torch.zeros(num_pairs, 0, inner_dim, device=device, dtype=pooled_all.dtype)

    # ==========================================================
    # 2) endpoints：
    #    - Classification: label one-hot
    #    - Regression: use true label directly (1D)
    # ==========================================================
    label1 = labels[x1_indices]
    label2 = labels[x2_indices]
    if is_regression:
        endpoint_1 = label1.to(dtype=inner_probs.dtype).unsqueeze(1).unsqueeze(-1)  # [P, 1, 1]
        endpoint_2 = label2.to(dtype=inner_probs.dtype).unsqueeze(1).unsqueeze(-1)  # [P, 1, 1]
    else:
        endpoint_1 = F.one_hot(label1, num_classes=num_labels).to(dtype=inner_probs.dtype).unsqueeze(1)
        endpoint_2 = F.one_hot(label2, num_classes=num_labels).to(dtype=inner_probs.dtype).unsqueeze(1)

    full_sequence = torch.cat([endpoint_1, inner_probs, endpoint_2], dim=1)

    if int(getattr(args, "pca_reg", 0) or 0) > 0 and full_sequence.size(-1) > 1:
        full_sequence = pca(full_sequence, num_pairs, k=int(args.pca_reg))

    # ==========================================================
    # 3) Polynomial regularization (vmap)
    # ==========================================================
    alpha_inputs = full_alpha if args.random_alpha else full_alpha[0]
    alpha_in_dims = 0 if args.random_alpha else None

    def reg_wrapper(alpha, sample_output):
        return polynomial_regularization(
            alpha,
            sample_output,
            resolution,
            args.miu,
            args.max_degree,
            have_const=not args.remove_const,
            use_norm=args.use_norm,
            random_alpha=args.random_alpha,
            square=args.square,
            degree_mode=args.degree_mode,
        )

    batched_reg = vmap(reg_wrapper, in_dims=(alpha_in_dims, 0))
    reg_terms = batched_reg(alpha_inputs, full_sequence)

    # Returning float32 is more stable (especially when training with bf16 autocast)
    return torch.mean(reg_terms).to(torch.float32)

# ...existing code...


# =============================
# 1) Prepare MNLI dataset
# =============================
def build_dataloaders(args):
    task = args.task.lower()
    raw = load_dataset("glue", task)
    tokenizer = BertTokenizerFast.from_pretrained(args.model_name)

    # def tokenize_fn(batch):
    #     premise = batch["premise"]
    #     hypothesis = batch["hypothesis"]
    #     return tokenizer(
    #         premise,
    #         hypothesis,
    #         truncation=True,
    #         max_length=args.max_len,
    #     )

    # tokenized = raw.map(tokenize_fn, batched=True)
    # tokenized = tokenized.remove_columns(["premise", "hypothesis", "idx"])
    # tokenized = tokenized.rename_column("label", "labels")
    # tokenized.set_format("torch")
    def _coalesce_list(values, fill=""):
        # Some fields in datasets (e.g., QQP) may have None; tokenizer cannot handle None
        return [fill if v is None else v for v in values]

    def tokenize_fn(batch):
        if task in {"cola", "sst2"}:
            s1 = _coalesce_list(batch["sentence"])
            return tokenizer(s1, truncation=True, max_length=args.max_len)
        if task in {"mrpc", "stsb", "rte", "wnli"}:
            s1 = _coalesce_list(batch["sentence1"])
            s2 = _coalesce_list(batch["sentence2"])
            return tokenizer(s1, s2, truncation=True, max_length=args.max_len)
        if task == "qqp":
            q1 = _coalesce_list(batch["question1"])
            q2 = _coalesce_list(batch["question2"])
            return tokenizer(q1, q2, truncation=True, max_length=args.max_len)
        if task == "mnli":
            p = _coalesce_list(batch["premise"])
            h = _coalesce_list(batch["hypothesis"])
            return tokenizer(p, h, truncation=True, max_length=args.max_len)
        if task == "qnli":
            q = _coalesce_list(batch["question"])
            s = _coalesce_list(batch["sentence"])
            return tokenizer(q, s, truncation=True, max_length=args.max_len)
        raise ValueError(f"Unsupported task: {task}")

    tokenized = raw.map(tokenize_fn, batched=True)

    # Uniformly remove irrelevant columns: different GLUE sub-tasks have different fields
    remove_cols = []
    for col in [
        "sentence",
        "sentence1",
        "sentence2",
        "question1",
        "question2",
        "premise",
        "hypothesis",
        "question",
        "idx",
    ]:
        if col in tokenized["train"].column_names:
            remove_cols.append(col)
    if len(remove_cols) > 0:
        tokenized = tokenized.remove_columns(remove_cols)

    tokenized = tokenized.rename_column("label", "labels")

    # STS-B labels are regression scores, ensure float32
    if task == "stsb":
        def cast_label(batch):
            return {"labels": [float(x) for x in batch["labels"]]}

        tokenized = tokenized.map(cast_label, batched=True)

    tokenized.set_format("torch")

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    train_loader = DataLoader(
        tokenized["train"],
        batch_size=args.bs,
        shuffle=True,
        num_workers=args.num_workers,
        collate_fn=data_collator,
        pin_memory=True,
    )

    # MNLI validation set has matched/mismatched; here we default to using matched as validation
    val_split = "validation_matched" if task == "mnli" and "validation_matched" in tokenized else "validation"

    val_loader = DataLoader(
        tokenized[val_split],
        batch_size=args.eval_batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        collate_fn=data_collator,
        pin_memory=True,
    )

    return tokenizer, train_loader, val_loader


# =============================
# 2) Train / Eval loops
# =============================
def train_one_epoch(
    model,
    dataloader,
    val_dataloader,
    optimizer,
    scheduler,
    device,
    epoch_idx: int,
    args,
    num_labels: int,
    lambda_reg: float,
    best_state: dict = None,
    log_every: int = 20,
    grad_clip: float = 1.0,
):
    model.train()

    running_loss = 0.0
    running_acc = 0.0
    running_reg = 0.0
    seen = 0

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Train Epoch {epoch_idx}")
    for step, batch in pbar:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        token_type_ids = batch.get("token_type_ids", None)
        if token_type_ids is not None:
            token_type_ids = token_type_ids.to(device)

        optimizer.zero_grad(set_to_none=True)

        # ==========================================
        # Modification: use autocast context to enable bf16
        # ==========================================
        # Note: bf16 does not require GradScaler (usually needed for float16)
        # device_type='cuda', dtype=torch.bfloat16
        with autocast(device_type="cuda", dtype=torch.bfloat16):
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                labels=labels,
            )

            ce_loss = outputs.loss
            logits = outputs.logits
            metric = metric_from_logits(args.task, logits.float(), labels)

            
            if lambda_reg > 0:
                reg_term = compute_wd_regularization(
                    model,
                    batch,
                    args,
                    num_labels=num_labels,
                    device=device,
                    lambda_reg=lambda_reg,
                )
                loss = ce_loss + lambda_reg * reg_term
            else:
                reg_term = torch.tensor(0.0, device=device)
                loss = ce_loss
            # loss = ce_loss
            # reg_term = torch.tensor(0.0, device=device)

        # Backward remains outside of autocast
        loss.backward()

        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        optimizer.step()
        scheduler.step()

        bs = labels.size(0)
        running_loss += loss.item() * bs
        running_acc += metric * bs
        running_reg += reg_term.detach().item() * bs
        seen += bs

        if (step + 1) % log_every == 0:
            train_loss_so_far = running_loss / max(seen, 1)
            train_acc_so_far = running_acc / max(seen, 1)
            train_reg_so_far = running_reg / max(seen, 1)

            val_loss, val_acc = eval_one_epoch(
                model=model,
                dataloader=val_dataloader,
                device=device,
                epoch_idx=epoch_idx,
                task=args.task,
            )

            if best_state is not None:
                prev_best = best_state.get("best_val_acc", -1.0)
                if val_acc > prev_best:
                    best_state["best_val_acc"] = val_acc
                    best_state["best_train_loss"] = train_loss_so_far
                    best_state["best_train_acc"] = train_acc_so_far
                    best_state["best_train_reg"] = train_reg_so_far
                    best_state["best_epoch"] = epoch_idx
                    best_state["best_step"] = step + 1
                    best_state["best_val_loss"] = val_loss
                    print(f"[info] New best model found at epoch {epoch_idx}, step {step + 1} with val acc {val_acc:.4f}")

            # eval_one_epoch sets model to eval, switch back to train for continued training
            model.train()

            pbar.set_postfix(
                loss=f"{train_loss_so_far:.4f}",
                Lambda=f"{lambda_reg:.4f}",
                acc=f"{train_acc_so_far:.4f}",
                reg=f"{reg_term.item():.4f}",
                val_acc=f"{val_acc:.4f}",
                lr=f"{scheduler.get_last_lr()[0]:.2e}",
            )

    epoch_loss = running_loss / seen
    epoch_acc = running_acc / seen
    epoch_reg = running_reg / seen
    return epoch_loss, epoch_acc, epoch_reg


@torch.no_grad()
def eval_one_epoch(model, dataloader, device, epoch_idx: int, task: str):
    model.eval()

    running_loss = 0.0
    seen = 0
    
    # Used to collect all predictions and labels
    all_logits = []
    all_labels = []

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Eval  Epoch {epoch_idx}")
    for step, batch in pbar:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        token_type_ids = batch.get("token_type_ids", None)
        if token_type_ids is not None:
            token_type_ids = token_type_ids.to(device)

        # Maintain the same precision settings as training
        with autocast(device_type="cuda", dtype=torch.bfloat16):
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                labels=labels,
            )
            loss = outputs.loss
            logits = outputs.logits

        bs = labels.size(0)
        running_loss += loss.item() * bs
        seen += bs

        # Collect predictions (note: collect on CPU to save GPU memory)
        # STS-B output dimension is [Batch, 1], need to squeeze to [Batch]
        if task == "stsb":
            preds = logits.squeeze().float().cpu()
        else:
            preds = logits.float().cpu()
            
        all_logits.append(preds)
        all_labels.append(labels.cpu())

        pbar.set_postfix(loss=f"{running_loss/seen:.4f}")

    # After the loop ends, concatenate all batch results
    all_logits = torch.cat(all_logits, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    # Compute global metrics
    final_acc = metric_from_logits(task, all_logits, all_labels)

    return running_loss / seen, final_acc


# =============================
# 3) Main: MNLI fine-tune BERT
# =============================
def main():
    args = parse_args()
    print("[Args]", vars(args))  # Print all arguments
    os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(args.gpu_id)
    os.makedirs(args.output_dir, exist_ok=True)
    set_seed(args.set_seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Info] Device: {device}")
    
    # Check if bf16 is supported
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
        print("[Info] BF16 is supported and will be used.")
    else:
        print("[Warning] BF16 not supported on this device, autocast might fall back or fail.")

    # -------- Data --------
    tokenizer, train_loader, val_loader = build_dataloaders(args)

    # -------- Model --------
    task = args.task.lower()
    num_labels = 1 if task == "stsb" else (3 if task == "mnli" else 2)
    model = BertForSequenceClassification.from_pretrained(args.model_name, num_labels=num_labels)
    if task == "stsb":
        # HuggingFace infers based on num_labels, but explicitly setting here is more stable
        model.config.problem_type = "regression"
    model.to(device)

    # -------- Optimizer + Scheduler --------
    optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    total_steps = args.epochs * len(train_loader)
    warmup_steps = int(total_steps * args.warmup_ratio)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps,
    )

    print(f"[Info] total_steps={total_steps}, warmup_steps={warmup_steps}")

    best_state = {
        "best_val_acc": -1.0,
        "best_train_loss": None,
        "best_train_acc": None,
        "best_train_reg": None,
        "best_epoch": None,
        "best_step": None,
        "best_val_loss": None,
    }
    # if args.lambda_reg <= 0:
    #     args.nums_pairs = 1
        

    # -------- Training loop --------
    for epoch in range(1, args.epochs + 1):
        start = time.time()
        lambda_reg = args.lambda_reg

        train_loss, train_acc, train_reg = train_one_epoch(
            model=model,
            dataloader=train_loader,
            val_dataloader=val_loader,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            epoch_idx=epoch,
            args=args,
            num_labels=num_labels,
            lambda_reg=lambda_reg,
            best_state=best_state,
            log_every=args.log_every,
            grad_clip=args.grad_clip,
        )

        val_loss, val_acc = eval_one_epoch(
            model=model,
            dataloader=val_loader,
            device=device,
            epoch_idx=epoch,
            task=args.task,
        )

        # Also include the validation results at the end of each epoch in the best tracking
        prev_best = best_state["best_val_acc"]
        if val_acc > prev_best:
            best_state["best_val_acc"] = val_acc
            best_state["best_train_loss"] = train_loss
            best_state["best_train_acc"] = train_acc
            best_state["best_train_reg"] = train_reg
            best_state["best_epoch"] = epoch
            best_state["best_step"] = None
            best_state["best_val_loss"] = val_loss

        elapsed = time.time() - start
        print(
            f"\n[Epoch {epoch}/{args.epochs}] "
            f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} val_acc={val_acc:.4f} | "
            f"reg={train_reg:.4f} lambda={lambda_reg:.4f} | "
            f"time={elapsed/60:.1f} min\n"
        )

        # -------- Save checkpoint per epoch --------
        # ckpt_dir = os.path.join(args.output_dir, f"checkpoint-epoch{epoch}")
        # os.makedirs(ckpt_dir, exist_ok=True)
        # model.save_pretrained(ckpt_dir)
        # tokenizer.save_pretrained(ckpt_dir)

        # -------- Track best --------
        if val_acc > prev_best:
            best_dir = os.path.join(args.output_dir, "best")
            # os.makedirs(best_dir, exist_ok=True)
            # model.save_pretrained(best_dir)
            # tokenizer.save_pretrained(best_dir)
            print(
                f"[Info] New best val_acc={best_state['best_val_acc']:.4f} "
                f"(epoch={best_state.get('best_epoch')}, step={best_state.get('best_step')}). Saved to {best_dir}"
            )
            if best_state["best_train_loss"] is not None:
                print(
                    f"[Info] Corresponding train_loss={best_state['best_train_loss']:.4f}, "
                    f"train_acc={best_state['best_train_acc']:.4f}, train_reg={best_state['best_train_reg']:.4f}"
                )

    print(f"[Done] Best validation accuracy = {best_state['best_val_acc']:.4f}")
    print(f"[Done] All checkpoints saved under: {args.output_dir}")

    # -------- Save results to CSV --------
    # result_dir = os.path.join(args.output_dir, "results")
    # os.makedirs(result_dir, exist_ok=True)
    csv_filename = f"summary_reg_{safe_filename(args.task)}.csv"
    header = [
        "task",
        "lr",
        "epochs",
        "bs",
        "lambda_reg",
        "resolution",
        "nums_pairs",
        "max_degree",
        "label",
        "random_alpha",
        "output_dir",
        "best_val_acc",
        "train_loss", "train_acc", "train_reg", "set_seed", "mixup_mode"
    ]
    row = [
        args.task,
        args.lr,
        args.epochs,
        args.bs,
        args.lambda_reg,
        args.resolution,
        args.nums_pairs,
        args.max_degree,
        args.label,
        args.random_alpha,
        args.output_dir,
        best_state["best_val_acc"],
        best_state["best_train_loss"],
        best_state["best_train_acc"],
        best_state["best_train_reg"],
        args.set_seed,
        args.mixup_mode,
    ]
    write_header = not os.path.exists(csv_filename)
    with open(csv_filename, "a", newline="") as f:
        writer = csv.writer(f)
        if write_header:
            writer.writerow(header)
        writer.writerow(row)
        
if __name__ == "__main__":
    main()