import os
import time
import random
import argparse
import csv
import numpy as np

import torch
import torch.nn as nn
import sklearn.metrics
from torch import autocast
from torch.utils.data import DataLoader

from datasets import load_dataset
from transformers import (
    BertTokenizerFast,
    BertForSequenceClassification,
    AdamW,
    get_linear_schedule_with_warmup,
    DataCollatorWithPadding,
)

from tqdm import tqdm


# =============================
# 0) Utils & Mixup Model
# =============================
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def accuracy_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> float:
    preds = torch.argmax(logits, dim=-1)
    return (preds == labels).float().mean().item()


def _rankdata_average_ties(x: np.ndarray) -> np.ndarray:
    """Tie-aware rankdata (1..N), like scipy.stats.rankdata(method='average')."""
    x = np.asarray(x)
    n = x.shape[0]
    if n == 0:
        return x.astype(np.float64)

    sorter = np.argsort(x, kind="mergesort")
    x_sorted = x[sorter]

    unequal = np.empty(n, dtype=bool)
    unequal[0] = True
    unequal[1:] = x_sorted[1:] != x_sorted[:-1]
    group_starts = np.nonzero(unequal)[0]
    group_starts = np.concatenate([group_starts, np.array([n], dtype=group_starts.dtype)])

    ranks_sorted = np.empty(n, dtype=np.float64)
    for i in range(group_starts.shape[0] - 1):
        start = int(group_starts[i])
        end = int(group_starts[i + 1])
        avg_rank = 0.5 * (start + (end - 1)) + 1.0
        ranks_sorted[start:end] = avg_rank

    ranks = np.empty(n, dtype=np.float64)
    ranks[sorter] = ranks_sorted
    return ranks


def spearman_corr_from_logits(logits: torch.Tensor, labels: torch.Tensor, eps: float = 1e-8) -> float:
    x = logits.view(-1).detach().float().cpu().numpy()
    y = labels.view(-1).detach().float().cpu().numpy()
    if x.size == 0:
        return 0.0

    rx = _rankdata_average_ties(x)
    ry = _rankdata_average_ties(y)

    rx = rx - rx.mean()
    ry = ry - ry.mean()
    denom = (np.sqrt((rx * rx).mean()) * np.sqrt((ry * ry).mean())) + eps
    return float((rx * ry).mean() / denom)


def matthews_corr_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> float:
    preds = torch.argmax(logits, dim=-1).detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    return float(sklearn.metrics.matthews_corrcoef(labels, preds))


def metric_from_logits(task: str, logits: torch.Tensor, labels: torch.Tensor) -> float:
    task = task.lower()
    if task == "cola":
        return matthews_corr_from_logits(logits, labels)
    if task == "stsb":
        return spearman_corr_from_logits(logits, labels)
    return accuracy_from_logits(logits, labels)


def safe_filename(name: str) -> str:
    return "".join(ch if (ch.isalnum() or ch in ("-", "_", ".")) else "_" for ch in str(name))


def parse_args():
    parser = argparse.ArgumentParser(description="BERT fine-tuning on GLUE with Mixup")
    parser.add_argument(
        "--task",
        type=str,
        default="rte",
        choices=["cola", "sst2", "mrpc", "stsb", "qqp", "mnli", "qnli", "rte", "wnli"],
        help="Which GLUE sub-task to run",
    )
    parser.add_argument("--model_name", type=str, default="bert-base-cased")
    parser.add_argument("--max_len", type=int, default=128)
    parser.add_argument("--lr", type=float, default=2e-5)
    parser.add_argument("--weight_decay", type=float, default=0.01)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--bs", type=int, default=32)
    parser.add_argument("--eval_batch_size", type=int, default=32)
    parser.add_argument("--warmup_ratio", type=float, default=0.1)
    parser.add_argument("--set_seed", type=int, default=0)
    parser.add_argument("--log_every", type=int, default=50)
    parser.add_argument("--grad_clip", type=float, default=1.0)
    parser.add_argument("--output_dir", type=str, default="./bert_glue_mixup")
    parser.add_argument("--num_workers", type=int, default=2)
    parser.add_argument(
        "--mixup_alpha",
        type=float,
        default=1.0,
        help="Beta(alpha, alpha) for lam. <=0 disables mixup.",
    )
    parser.add_argument(
        "--mixup_mode",
        type=str,
        default="encoder",
        choices=["encoder", "embeddings"],
        help="Where to apply mixup: pooled encoder output or input embeddings.",
    )
    parser.add_argument("--gpu_id", type=str, default="0", help="which gpu to use")
    parser.add_argument("--save_best", action="store_true", help="save best checkpoint to output_dir/best")
    parser.add_argument("--bf16", action="store_true", help="enable bf16 autocast (AMP) on CUDA")
    return parser.parse_args()


