"""
Finetune EBM with cached encodings (forward_from_encoded).

Objective:
    loss = energy(question_related, answer) - energy(question, answer)

Input CSV (required):
    question, answer, question_related
(Optional columns are ignored): energy_full, energy_related
"""

import argparse
import logging
import os
import random
from pathlib import Path
from typing import Dict, List, Tuple

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
from torch import nn
import math
import json

# Project imports
from src.energy_model.config import EBMConfig
from src.energy_model.models import EnergyModel

# ------- Caching helpers (compatible with your earlier utilities) -------

def contrastive_loss(
    pos_energy: torch.Tensor,
    neg_energy: torch.Tensor,
    margin: float,
) -> torch.Tensor:
    """
    Encourages pos_energy to be at least `margin` lower than neg_energy.
    Minimizes max(0, margin + pos - neg).
    """
    y = torch.ones_like(pos_energy)  # target: first input (neg) > second (pos)
    return nn.MarginRankingLoss(margin=margin)(
        neg_energy, pos_energy, y
    )
    
def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_cached_batch(
    texts: List[str],
    cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Stack CPU-cached (emb, mask) for `texts` and move to device."""
    embs = torch.stack([cache[t][0] for t in texts]).to(device, non_blocking=True)
    masks = torch.stack([cache[t][1] for t in texts]).to(device, non_blocking=True)
    return embs, masks

@torch.no_grad()
def update_encoding_cache(
    texts: List[str],
    cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    model: torch.nn.Module,
    batch_size: int = 256,
) -> None:
    """Encode any texts not in cache, store (emb, mask) on CPU."""
    to_encode = [t for t in texts if t not in cache]
    if not to_encode:
        return
    for i in tqdm(range(0, len(to_encode), batch_size), desc="Encoding cache", leave=False):
        chunk = to_encode[i:i+batch_size]
        embs, masks = model.text_encoder(chunk)  # expect [B, ...] tensors on device
        # move to CPU to save VRAM
        embs = embs.cpu()
        masks = masks.cpu()
        for j, text in enumerate(chunk):
            cache[text] = (embs[j], masks[j])

# -------------------------- Dataset --------------------------

class QAPairs(Dataset):
    def __init__(self, df: pd.DataFrame):
        needed = ["question", "answer", "question_related"]
        for c in needed:
            if c not in df.columns:
                raise ValueError(f"Missing required column '{c}' in CSV")
        df = df.dropna(subset=needed).copy()
        for c in needed:
            df[c] = df[c].astype(str).str.strip()
        self.q = df["question"].tolist()
        self.a = df["answer"].tolist()
        self.qr = df["question_related"].tolist()
        self.orig_df = df  # keep for optional energy write

    def __len__(self):
        return len(self.q)

    def __getitem__(self, idx):
        return self.q[idx], self.a[idx], self.qr[idx]

def collate_text(batch):
    qs, ans, qrs = zip(*batch)
    return list(qs), list(ans), list(qrs)

class IMDBSentimentPairs(Dataset):
    """Dataset for IMDB sentiment analysis with sentence importance scores."""
    
    def __init__(self, df: pd.DataFrame):
        needed = ["review_text", "formatted_prediction", "sentence_analysis"]
        for c in needed:
            if c not in df.columns:
                raise ValueError(f"Missing required column '{c}' in CSV")
        
        # Filter and prepare data
        df = df.dropna(subset=needed).copy()
        for c in needed:
            df[c] = df[c].astype(str).str.strip()
        
        # Process each record to create positive/negative pairs
        self.positive_texts = []  # High importance sentences
        self.negative_texts = []  # Low importance sentences  
        self.predictions = []     # Formatted predictions (same for both)
        self.orig_df = df.copy()  # Keep original for reference
        
        filtered_count = 0
        for idx, row in df.iterrows():
            try:
                # Parse sentence analysis JSON
                sentence_data = json.loads(row["sentence_analysis"])
                
                # Filter records with less than 5 sentences
                if len(sentence_data) < 5:
                    filtered_count += 1
                    continue
                
                # Sort by importance score (descending)
                sentence_data.sort(key=lambda x: x["importance_score"], reverse=True)
                
                # Calculate split point (80% for positive, 20% for negative)
                split_point = int(0.8 * len(sentence_data))
                
                # Get high importance sentences (top 80%)
                high_importance = sentence_data[:split_point]
                high_importance_text = " ".join([s["sentence"] for s in high_importance])
                
                # Get low importance sentences (bottom 20%)
                low_importance = sentence_data[split_point:]
                low_importance_text = " ".join([s["sentence"] for s in low_importance])
                
                # Store the pairs
                self.positive_texts.append(high_importance_text)
                self.negative_texts.append(low_importance_text)
                self.predictions.append(row["formatted_prediction"])
                
            except (json.JSONDecodeError, KeyError, TypeError) as e:
                logging.warning(f"Skipping row {idx} due to parsing error: {e}")
                continue
        
        logging.info(f"Filtered {filtered_count} records with <5 sentences")
        logging.info(f"Created {len(self.positive_texts)} training pairs from IMDB dataset")
    
    def __len__(self):
        return len(self.positive_texts)
    
    def __getitem__(self, idx):
        # Return in the same format as QAPairs: (question, answer, question_related)
        # Here: (negative_text, prediction, positive_text)
        # This maintains the energy objective: energy(positive, prediction) < energy(negative, prediction)
        return self.negative_texts[idx], self.predictions[idx], self.positive_texts[idx]

# -------------------------- Finetune --------------------------

def build_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(description="Finetune EBM (cached encodings)")
    # Data & model
    p.add_argument("--csv", required=True, help="CSV with columns: question, answer, question_related")
    p.add_argument("--ebm_checkpoint", required=True, help="Path to pretrained EBM checkpoint (.pt/.bin)")
    p.add_argument("--output_checkpoint", required=True, help="Where to save finetuned checkpoint (.pt)")
    # Model arch (must match checkpoint)
    p.add_argument("--ebm_self_attention_layers", type=int, default=2)
    p.add_argument("--ebm_cross_attention_layers", type=int, default=6)
    # Train
    p.add_argument("--epochs", type=int, default=3)
    p.add_argument("--batch_size", type=int, default=32)
    p.add_argument("--lr", type=float, default=2e-5)
    p.add_argument("--weight_decay", type=float, default=0.0)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--grad_accum_steps", type=int, default=1)
    p.add_argument("--fp16", action="store_true", help="Enable mixed precision in the energy head (safe with cached inputs)")
    # Cache behavior
    p.add_argument("--assume_frozen_text_encoder", type=lambda s: s.lower() not in ["0","false","no"], default=True,
                   help="If False, cache is rebuilt each epoch (slower).")
    p.add_argument("--cache_encode_batch_size", type=int, default=256)
    # After-train energies
    p.add_argument("--write_energies_after", action="store_true")
    p.add_argument("--output_csv", default=None, help="Path to write energies CSV (default: <input>_energies_after.csv)")
    # Misc
    p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--loglevel", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"])
    p.add_argument("--checkpoints_per_epoch", type=int, default=0,
               help="How many intra-epoch checkpoints to save (0 = only final).")
    p.add_argument("--margin", type=float, default=0.2,
               help="Margin for contrastive loss: max(0, margin + pos - neg)")
    p.add_argument("--val_csv", default=None,
               help="CSV for validation (same columns as train).")
    p.add_argument("--val_rows", default=None,
                help='Optional row range "start:end" to evaluate a subset of the validation CSV. 0-based, end exclusive.')
    p.add_argument("--val_dump_csv", default=None,
                help="If provided, do not log validation metrics; instead write per-row energies for the selected rows.")
    # Resume training
    p.add_argument("--resume_checkpoint", default=None,
                help="Path to checkpoint file to resume training from. If provided, will load model and optimizer state.")
    # IMDB sentiment analysis mode
    p.add_argument("--imdb_mode", action="store_true",
                help="Use IMDB sentiment analysis dataset with sentence importance scores. Requires CSV with columns: review_text, formatted_prediction, sentence_analysis")

    return p

def save_checkpoint(model, optimizer, path, epoch, global_step):
    """Save model checkpoint with training state"""
    checkpoint = {
        "epoch": epoch,
        "global_step": global_step,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    torch.save(checkpoint, path)
    logging.info(f"Saved checkpoint: {path}")


def load_checkpoint(checkpoint_path, model, optimizer, device):
    """Load model checkpoint and return training state
    
    Args:
        checkpoint_path: Path to checkpoint file
        model: Model to load state into
        optimizer: Optimizer to load state into  
        device: Device for loading
        
    Returns:
        tuple: (epoch, global_step) from checkpoint
    """
    if not Path(checkpoint_path).exists():
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
        
    logging.info(f"Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Load model state
    if "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"], strict=True)
    else:
        # Fallback for old checkpoint format
        model.load_state_dict(checkpoint, strict=True)
    
    # Load optimizer state if available
    if "optimizer_state_dict" in checkpoint and optimizer is not None:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    
    # Get training state
    epoch = checkpoint.get("epoch", 0)
    global_step = checkpoint.get("global_step", 0)
    
    logging.info(f"Resumed from epoch {epoch}, global step {global_step}")
    return epoch, global_step


@torch.no_grad()
def validate(
    model: EnergyModel,
    dl: DataLoader,
    cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    device: torch.device,
    margin: float,
) -> dict:
    """
    Simple validation loop using the same loss/logic as training.
    Returns averaged metrics across all batches.
    """
    model.eval()
    total_loss, total_correct, total_samples = 0.0, 0, 0
    pos_means, neg_means = [], []

    for qs, ans, qrs in tqdm(dl, desc="Validation", dynamic_ncols=True):
        # cached encodings
        x_pos_q, x_pos_m = get_cached_batch(qrs, cache, device)
        y_pos_q, y_pos_m = get_cached_batch(ans, cache, device)
        x_neg_q, x_neg_m = get_cached_batch(qs,  cache, device)
        y_neg_q, y_neg_m = y_pos_q, y_pos_m

        pos_e = model.forward_from_encoded((x_pos_q, x_pos_m), (y_pos_q, y_pos_m)).flatten()
        neg_e = model.forward_from_encoded((x_neg_q, x_neg_m), (y_neg_q, y_neg_m)).flatten()
        loss = contrastive_loss(pos_e, neg_e, margin=margin)

        # metrics
        total_loss += loss.item() * pos_e.size(0)
        correct = (pos_e < neg_e).sum().item()
        total_correct += correct
        total_samples += pos_e.size(0)

        pos_means.append(pos_e.mean().item())
        neg_means.append(neg_e.mean().item())

    return {
        "loss": total_loss / max(1, total_samples),
        "accuracy": 100.0 * total_correct / max(1, total_samples),
        "pos_energy_mean": sum(pos_means) / len(pos_means),
        "neg_energy_mean": sum(neg_means) / len(neg_means),
    }

@torch.no_grad()
def validate_dump_rows(
    model: EnergyModel,
    df: pd.DataFrame,
    cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    device: torch.device,
    out_csv: str,
    batch_size: int,
    margin: float,   # NEW: need margin to compute per-row loss
) -> None:
    """Compute energies + per-row metrics and write to CSV. No summary logging."""
    if len(df) == 0:
        logging.warning("Validation slice is empty; nothing to dump.")
        return

    rows_out = []
    for i in tqdm(range(0, len(df), batch_size), desc="Validation (dump)", dynamic_ncols=True):
        sl = slice(i, min(i + batch_size, len(df)))
        q  = [str(x).strip() for x in df["question"].iloc[sl].tolist()]
        a  = [str(x).strip() for x in df["answer"].iloc[sl].tolist()]
        qr = [str(x).strip() for x in df["question_related"].iloc[sl].tolist()]

        x_pos_q, x_pos_m = get_cached_batch(qr, cache, device)
        y_pos_q, y_pos_m = get_cached_batch(a,  cache, device)
        x_neg_q, x_neg_m = get_cached_batch(q,  cache, device)
        y_neg_q, y_neg_m = y_pos_q, y_pos_m

        pos_e_t = model.forward_from_encoded((x_pos_q, x_pos_m), (y_pos_q, y_pos_m)).flatten()
        neg_e_t = model.forward_from_encoded((x_neg_q, x_neg_m), (y_neg_q, y_neg_m)).flatten()

        # per-row tensors → python lists
        pos_e = pos_e_t.cpu().tolist()
        neg_e = neg_e_t.cpu().tolist()

        for j in range(len(pos_e)):
            # per-row metrics
            loss_j = max(0.0, margin + pos_e[j] - neg_e[j])  # margin-ranking hinge
            acc_j  = 100.0 if pos_e[j] < neg_e[j] else 0.0

            metrics_obj = {
                "loss": loss_j,
                "accuracy": acc_j,
                "pos_energy_mean": pos_e[j],
                "neg_energy_mean": neg_e[j],
            }

            rows_out.append({
                "question": q[j],
                "answer": a[j],
                "question_related": qr[j],
                # keep energies for convenience
                "energy_full": neg_e[j],
                "energy_related": pos_e[j],
                # per-row metrics (flat columns)
                "loss": loss_j,
                "accuracy": acc_j,
                "pos_energy_mean": pos_e[j],
                "neg_energy_mean": neg_e[j],
                # exact JSON blob if you want the object form too
                "metrics_json": json.dumps(metrics_obj, ensure_ascii=False),
            })

    out_path = Path(out_csv)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    pd.DataFrame(rows_out).to_csv(out_path, index=False)


def _parse_row_range(spec: str, total: int) -> slice:
    """
    Parse "start:end" into a Python slice, clamped to [0, total].
    Examples: "0:100", ":200", "500:", "-100:" (last 100)
    """
    if spec is None:
        return slice(0, total)
    parts = spec.split(":")
    if len(parts) != 2:
        raise ValueError(f'--val_rows must be "start:end", got: {spec}')
    def _to_int(x, default):
        return int(x) if x.strip() else default
    start = _to_int(parts[0], 0)
    end   = _to_int(parts[1], total)
    # support negatives
    if start < 0: start = total + start
    if end   < 0: end   = total + end
    # clamp
    start = max(0, min(start, total))
    end   = max(0, min(end, total))
    if end < start:
        end = start
    return slice(start, end)


@torch.no_grad()
def compute_energies_after(model: EnergyModel, df: pd.DataFrame,
                           cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
                           device: torch.device, out_csv: str | Path):
    rows = len(df)
    efull, erel = [], []
    pbar = tqdm(range(rows), desc="Post-train energies", dynamic_ncols=True)
    for i in pbar:
        q  = str(df.iloc[i]["question"]).strip()
        a  = str(df.iloc[i]["answer"]).strip()
        qr = str(df.iloc[i]["question_related"]).strip()
        xq_emb, xq_mask = get_cached_batch([q], cache, device)
        xa_emb, xa_mask = get_cached_batch([a], cache, device)
        xqr_emb, xqr_mask = get_cached_batch([qr], cache, device)
        e_full = model.forward_from_encoded((xq_emb, xq_mask), (xa_emb, xa_mask)).flatten()[0].item()
        e_rel  = model.forward_from_encoded((xqr_emb, xqr_mask), (xa_emb, xa_mask)).flatten()[0].item()
        efull.append(e_full)
        erel.append(e_rel)
    out_df = df.copy()
    out_df["energy_full"] = efull
    out_df["energy_related"] = erel
    out_csv = Path(out_csv)
    out_csv.parent.mkdir(parents=True, exist_ok=True)
    out_df.to_csv(out_csv, index=False)

def main():
    args = build_parser().parse_args()
    logging.basicConfig(level=getattr(logging, args.loglevel), format="%(asctime)s - %(levelname)s - %(message)s")
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    set_seed(args.seed)

    device = torch.device(args.device)
    logging.info(f"Using device: {device} (CUDA devices: {torch.cuda.device_count()})")

    # Load CSV + dataset
    df = pd.read_csv(args.csv)
    
    # Choose dataset class based on mode
    if args.imdb_mode:
        logging.info("Using IMDB sentiment analysis mode")
        ds = IMDBSentimentPairs(df)
    else:
        logging.info("Using standard QA pairs mode")
        ds = QAPairs(df)
    
    dl = DataLoader(ds, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=collate_text)

    steps_per_epoch = max(1, (len(ds) + args.batch_size - 1) // args.batch_size)
    save_every_steps = None
    if args.checkpoints_per_epoch > 0:
        save_every_steps = max(1, steps_per_epoch // args.checkpoints_per_epoch)


    # Build & load model
    ebm_config = EBMConfig(
        self_attention_n_layers=args.ebm_self_attention_layers,
        cross_attention_n_layers=args.ebm_cross_attention_layers,
    )
    model = EnergyModel(ebm_config).to(device)
    logging.info(f"Loading checkpoint: {args.ebm_checkpoint}")
    ckpt = torch.load(args.ebm_checkpoint, map_location=device)
    state_dict = ckpt.get("model_state_dict", ckpt)
    model.load_state_dict(state_dict, strict=True)

    # Prebuild encoding cache on CPU (unique strings from CSV)
    encoding_cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
    
    # Extract all unique texts for caching based on dataset type
    if args.imdb_mode:
        all_texts = list(set(ds.positive_texts + ds.negative_texts + ds.predictions))
    else:
        all_texts = list(set(ds.q + ds.a + ds.qr))
    
    logging.info(f"Pre-encoding {len(all_texts)} unique texts to CPU cache …")
    update_encoding_cache(all_texts, encoding_cache, model, batch_size=args.cache_encode_batch_size)

    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=args.fp16 and device.type == "cuda")

    # Resume from checkpoint if provided
    start_epoch = 1
    global_step = 0
    if args.resume_checkpoint:
        start_epoch, global_step = load_checkpoint(args.resume_checkpoint, model, optimizer, device)
        start_epoch += 1  # Start from next epoch after checkpoint
    
    model.train()

    for epoch in range(start_epoch, args.epochs + 1):
        # If encoder is NOT frozen, rebuild cache each epoch (slower)
        if not args.assume_frozen_text_encoder:
            logging.info("Re-encoding cache (assume_frozen_text_encoder=False) …")
            update_encoding_cache(all_texts, encoding_cache, model, batch_size=args.cache_encode_batch_size)

        pbar = tqdm(dl, desc=f"Epoch {epoch}/{args.epochs}", dynamic_ncols=True)
        running = 0.0
        acc_steps = 0

        for qs, ans, qrs in pbar:
            # Pull cached encodings → device
            x_pos_q, x_pos_m = get_cached_batch(qrs, encoding_cache, device)  # question_related
            y_pos_q, y_pos_m = get_cached_batch(ans, encoding_cache, device)  # answer
            x_neg_q, x_neg_m = get_cached_batch(qs,  encoding_cache, device)  # question
            y_neg_q, y_neg_m = y_pos_q, y_pos_m                                # same answers

            # Forward via encoded tensors
            if scaler.is_enabled():
                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    pos_e = model.forward_from_encoded((x_pos_q, x_pos_m), (y_pos_q, y_pos_m)).flatten()
                    neg_e = model.forward_from_encoded((x_neg_q, x_neg_m), (y_neg_q, y_neg_m)).flatten()
                    loss = contrastive_loss(pos_e, neg_e, args.margin) #loss = (pos_e - neg_e).mean()
            else:
                pos_e = model.forward_from_encoded((x_pos_q, x_pos_m), (y_pos_q, y_pos_m)).flatten()
                neg_e = model.forward_from_encoded((x_neg_q, x_neg_m), (y_neg_q, y_neg_m)).flatten()
                loss = contrastive_loss(pos_e, neg_e, args.margin) #loss = (pos_e - neg_e).mean()

            # Backward + (optional) grad accumulation
            if scaler.is_enabled():
                scaler.scale(loss).backward()
            else:
                loss.backward()
            acc_steps += 1

            if acc_steps % args.grad_accum_steps == 0:
                if scaler.is_enabled():
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), args.grad_clip)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    clip_grad_norm_(model.parameters(), args.grad_clip)
                    optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                global_step += 1

            running += loss.item()
            avg_loss = running / (pbar.n + 1)
            with torch.no_grad():
                avg_pos = pos_e.mean().item()
                avg_neg = neg_e.mean().item()
            pbar.set_postfix(loss=f"{avg_loss:.4f}", posE=f"{avg_pos:.3f}", negE=f"{avg_neg:.3f}")

            # Intra-epoch checkpointing (optional)
            if save_every_steps is not None:
                # pbar.n is 0-based completed steps in this epoch
                if (pbar.n + 1) % save_every_steps == 0:
                    tmp_path = Path(args.output_checkpoint)
                    tmp_path = tmp_path.with_name(tmp_path.stem + f"_e{epoch}_step{pbar.n+1}" + tmp_path.suffix)
                    save_checkpoint(model, optimizer, tmp_path, epoch=epoch, global_step=global_step)


    final_path = Path(args.output_checkpoint)
    final_path = final_path if final_path.suffix == ".pt" else final_path.with_suffix(".pt")
    save_checkpoint(model, optimizer, final_path, epoch=epoch, global_step=global_step)
    logging.info(f"Saved FINAL checkpoint: {final_path}")

    # -------- Validation after training --------
    if args.val_csv:
        df_val_all = pd.read_csv(args.val_csv)

        # Build/extend cache with any unseen texts from validation slice
        val_slice = _parse_row_range(args.val_rows, len(df_val_all)) if args.val_rows else slice(0, len(df_val_all))
        df_val = df_val_all.iloc[val_slice].reset_index(drop=True)

        # Ensure cache has all unique texts in the selected slice based on mode
        if args.imdb_mode:
            # For IMDB mode, we need to create validation dataset first to get all texts
            val_ds_temp = IMDBSentimentPairs(df_val)
            val_texts = list(set(val_ds_temp.positive_texts + val_ds_temp.negative_texts + val_ds_temp.predictions))
        else:
            val_texts = list(set(
                df_val["question"].astype(str).tolist() +
                df_val["answer"].astype(str).tolist() +
                df_val["question_related"].astype(str).tolist()
            ))
        update_encoding_cache(val_texts, encoding_cache, model, batch_size=max(256, args.batch_size))

        if args.val_dump_csv:
            # Dump-only mode: no metric logging, just write per-row energies
            model.eval()
            validate_dump_rows(
                model=model,
                df=df_val,
                cache=encoding_cache,
                device=device,
                out_csv=args.val_dump_csv,
                batch_size=args.batch_size,
                margin=args.margin,
            )
            # Do NOT log summary metrics in this mode
        else:
            # Normal metrics mode over the (possibly sliced) validation set
            if args.imdb_mode:
                val_ds = IMDBSentimentPairs(df_val)
            else:
                val_ds = QAPairs(df_val)
            val_dl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=0, collate_fn=collate_text)
            model.eval()
            metrics = validate(model, val_dl, encoding_cache, device, args.margin)
            logging.info(f"[Validation] loss={metrics['loss']:.4f}, "
                        f"acc={metrics['accuracy']:.2f}%, "
                        f"posE={metrics['pos_energy_mean']:.3f}, "
                        f"negE={metrics['neg_energy_mean']:.3f}")
    else:
        logging.info("No validation CSV provided, skipping validation")


    # Optional: compute energies after finetune (using the same cache)
    if args.write_energies_after:
        out_csv = args.output_csv or (Path(args.csv).with_suffix("").as_posix() + "_energies_after.csv")
        logging.info("Writing energies after finetune …")
        model.eval()
        compute_energies_after(model, ds.orig_df, encoding_cache, device, out_csv)
        logging.info(f"Wrote: {out_csv}")

if __name__ == "__main__":
    main()