def _use_bf16(args, device: torch.device) -> bool:
    if not getattr(args, "bf16", False):
        return False
    if device.type != "cuda":
        return False
    return bool(torch.cuda.is_available() and torch.cuda.is_bf16_supported())


class MixupBertForSequenceClassification(BertForSequenceClassification):
    """
    Modified to match the logic in first code snippet (MyBertModel).
    Uses shared position_ids and token_type_ids for both inputs.
    """

    def _forward_init(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):
        """
        Initialize forward pass parameters (copied from first code snippet)
        """
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        # Get extended attention mask
        extended_attention_mask = self.bert.get_extended_attention_mask(
            attention_mask, input_shape, device
        )

        # Handle encoder attention mask if in decoder mode
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_extended_attention_mask = self.bert.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        # Prepare head mask
        head_mask = self.bert.get_head_mask(head_mask, self.config.num_hidden_layers)

        return (
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            head_mask,
            inputs_embeds,
            encoder_hidden_states,
            encoder_attention_mask,
            extended_attention_mask,
            encoder_extended_attention_mask,
        )

    def forward_mix_embed(
        self,
        input_ids1,
        attention_mask1,
        token_type_ids1,
        input_ids2,
        attention_mask2,
        token_type_ids2,
        lam,
    ):
        """
        Modified to match first code snippet logic:
        - Use shared position_ids and token_type_ids
        - Process through _forward_init first
        """
        # 1. Process first input through _forward_init
        x1, attention_mask1_proc, token_type_ids, position_ids, head_mask, inputs_embeds, \
            encoder_hidden_states, encoder_attention_mask, extended_attention_mask1, \
            encoder_extended_attention_mask = self._forward_init(
                input_ids=input_ids1,
                attention_mask=attention_mask1,
                token_type_ids=token_type_ids1
            )

        # 2. Get embeddings for first input using SHARED token_type_ids and position_ids
        embedding_output1 = self.bert.embeddings(
            input_ids=x1,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds
        )

        # 3. Process second input through _forward_init (reusing same position_ids, token_type_ids)
        x2, attention_mask2_proc, token_type_ids, position_ids, head_mask, inputs_embeds, \
            encoder_hidden_states, encoder_attention_mask, extended_attention_mask2, \
            encoder_extended_attention_mask = self._forward_init(
                input_ids=input_ids2,
                attention_mask=attention_mask2,
                token_type_ids=token_type_ids2
            )

        # 4. Get embeddings for second input using SHARED token_type_ids and position_ids
        embedding_output2 = self.bert.embeddings(
            input_ids=x2,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds
        )

        # 5. Mix embeddings
        embedding_output = lam * embedding_output1 + (1.0 - lam) * embedding_output2

        # 6. Mix attention masks (take max to ensure we don't miss attending to any value)
        extended_attention_mask = torch.max(extended_attention_mask1, extended_attention_mask2)

        # 7. Forward through encoder using mixed embeddings
        encoder_outputs = self.bert.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
        )

        # 8. Pooling and classification
        sequence_output = encoder_outputs[0]
        pooled_output = self.bert.pooler(sequence_output)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        return logits


    
    def forward_mix_encoder(
        self,
        input_ids1,
        attention_mask1,
        token_type_ids1,
        input_ids2,
        attention_mask2,
        token_type_ids2,
        lam,
    ):
        # two independent standard BERT forwards
        outputs1 = self.bert(
            input_ids=input_ids1,
            attention_mask=attention_mask1,
            token_type_ids=token_type_ids1,   # <-- add
        )
        outputs2 = self.bert(
            input_ids=input_ids2,
            attention_mask=attention_mask2,
            token_type_ids=token_type_ids2,   # <-- add
        )

        pooled_output1 = self.dropout(outputs1[1])
        pooled_output2 = self.dropout(outputs2[1])

        pooled_output = lam * pooled_output1 + (1.0 - lam) * pooled_output2
        logits = self.classifier(pooled_output)
        return logits


# =============================
# 1) Prepare GLUE dataset
# =============================
def build_dataloaders(args):
    task = args.task.lower()
    raw = load_dataset("glue", task)
    tokenizer = BertTokenizerFast.from_pretrained(args.model_name)

    def _coalesce_list(values, fill=""):
        return [fill if v is None else v for v in values]

    def tokenize_fn(batch):
        if task in {"cola", "sst2"}:
            s1 = _coalesce_list(batch["sentence"])
            return tokenizer(s1, truncation=True, max_length=args.max_len)
        if task in {"mrpc", "stsb", "rte", "wnli"}:
            s1 = _coalesce_list(batch["sentence1"])
            s2 = _coalesce_list(batch["sentence2"])
            return tokenizer(s1, s2, truncation=True, max_length=args.max_len)
        if task == "qqp":
            q1 = _coalesce_list(batch["question1"])
            q2 = _coalesce_list(batch["question2"])
            return tokenizer(q1, q2, truncation=True, max_length=args.max_len)
        if task == "mnli":
            p = _coalesce_list(batch["premise"])
            h = _coalesce_list(batch["hypothesis"])
            return tokenizer(p, h, truncation=True, max_length=args.max_len)
        if task == "qnli":
            q = _coalesce_list(batch["question"])
            s = _coalesce_list(batch["sentence"])
            return tokenizer(q, s, truncation=True, max_length=args.max_len)
        raise ValueError(f"Unsupported task: {task}")

    tokenized = raw.map(tokenize_fn, batched=True)

    remove_cols = []
    for col in [
        "sentence",
        "sentence1",
        "sentence2",
        "question1",
        "question2",
        "premise",
        "hypothesis",
        "question",
        "idx",
    ]:
        if col in tokenized["train"].column_names:
            remove_cols.append(col)
    if len(remove_cols) > 0:
        tokenized = tokenized.remove_columns(remove_cols)

    tokenized = tokenized.rename_column("label", "labels")

    if task == "stsb":
        def cast_label(batch):
            return {"labels": [float(x) for x in batch["labels"]]}

        tokenized = tokenized.map(cast_label, batched=True)

    tokenized.set_format("torch")

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    train_loader = DataLoader(
        tokenized["train"],
        batch_size=args.bs,
        shuffle=True,
        num_workers=args.num_workers,
        collate_fn=data_collator,
        pin_memory=True,
    )

    val_split = "validation_matched" if task == "mnli" and "validation_matched" in tokenized else "validation"
    val_loader = DataLoader(
        tokenized[val_split],
        batch_size=args.eval_batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        collate_fn=data_collator,
        pin_memory=True,
    )

    return tokenizer, train_loader, val_loader


# =============================
# 2) Train / Eval loops
# =============================
def mixup_criterion_cross_entropy(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def mixup_criterion_regression(criterion, pred, y_a, y_b, lam):
    y_mix = lam * y_a + (1.0 - lam) * y_b
    return criterion(pred, y_mix)


@torch.no_grad()
def eval_on_dataloader(model, dataloader, device, task: str, num_labels: int, args):
    """Lightweight eval (no tqdm), used for periodic validation during training."""
    was_training = model.training
    model.eval()

    is_regression = (num_labels == 1)
    running_loss = 0.0
    seen = 0
    all_logits = []
    all_labels = []

    use_bf16 = _use_bf16(args, device)
    for batch in dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        token_type_ids = batch.get("token_type_ids")
        if token_type_ids is not None:
            token_type_ids = token_type_ids.to(device)

        with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=use_bf16):
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                labels=labels,
            )
            loss = outputs.loss
            logits = outputs.logits

        bs = labels.size(0)
        running_loss += float(loss.item()) * bs
        seen += bs

        if is_regression:
            all_logits.append(logits.view(-1).detach().cpu())
            all_labels.append(labels.view(-1).detach().cpu())
        else:
            all_logits.append(logits.detach().cpu())
            all_labels.append(labels.detach().cpu())

    all_logits = torch.cat(all_logits, dim=0) if len(all_logits) > 0 else torch.empty(0)
    all_labels = torch.cat(all_labels, dim=0) if len(all_labels) > 0 else torch.empty(0)

    metric = metric_from_logits(task, all_logits, all_labels) if seen > 0 else 0.0
    avg_loss = (running_loss / seen) if seen > 0 else 0.0

    if was_training:
        model.train()
    return avg_loss, metric


def train_one_epoch(
    model,
    dataloader,
    val_dataloader,
    optimizer,
    scheduler,
    device,
    epoch_idx: int,
    args,
    num_labels: int,
    best_state: dict,
    tokenizer=None,
    log_every: int = 100,
    grad_clip: float = 1.0,
):
    model.train()
    is_regression = (num_labels == 1)
    criterion = nn.MSELoss() if is_regression else nn.CrossEntropyLoss()

    use_bf16 = _use_bf16(args, device)

    running_loss = 0.0
    running_acc = 0.0
    seen = 0

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Train Epoch {epoch_idx}")
    for step, batch in pbar:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        token_type_ids = batch.get("token_type_ids")
        if token_type_ids is not None:
            token_type_ids = token_type_ids.to(device)

        bs = input_ids.size(0)
        optimizer.zero_grad(set_to_none=True)

        # =============================
        # bf16 autocast forward
        # =============================
        if bs < 2 or args.mixup_alpha <= 0:
            with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=use_bf16):
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                    labels=labels,
                )
                loss = outputs.loss
                logits = outputs.logits

            # backward outside autocast
            loss.backward()
            metric_val = metric_from_logits(args.task, logits if not is_regression else logits.view(-1), labels)

        else:
            lam = float(np.random.beta(args.mixup_alpha, args.mixup_alpha))
            index = torch.randperm(bs).to(device)

            input_ids2 = input_ids[index]
            attention_mask2 = attention_mask[index]
            labels2 = labels[index]
            token_type_ids2 = token_type_ids[index] if token_type_ids is not None else None

            with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=use_bf16):
                if args.mixup_mode == "embeddings":
                    logits = model.forward_mix_embed(
                        input_ids1=input_ids,
                        attention_mask1=attention_mask,
                        token_type_ids1=token_type_ids,
                        input_ids2=input_ids2,
                        attention_mask2=attention_mask2,
                        token_type_ids2=token_type_ids2,
                        lam=lam,
                    )
                else:
                    logits = model.forward_mix_encoder(
                        input_ids1=input_ids,
                        attention_mask1=attention_mask,
                        token_type_ids1=token_type_ids,
                        input_ids2=input_ids2,
                        attention_mask2=attention_mask2,
                        token_type_ids2=token_type_ids2,
                        lam=lam,
                    )

                if is_regression:
                    pred = logits.view(-1)
                    y_a = labels.view(-1).to(pred.dtype)
                    y_b = labels2.view(-1).to(pred.dtype)
                    loss = mixup_criterion_regression(criterion, pred, y_a, y_b, lam)
                else:
                    loss = mixup_criterion_cross_entropy(criterion, logits, labels, labels2, lam)

            loss.backward()

            if is_regression:
                pred = logits.view(-1)
                y_a = labels.view(-1).to(pred.dtype)
                y_b = labels2.view(-1).to(pred.dtype)
                y_mix = lam * y_a + (1.0 - lam) * y_b
                metric_val = spearman_corr_from_logits(pred, y_mix)
            else:
                preds = torch.argmax(logits, dim=1)
                metric_val = (
                    (lam * (preds == labels).float().sum() + (1.0 - lam) * (preds == labels2).float().sum()).item() / bs
                )


        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        optimizer.step()
        scheduler.step()

        running_loss += float(loss.item()) * bs
        running_acc += float(metric_val) * bs
        seen += bs

        global_step = (epoch_idx - 1) * len(dataloader) + (step + 1)

        if (step + 1) % log_every == 0:
            pbar.set_postfix(
                loss=f"{running_loss/seen:.4f}",
                metric=f"{running_acc/seen:.4f}",
                lr=f"{scheduler.get_last_lr()[0]:.2e}",
            )

            # periodic validation
            # periodic validation (bf16 autocast enabled inside eval_on_dataloader when CUDA bf16 is supported)
            val_loss, val_metric = eval_on_dataloader(
                model=model,
                dataloader=val_dataloader,
                device=device,
                task=args.task,
                num_labels=num_labels,
                args=args,
            )

            print(
                f"[Eval@step {global_step}] "
                f"val_loss={val_loss:.4f} val_{best_state['metric_name']}={val_metric:.4f}"
            )

            if val_metric > best_state["best_val_metric"]:
                best_state["best_val_metric"] = float(val_metric)
                best_state["best_val_loss"] = float(val_loss)
                best_state["best_epoch"] = int(epoch_idx)
                best_state["best_step"] = int(global_step)
                best_state["best_train_loss"] = float(running_loss / seen)
                best_state["best_train_metric"] = float(running_acc / seen)

                print(
                    f"[Best@step {global_step}] "
                    f"best_val_{best_state['metric_name']}={best_state['best_val_metric']:.4f}"
                )

                if args.save_best:
                    best_dir = os.path.join(args.output_dir, "best")
                    os.makedirs(best_dir, exist_ok=True)
                    model.save_pretrained(best_dir)
                    if tokenizer is not None:
                        tokenizer.save_pretrained(best_dir)
                    print(f"[Info] Saved best checkpoint to {best_dir}")

    epoch_loss = running_loss / seen
    epoch_metric = running_acc / seen
    return epoch_loss, epoch_metric


@torch.no_grad()
def eval_one_epoch(model, dataloader, device, epoch_idx: int, task: str, num_labels: int, args):
    model.eval()

    is_regression = (num_labels == 1)
    running_loss = 0.0
    seen = 0

    all_logits = []
    all_labels = []

    use_bf16 = _use_bf16(args, device)
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Eval  Epoch {epoch_idx}")
    for step, batch in pbar:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        token_type_ids = batch.get("token_type_ids")
        if token_type_ids is not None:
            token_type_ids = token_type_ids.to(device)

        # Standard forward (no mixup during evaluation)
        with autocast(device_type="cuda", dtype=torch.bfloat16, enabled=use_bf16):
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                labels=labels,
            )
            loss = outputs.loss
            logits = outputs.logits

        bs = labels.size(0)
        running_loss += float(loss.item()) * bs
        seen += bs

        if is_regression:
            all_logits.append(logits.view(-1).detach().cpu())
            all_labels.append(labels.view(-1).detach().cpu())
        else:
            all_logits.append(logits.detach().cpu())
            all_labels.append(labels.detach().cpu())

        if seen > 0:
            pbar.set_postfix(loss=f"{running_loss/seen:.4f}")

    all_logits = torch.cat(all_logits, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    metric = metric_from_logits(task, all_logits, all_labels)
    return running_loss / seen, metric


# =============================
# 3) Main: rte fine-tune BERT with Mixup
# =============================
def main():
    args = parse_args()
    print("[Args]", vars(args))

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    os.makedirs(args.output_dir, exist_ok=True)
    set_seed(args.set_seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Info] Device: {device}")

    if args.bf16:
        if device.type == "cuda" and torch.cuda.is_available() and torch.cuda.is_bf16_supported():
            print("[Info] BF16 is supported and will be used via autocast.")
        else:
            print("[Warning] BF16 requested but not supported; will run in fp32.")

    task = args.task.lower()
    num_labels = 1 if task == "stsb" else (3 if task == "mnli" else 2)

    tokenizer, train_loader, val_loader = build_dataloaders(args)

    model = MixupBertForSequenceClassification.from_pretrained(args.model_name, num_labels=num_labels)
    if task == "stsb":
        model.config.problem_type = "regression"
    model.to(device)

    optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    total_steps = args.epochs * len(train_loader)
    warmup_steps = int(total_steps * args.warmup_ratio)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps,
    )

    print(
        f"[Info] task={task}, total_steps={total_steps}, warmup_steps={warmup_steps}, "
        f"mixup_alpha={args.mixup_alpha}, mixup_mode={args.mixup_mode}"
    )

    metric_name = "spearman" if task == "stsb" else ("mcc" if task == "cola" else "accuracy")

    best_state = {
        "metric_name": metric_name,
        "best_val_metric": -1.0,
        "best_val_loss": None,
        "best_epoch": None,
        "best_step": None,
        "best_train_loss": None,
        "best_train_metric": None,
    }

    for epoch in range(1, args.epochs + 1):
        start = time.time()

        train_loss, train_metric = train_one_epoch(
            model=model,
            dataloader=train_loader,
            val_dataloader=val_loader,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            epoch_idx=epoch,
            args=args,
            num_labels=num_labels,
            best_state=best_state,
            tokenizer=tokenizer,
            log_every=args.log_every,
            grad_clip=args.grad_clip,
        )

        val_loss, val_metric = eval_one_epoch(
            model=model,
            dataloader=val_loader,
            device=device,
            epoch_idx=epoch,
            task=args.task,
            num_labels=num_labels,
            args=args,
        )

        elapsed = time.time() - start
        print(
            f"\n[Epoch {epoch}/{args.epochs}] "
            f"train_loss={train_loss:.4f} train_metric={train_metric:.4f} | "
            f"val_loss={val_loss:.4f} val_{metric_name}={val_metric:.4f} | "
            f"time={elapsed/60:.1f} min\n"
        )

        # also consider epoch-end eval for best tracking
        if val_metric > best_state["best_val_metric"]:
            best_state["best_val_metric"] = float(val_metric)
            best_state["best_val_loss"] = float(val_loss)
            best_state["best_epoch"] = int(epoch)
            best_state["best_step"] = int(epoch * len(train_loader))
            best_state["best_train_loss"] = float(train_loss)
            best_state["best_train_metric"] = float(train_metric)

            print(
                f"[Best@epoch {epoch}] best_val_{metric_name}={best_state['best_val_metric']:.4f}"
            )

            if args.save_best:
                best_dir = os.path.join(args.output_dir, "best")
                os.makedirs(best_dir, exist_ok=True)
                model.save_pretrained(best_dir)
                tokenizer.save_pretrained(best_dir)
                print(f"[Info] Saved best checkpoint to {best_dir}")

    print(f"[Done] Best validation {metric_name} = {best_state['best_val_metric']:.4f}")

    # -------- Save summary to CSV (append) --------
    csv_filename =  f"summary_mixup_{safe_filename(args.task)}.csv"
    header = [
        "task",
        "model_name",
        "max_len",
        "lr",
        "epochs",
        "bs",
        "eval_batch_size",
        "weight_decay",
        "warmup_ratio",
        "mixup_alpha",
        "mixup_mode",
        "set_seed",
        "gpu_id",
        "output_dir",
        "metric_name",
        "best_val_metric",
        "best_val_loss",
        "best_epoch",
        "best_step",
    ]
    row = [
        args.task,
        args.model_name,
        args.max_len,
        args.lr,
        args.epochs,
        args.bs,
        args.eval_batch_size,
        args.weight_decay,
        args.warmup_ratio,
        args.mixup_alpha,
        args.mixup_mode,
        args.set_seed,
        args.gpu_id,
        args.output_dir,
        best_state["metric_name"],
        best_state["best_val_metric"],
        best_state["best_val_loss"],
        best_state["best_epoch"],
        best_state["best_step"],
    ]
    write_header = not os.path.exists(csv_filename)
    with open(csv_filename, "a", newline="") as f:
        writer = csv.writer(f)
        if write_header:
            writer.writerow(header)
        writer.writerow(row)
    print(f"[Done] Wrote summary to: {csv_filename}")


if __name__ == "__main__":
    main()