from __future__ import annotations

import os
import random
import re
import time
from collections import Counter, defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

LABEL_MAP = {"SUPPORTS": 0, "REFUTES": 1, "NOT ENOUGH INFO": 2}
ID_TO_LABEL = {0: "SUPPORTS", 1: "REFUTES", 2: "NOT ENOUGH INFO"}
NUM_CLASSES = 3
LABEL_TOKENS = ["[SUPPORTS]", "[REFUTES]", "[NEI]"]
STOPWORDS = {
    "the", "a", "an", "of", "in", "on", "at", "for", "to", "and", "or", "is", "was", "were",
    "are", "be", "been", "by", "with", "as", "that", "this", "it", "from", "into", "about"
}
STRICT_EVIDENCE_VARIANTS = {"evidence_only_strict", "evidence_only_random_labels"}


@dataclass
class PretrainedGPT2Config:
    dataset_name: str = "copenlu/fever_gold_evidence"
    num_train_samples: int = 50_000
    num_eval_samples: int = 5_000
    max_seq_len: int = 256
    seed: int = 42
    model_name: str = "gpt2"
    freeze_lower_layers_epochs: int = 1
    batch_size: int = 16
    num_epochs: int = 5
    lr: float = 5e-5
    consistency_loss_weight: float = 0.5
    grad_clip: float = 1.0
    pooling_modes: Tuple[str, ...] = (
        "no_consistency_loss",
        "evidence_only_pooling",
        "evidence_only_strict",
        "full_sequence_pooling",
        "claim_only_pooling",
        "evidence_only_random_labels",
    )
    results_path: str = "results_fever_pretrained_gpu.csv"
    smoke_test: bool = False
    require_gpu: bool = False
    device: str = "auto"


def load_fever_data(cfg: PretrainedGPT2Config):
    from datasets import load_dataset

    print(f"Loading FEVER dataset from {cfg.dataset_name} ...")
    train_ds = load_dataset(cfg.dataset_name, split="train")
    val_ds = load_dataset(cfg.dataset_name, split="validation")

    def parse_split(hf_ds, n_samples, split_name="train"):
        rng = random.Random(cfg.seed)
        indices = list(range(len(hf_ds)))
        rng.shuffle(indices)
        samples = []
        for idx in indices:
            ex = hf_ds[idx]
            label_str = ex["label"]
            if label_str not in LABEL_MAP:
                continue
            evidence_pieces = []
            for ev in ex["evidence"]:
                if len(ev) >= 3 and ev[2]:
                    evidence_pieces.append(str(ev[2]))
            evidence_text = " ".join(evidence_pieces) if evidence_pieces else "No evidence available."
            samples.append({
                "evidence": evidence_text,
                "claim": str(ex["claim"]),
                "label": LABEL_MAP[label_str],
                "label_str": label_str,
            })
            if len(samples) >= n_samples:
                break
        print(f" {split_name}: {len(samples)} samples loaded")
        dist = Counter(s["label_str"] for s in samples)
        print(f" label distribution: {dict(dist)}")
        return samples

    return parse_split(train_ds, cfg.num_train_samples, "train"), parse_split(val_ds, cfg.num_eval_samples, "eval")


def make_smoke_data(n_train: int = 32, n_eval: int = 16, seed: int = 42):
    rng = random.Random(seed)
    labels = list(LABEL_MAP.keys())
    evidence_templates = [
        "The study showed that X is true in most cases observed.",
        "Research indicates that Y is false and has been refuted.",
        "No significant evidence was found in the literature.",
        "Experiments confirmed that Z leads to positive outcomes.",
        "Scientists have disputed the claim about W being correct.",
    ]
    claim_templates = [
        "X is commonly found to be true.",
        "Y has been proven false by scientists.",
        "There is no clear evidence about this topic.",
        "Z leads to positive results overall.",
        "W is confirmed by the scientific community.",
    ]

    def make_samples(n):
        out = []
        for _ in range(n):
            ls = rng.choice(labels)
            out.append({
                "evidence": rng.choice(evidence_templates),
                "claim": rng.choice(claim_templates),
                "label": LABEL_MAP[ls],
                "label_str": ls,
            })
        return out

    return make_samples(n_train), make_samples(n_eval)


_tokenizer = None


def get_tokenizer(model_name: str = "gpt2"):
    global _tokenizer
    if _tokenizer is None:
        from transformers import GPT2TokenizerFast
        tok = GPT2TokenizerFast.from_pretrained(model_name)
        tok.add_special_tokens({
            "bos_token": "[BOS]",
            "eos_token": "[EOS]",
            "pad_token": "[PAD]",
            "additional_special_tokens": ["[SEP]", "[LABELSEP]", "[SUPPORTS]", "[REFUTES]", "[NEI]"],
        })
        _tokenizer = tok
        print(f"Tokenizer vocabulary size after special tokens: {len(tok)}")
    return _tokenizer


class FEVERTokenDataset(Dataset):
    def __init__(self, samples: List[dict], max_seq_len: int = 256, model_name: str = "gpt2"):
        self.max_seq_len = max_seq_len
        self.tokenizer = get_tokenizer(model_name)
        bos_id = self.tokenizer.bos_token_id
        sep_id = self.tokenizer.convert_tokens_to_ids("[SEP]")
        labelsep_id = self.tokenizer.convert_tokens_to_ids("[LABELSEP]")
        eos_id = self.tokenizer.eos_token_id
        pad_id = self.tokenizer.pad_token_id
        self.records = []
        for s in samples:
            ev_ids = self.tokenizer.encode(s["evidence"], add_special_tokens=False)
            cl_ids = self.tokenizer.encode(s["claim"], add_special_tokens=False)
            label_tok_id = self.tokenizer.convert_tokens_to_ids(LABEL_TOKENS[s["label"]])
            overhead = 5
            ev_budget = max_seq_len - overhead - len(cl_ids)
            if ev_budget < 4:
                ev_budget = 4
            cl_ids = cl_ids[:max_seq_len - overhead - ev_budget]
            ev_ids = ev_ids[:ev_budget]
            ids = [bos_id] + ev_ids + [sep_id] + cl_ids + [labelsep_id] + [label_tok_id] + [eos_id]
            sep_pos = 1 + len(ev_ids)
            labelsep_pos = sep_pos + 1 + len(cl_ids)
            pad_len = max_seq_len - len(ids)
            attn = [1] * len(ids) + [0] * max(0, pad_len)
            ids = (ids + [pad_id] * max(0, pad_len))[:max_seq_len]
            attn = attn[:max_seq_len]
            self.records.append({
                "input_ids": torch.tensor(ids, dtype=torch.long),
                "attention_mask": torch.tensor(attn, dtype=torch.long),
                "sep_pos": sep_pos,
                "labelsep_pos": labelsep_pos,
                "label": s["label"],
                "label_str": s["label_str"],
                "evidence": s["evidence"],
                "claim": s["claim"],
            })

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        return self.records[idx]


def collate_fn(batch):
    return {
        "input_ids": torch.stack([b["input_ids"] for b in batch]),
        "attention_mask": torch.stack([b["attention_mask"] for b in batch]),
        "labels": torch.tensor([b["label"] for b in batch], dtype=torch.long),
        "sep_pos": [b["sep_pos"] for b in batch],
        "labelsep_pos": [b["labelsep_pos"] for b in batch],
    }


class PretrainedGPT2WithConsistencyHead(nn.Module):
    def __init__(self, model_name: str = "gpt2", vocab_size_after_special: int = None):
        super().__init__()
        from transformers import GPT2LMHeadModel
        print(f"Loading pretrained GPT-2 backbone: {model_name}")
        self.gpt2 = GPT2LMHeadModel.from_pretrained(model_name)
        if vocab_size_after_special is not None:
            self.gpt2.resize_token_embeddings(vocab_size_after_special)
            print(f" Token embeddings resized to {vocab_size_after_special}")
        hidden_size = self.gpt2.config.hidden_size
        self.consistency_head = nn.Linear(hidden_size, NUM_CLASSES)
        nn.init.normal_(self.consistency_head.weight, std=0.02)
        nn.init.zeros_(self.consistency_head.bias)

    @property
    def transformer_blocks(self):
        return self.gpt2.transformer.h

    def freeze_lower_layers(self, freeze: bool = True):
        blocks = self.transformer_blocks
        n_freeze = len(blocks) // 2
        for i, block in enumerate(blocks):
            for p in block.parameters():
                p.requires_grad = (not freeze) if i < n_freeze else True
        for p in self.gpt2.transformer.wte.parameters():
            p.requires_grad = not freeze
        for p in self.gpt2.transformer.wpe.parameters():
            p.requires_grad = not freeze
        print(f" Lower {n_freeze}/{len(blocks)} transformer blocks + embeddings {'frozen' if freeze else 'unfrozen'}")

    def _pool(self, hidden_states, pooling_mode, sep_pos, labelsep_pos, attention_mask):
        B, T, _ = hidden_states.shape
        pooled = []
        for i in range(B):
            if pooling_mode in {"evidence_only_pooling", "evidence_only_strict", "evidence_only_random_labels"}:
                start, end = 1, max(2, sep_pos[i])
            elif pooling_mode == "claim_only_pooling":
                start, end = sep_pos[i] + 1, max(sep_pos[i] + 2, labelsep_pos[i])
            elif pooling_mode == "full_sequence_pooling":
                positions = attention_mask[i].nonzero(as_tuple=True)[0]
                start, end = (positions[0].item(), positions[-1].item() + 1) if len(positions) > 0 else (0, T)
            else:
                start, end = 0, T
            start = max(0, min(start, T - 1))
            end = max(start + 1, min(end, T))
            pooled.append(hidden_states[i, start:end, :].mean(dim=0))
        return torch.stack(pooled, dim=0)

    def _strict_evidence_hidden_states(self, hidden_states, sep_pos):
        masked = hidden_states.clone()
        for i in range(masked.size(0)):
            if sep_pos[i] < masked.size(1):
                masked[i, sep_pos[i]:, :] = 0.0
        return masked

    def forward(self, input_ids, attention_mask, pooling_mode="evidence_only_pooling", sep_pos=None, labelsep_pos=None):
        B, T = input_ids.shape
        device = input_ids.device
        outputs = self.gpt2.transformer(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=False)
        hidden_states = outputs.last_hidden_state
        if labelsep_pos is not None:
            ls_idx = torch.tensor(labelsep_pos, dtype=torch.long, device=device).clamp(0, T - 2)
            h_at_ls = hidden_states[torch.arange(B, device=device), ls_idx, :]
        else:
            h_at_ls = hidden_states[:, -2, :]
        lm_logits = self.gpt2.lm_head(h_at_ls)
        cons_hidden = self._strict_evidence_hidden_states(hidden_states, sep_pos) if pooling_mode in STRICT_EVIDENCE_VARIANTS else hidden_states
        pooled = self._pool(cons_hidden, pooling_mode, sep_pos, labelsep_pos, attention_mask)
        cons_logits = self.consistency_head(pooled)
        return lm_logits, cons_logits

    def get_full_lm_logits(self, input_ids, attention_mask=None):
        outputs = self.gpt2.transformer(input_ids=input_ids, attention_mask=attention_mask)
        return self.gpt2.lm_head(outputs.last_hidden_state)


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def resolve_device(cfg: PretrainedGPT2Config) -> torch.device:
    if cfg.require_gpu and not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available but --require-gpu was set. Please run on a machine with a CUDA-capable GPU.")
    if cfg.device == "auto":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return torch.device(cfg.device)


def maybe_randomize_labels(labels, pooling_mode, seed_offset=0):
    if pooling_mode != "evidence_only_random_labels":
        return labels
    shift = (seed_offset % (NUM_CLASSES - 1)) + 1
    return (labels + shift) % NUM_CLASSES


def train_epoch(model, loader, optimizer, device, pooling_mode, consistency_loss_weight, scaler=None, grad_clip=1.0):
    model.train()
    total_lm = total_cons = 0.0
    n_batches = 0
    use_amp = (scaler is not None) and device.type == "cuda"
    for batch_idx, batch in enumerate(loader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        sep_pos = batch["sep_pos"]
        labelsep_pos = batch["labelsep_pos"]
        B, T = input_ids.shape
        optimizer.zero_grad()
        with torch.amp.autocast(device_type="cuda", enabled=use_amp):
            lm_logits, cons_logits = model(input_ids, attention_mask, pooling_mode=pooling_mode, sep_pos=sep_pos, labelsep_pos=labelsep_pos)
            ls_idx = torch.tensor(labelsep_pos, dtype=torch.long, device=device).clamp(0, T - 2)
            label_token_targets = input_ids[torch.arange(B, device=device), (ls_idx + 1).clamp(0, T - 1)]
            lm_loss = F.cross_entropy(lm_logits, label_token_targets)
            cons_targets = maybe_randomize_labels(labels, pooling_mode, batch_idx)
            cons_loss = F.cross_entropy(cons_logits, cons_targets)
            loss = lm_loss if pooling_mode == "no_consistency_loss" else lm_loss + consistency_loss_weight * cons_loss
        if use_amp:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
        total_lm += lm_loss.item()
        total_cons += cons_loss.item()
        n_batches += 1
    return total_lm / max(1, n_batches), total_cons / max(1, n_batches)


@torch.no_grad()
def evaluate_classification(model, loader, device, pooling_mode):
    model.eval()
    correct = total = 0
    for batch in loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        sep_pos = batch["sep_pos"]
        labelsep_pos = batch["labelsep_pos"]
        _, cons_logits = model(input_ids, attention_mask, pooling_mode=pooling_mode, sep_pos=sep_pos, labelsep_pos=labelsep_pos)
        preds = cons_logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return correct / max(1, total)


@torch.no_grad()
def evaluate_generation_accuracy(model, eval_samples, device, max_seq_len=256, n_samples=500, model_name="gpt2"):
    model.eval()
    tokenizer = get_tokenizer(model_name)
    bos_id = tokenizer.bos_token_id
    sep_id = tokenizer.convert_tokens_to_ids("[SEP]")
    labelsep_id = tokenizer.convert_tokens_to_ids("[LABELSEP]")
    label_tok_ids = [tokenizer.convert_tokens_to_ids(t) for t in LABEL_TOKENS]
    samples_to_use = eval_samples[:n_samples]
    correct = 0
    for s in samples_to_use:
        ev_ids = tokenizer.encode(s["evidence"], add_special_tokens=False)
        cl_ids = tokenizer.encode(s["claim"], add_special_tokens=False)
        label_tok_id = label_tok_ids[s["label"]]
        overhead = 3
        ev_budget = max_seq_len - overhead - len(cl_ids)
        if ev_budget < 4:
            ev_budget = 4
        cl_ids = cl_ids[:max_seq_len - overhead - ev_budget]
        ev_ids = ev_ids[:ev_budget]
        prompt_ids = [bos_id] + ev_ids + [sep_id] + cl_ids + [labelsep_id]
        prompt_t = torch.tensor([prompt_ids], dtype=torch.long, device=device)
        attn_t = torch.ones_like(prompt_t)
        full_logits = model.get_full_lm_logits(prompt_t, attn_t)
        pred_id = full_logits[0, -1, :].argmax().item()
        if pred_id == label_tok_id:
            correct += 1
    return correct / max(1, len(samples_to_use))


def _claim_signature(text: str):
    toks = re.findall(r"[a-z0-9]+", text.lower())
    return {t for t in toks if t not in STOPWORDS and len(t) > 2}


def build_matched_claim_candidates(records, max_per_item=32):
    buckets = defaultdict(list)
    signatures = []
    for i, r in enumerate(records):
        sig = _claim_signature(r["claim"])
        signatures.append(sig)
        for tok in list(sig)[:12]:
            buckets[tok].append(i)
    candidates = {}
    for i, r in enumerate(records):
        orig_label = r["label"]
        sig = signatures[i]
        pooled = set()
        for tok in list(sig)[:12]:
            pooled.update(buckets.get(tok, []))
        scored = []
        for j in pooled:
            if j == i or records[j]["label"] == orig_label:
                continue
            other_sig = signatures[j]
            if not sig or not other_sig:
                continue
            inter = len(sig & other_sig)
            union = len(sig | other_sig)
            score = inter / union if union else 0.0
            if score > 0:
                scored.append((score, j))
        scored.sort(reverse=True)
        candidates[i] = [j for _, j in scored[:max_per_item]]
    return candidates


@torch.no_grad()
def evaluate_counterfactual_swap(model, eval_dataset, device, pooling_mode, n_pairs=500, seed=42, model_name="gpt2", matched_claim_only=False):
    model.eval()
    tokenizer = get_tokenizer(model_name)
    rng = random.Random(seed)
    records = eval_dataset.records
    n = min(n_pairs, len(records))
    indices = list(range(len(records)))
    rng.shuffle(indices)
    eval_indices = indices[:n]
    label_to_idxs = {0: [], 1: [], 2: []}
    for i, r in enumerate(records):
        label_to_idxs[r["label"]].append(i)
    matched_candidates = build_matched_claim_candidates(records) if matched_claim_only else None
    bos_id = tokenizer.bos_token_id
    sep_id = tokenizer.convert_tokens_to_ids("[SEP]")
    labelsep_id = tokenizer.convert_tokens_to_ids("[LABELSEP]")
    pad_id = tokenizer.pad_token_id
    eos_id = tokenizer.eos_token_id
    label_tok_ids = [tokenizer.convert_tokens_to_ids(t) for t in LABEL_TOKENS]
    T = eval_dataset.max_seq_len
    cls_follows_swap = cls_follows_orig = gen_follows_swap = gen_follows_orig = n_valid = 0
    for idx in eval_indices:
        orig = records[idx]
        orig_label = orig["label"]
        if matched_claim_only:
            candidates = matched_candidates.get(idx, [])
            if not candidates:
                continue
            swap_idx = rng.choice(candidates)
            swap_rec = records[swap_idx]
            swap_label = swap_rec["label"]
        else:
            other_labels = [l for l in range(NUM_CLASSES) if l != orig_label]
            swap_label = rng.choice(other_labels)
            candidates = label_to_idxs[swap_label]
            if not candidates:
                continue
            swap_rec = records[rng.choice(candidates)]
        ev_ids = tokenizer.encode(swap_rec["evidence"], add_special_tokens=False)
        cl_ids = tokenizer.encode(orig["claim"], add_special_tokens=False)
        overhead = 5
        ev_budget = T - overhead - len(cl_ids)
        if ev_budget < 4:
            ev_ids = ev_ids[:T - 9]
            cl_ids = cl_ids[:4]
        else:
            ev_ids = ev_ids[:ev_budget]
        new_label_tok_id = label_tok_ids[swap_label]
        ids = [bos_id] + ev_ids + [sep_id] + cl_ids + [labelsep_id] + [new_label_tok_id] + [eos_id]
        sep_pos_new = 1 + len(ev_ids)
        labelsep_pos_new = sep_pos_new + 1 + len(cl_ids)
        pad_len = T - len(ids)
        attn = [1] * len(ids) + [0] * max(0, pad_len)
        ids_pad = (ids + [pad_id] * max(0, pad_len))[:T]
        attn = attn[:T]
        swapped_ids = torch.tensor([ids_pad], dtype=torch.long, device=device)
        swapped_mask = torch.tensor([attn], dtype=torch.long, device=device)
        _, cons_logits = model(swapped_ids, swapped_mask, pooling_mode=pooling_mode, sep_pos=[sep_pos_new], labelsep_pos=[labelsep_pos_new])
        pred_cls = cons_logits[0].argmax().item()
        prompt_ids = [bos_id] + ev_ids + [sep_id] + cl_ids + [labelsep_id]
        prompt_t = torch.tensor([prompt_ids], dtype=torch.long, device=device)
        attn_prompt = torch.ones_like(prompt_t)
        full_logits = model.get_full_lm_logits(prompt_t, attn_prompt)
        pred_gen_tok = full_logits[0, -1, :].argmax().item()
        pred_gen_label = next((li for li, ltid in enumerate(label_tok_ids) if pred_gen_tok == ltid), None)
        if pred_cls == swap_label:
            cls_follows_swap += 1
        if pred_cls == orig_label:
            cls_follows_orig += 1
        if pred_gen_label is not None:
            if pred_gen_label == swap_label:
                gen_follows_swap += 1
            if pred_gen_label == orig_label:
                gen_follows_orig += 1
        n_valid += 1
    d = max(1, n_valid)
    return cls_follows_swap / d, cls_follows_orig / d, gen_follows_swap / d, gen_follows_orig / d


@torch.no_grad()
def evaluate_shuffled_pairing(model, eval_dataset, device, pooling_mode, n_pairs=500, seed=42, model_name="gpt2"):
    model.eval()
    tokenizer = get_tokenizer(model_name)
    rng = random.Random(seed + 1)
    records = eval_dataset.records
    n = min(n_pairs, len(records))
    indices = list(range(len(records)))
    rng.shuffle(indices)
    eval_indices = indices[:n]
    bos_id = tokenizer.bos_token_id
    sep_id = tokenizer.convert_tokens_to_ids("[SEP]")
    labelsep_id = tokenizer.convert_tokens_to_ids("[LABELSEP]")
    pad_id = tokenizer.pad_token_id
    eos_id = tokenizer.eos_token_id
    label_tok_ids = [tokenizer.convert_tokens_to_ids(t) for t in LABEL_TOKENS]
    T = eval_dataset.max_seq_len
    offset = max(1, n // 3)
    shuffled_pairs = [(eval_indices[i], eval_indices[(i + offset) % n]) for i in range(n)]
    cls_correct = gen_correct = n_valid = 0
    for orig_idx, ev_idx in shuffled_pairs:
        orig = records[orig_idx]
        ev_rec = records[ev_idx]
        orig_label = orig["label"]
        ev_ids = tokenizer.encode(ev_rec["evidence"], add_special_tokens=False)
        cl_ids = tokenizer.encode(orig["claim"], add_special_tokens=False)
        overhead = 5
        ev_budget = T - overhead - len(cl_ids)
        if ev_budget < 4:
            ev_ids = ev_ids[:T - 9]
            cl_ids = cl_ids[:4]
        else:
            ev_ids = ev_ids[:ev_budget]
        label_tok_id = label_tok_ids[orig_label]
        ids = [bos_id] + ev_ids + [sep_id] + cl_ids + [labelsep_id] + [label_tok_id] + [eos_id]
        sep_pos_new = 1 + len(ev_ids)
        labelsep_pos_new = sep_pos_new + 1 + len(cl_ids)
        pad_len = T - len(ids)
        attn = [1] * len(ids) + [0] * max(0, pad_len)
        ids_pad = (ids + [pad_id] * max(0, pad_len))[:T]
        attn = attn[:T]
        shuf_ids = torch.tensor([ids_pad], dtype=torch.long, device=device)
        shuf_mask = torch.tensor([attn], dtype=torch.long, device=device)
        _, cons_logits = model(shuf_ids, shuf_mask, pooling_mode=pooling_mode, sep_pos=[sep_pos_new], labelsep_pos=[labelsep_pos_new])
        pred_cls = cons_logits[0].argmax().item()
        prompt_ids = [bos_id] + ev_ids + [sep_id] + cl_ids + [labelsep_id]
        prompt_t = torch.tensor([prompt_ids], dtype=torch.long, device=device)
        attn_prompt = torch.ones_like(prompt_t)
        full_logits = model.get_full_lm_logits(prompt_t, attn_prompt)
        pred_gen_tok = full_logits[0, -1, :].argmax().item()
        pred_gen_label = next((li for li, ltid in enumerate(label_tok_ids) if pred_gen_tok == ltid), None)
        if pred_cls == orig_label:
            cls_correct += 1
        if pred_gen_label == orig_label:
            gen_correct += 1
        n_valid += 1
    d = max(1, n_valid)
    return cls_correct / d, gen_correct / d


def run_pretrained_gpt2_experiment(cfg: PretrainedGPT2Config) -> pd.DataFrame:
    set_seed(cfg.seed)
    device = resolve_device(cfg)
    print(f"Using device: {device}")
    scaler = torch.cuda.amp.GradScaler() if device.type == "cuda" else None
    if device.type == "cuda":
        print("Mixed precision (AMP) enabled.")
    train_samples, eval_samples = make_smoke_data(min(cfg.num_train_samples, 32), min(cfg.num_eval_samples, 16), cfg.seed) if cfg.smoke_test else load_fever_data(cfg)
    print("Building tokenized datasets...")
    train_dataset = FEVERTokenDataset(train_samples, cfg.max_seq_len, cfg.model_name)
    eval_dataset = FEVERTokenDataset(eval_samples, cfg.max_seq_len, cfg.model_name)
    tokenizer = get_tokenizer(cfg.model_name)
    vocab_size = len(tokenizer)
    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0, pin_memory=(device.type == "cuda"))
    eval_loader = DataLoader(eval_dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
    variants = ["evidence_only_strict"] if cfg.smoke_test else list(cfg.pooling_modes)
    num_epochs = 1 if cfg.smoke_test else cfg.num_epochs
    all_results = []
    for variant in variants:
        print(f"\n{'='*60}\nVariant: {variant}\n{'='*60}")
        set_seed(cfg.seed)
        model = PretrainedGPT2WithConsistencyHead(model_name=cfg.model_name, vocab_size_after_special=vocab_size).to(device)
        optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=cfg.lr, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=cfg.lr * 0.1)
        final_lm_loss = final_cons_loss = 0.0
        t_start = time.time()
        for epoch in range(1, num_epochs + 1):
            if cfg.freeze_lower_layers_epochs > 0:
                if epoch == 1:
                    model.freeze_lower_layers(True)
                elif epoch == cfg.freeze_lower_layers_epochs + 1:
                    model.freeze_lower_layers(False)
                    optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=cfg.lr * 0.5, weight_decay=0.01)
                    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(1, num_epochs - epoch + 1), eta_min=cfg.lr * 0.05)
            final_lm_loss, final_cons_loss = train_epoch(model, train_loader, optimizer, device, variant, cfg.consistency_loss_weight, scaler, cfg.grad_clip)
            scheduler.step()
            print(f" Epoch {epoch}/{num_epochs}: lm_loss={final_lm_loss:.4f}, cons_loss={final_cons_loss:.4f}, elapsed={time.time()-t_start:.1f}s")
        eval_n_pairs = min(500, len(eval_samples))
        eval_n_gen = min(500, len(eval_samples))
        cls_acc = evaluate_classification(model, eval_loader, device, variant)
        gen_acc = evaluate_generation_accuracy(model, eval_samples, device, cfg.max_seq_len, eval_n_gen, cfg.model_name)
        cf_cls_swap, cf_cls_orig, cf_gen_swap, cf_gen_orig = evaluate_counterfactual_swap(model, eval_dataset, device, variant, eval_n_pairs, cfg.seed, cfg.model_name, matched_claim_only=False)
        mc_cls_swap, mc_cls_orig, _, _ = evaluate_counterfactual_swap(model, eval_dataset, device, variant, eval_n_pairs, cfg.seed, cfg.model_name, matched_claim_only=True)
        shuf_cls_acc, shuf_gen_acc = evaluate_shuffled_pairing(model, eval_dataset, device, variant, eval_n_pairs, cfg.seed, cfg.model_name)
        all_results.append({
            "variant": variant,
            "final_lm_loss": round(final_lm_loss, 4),
            "final_cons_loss": round(final_cons_loss, 4),
            "gen_claim_acc": round(gen_acc, 4),
            "cls_claim_acc": round(cls_acc, 4),
            "cfact_gen_follows_swap": round(cf_gen_swap, 4),
            "cfact_gen_follows_orig": round(cf_gen_orig, 4),
            "cfact_cls_follows_swap": round(cf_cls_swap, 4),
            "cfact_cls_follows_orig": round(cf_cls_orig, 4),
            "matched_cfact_cls_follows_swap": round(mc_cls_swap, 4),
            "matched_cfact_cls_follows_orig": round(mc_cls_orig, 4),
            "shuffled_gen_acc": round(shuf_gen_acc, 4),
            "shuffled_cls_acc": round(shuf_cls_acc, 4),
        })
        pd.DataFrame(all_results).to_csv(cfg.results_path, index=False)
    df = pd.DataFrame(all_results)
    df.to_csv(cfg.results_path, index=False)
    print(f"\nFinal results saved to {cfg.results_path}")
    return df


def format_results_markdown(df: pd.DataFrame, cfg: PretrainedGPT2Config) -> str:
    cols = [
        "variant", "final_lm_loss", "final_cons_loss", "gen_claim_acc", "cls_claim_acc",
        "cfact_cls_follows_swap", "cfact_cls_follows_orig", "matched_cfact_cls_follows_swap",
        "matched_cfact_cls_follows_orig", "cfact_gen_follows_swap", "cfact_gen_follows_orig",
        "shuffled_cls_acc", "shuffled_gen_acc",
    ]
    lines = [
        "# FEVER Pretrained GPT-2 Claim-Consistency Coupling Results\n",
        "## Experiment Setup\n",
        f"- Model: `{cfg.model_name}` (pretrained HuggingFace GPT-2 backbone)",
        f"- Dataset: `{cfg.dataset_name}` (copenlu/fever_gold_evidence)",
        f"- Train samples: {cfg.num_train_samples:,} | Eval samples: {cfg.num_eval_samples:,}",
        f"- max_seq_len: {cfg.max_seq_len}",
        f"- Epochs: {cfg.num_epochs} | Batch size: {cfg.batch_size} | LR: {cfg.lr}",
        f"- Consistency loss weight: {cfg.consistency_loss_weight}",
        f"- Freeze lower layers epochs: {cfg.freeze_lower_layers_epochs}",
        "",
        "## Sequence Format\n",
        "```",
        "[BOS] <evidence_passage> [SEP] <claim> [LABELSEP] <label_token> [EOS]",
        "```\n",
        "- **Evidence pooling**: mean of hidden states at positions 1..sep_pos-1 (before [SEP])",
        "- **Evidence strict**: same pooling plus zeroed non-evidence hidden states for the consistency path",
        "- **Claim pooling**: mean over positions sep_pos+1..labelsep_pos-1",
        "- **Full pooling**: mean over all non-pad tokens",
        "- **Random-label control**: evidence-only strict path trained against permuted consistency labels",
        "",
        "## Results Table\n",
        "| " + " | ".join(cols) + " |",
        "|" + "|".join(["---"] * len(cols)) + "|",
    ]
    for _, row in df.iterrows():
        vals = [str(row.get("variant", "?"))] + [f"{float(row.get(c, float('nan'))):.4f}" if pd.notna(row.get(c, float('nan'))) else "N/A" for c in cols[1:]]
        lines.append("| " + " | ".join(vals) + " |")
    lines += [
        "",
        "## Metric Descriptions\n",
        "- `matched_cfact_cls_follows_swap/orig`: counterfactual classification metrics using swapped evidence from lexically similar claims with different labels",
        "- `evidence_only_strict`: consistency head can only use evidence-region activations after non-evidence hidden states are zeroed in the consistency path",
        "- `evidence_only_random_labels`: sanity control where the consistency head is trained on permuted labels while LM loss remains intact",
    ]
    return "\n".join(lines)


def build_pretrained_comparison_table(fever_df: pd.DataFrame, synthetic_path: str = "results_comparison_hard.csv"):
    if not os.path.exists(synthetic_path):
        print(f"Synthetic results not found at {synthetic_path} — skipping comparison.")
        return None
    synth_df = pd.read_csv(synthetic_path)
    synth_df["experiment"] = "synthetic_hard"
    fever_df2 = fever_df.copy()
    fever_df2["experiment"] = "fever_pretrained_gpt2"
    if "cls_claim_acc (rationale_pool)" in synth_df.columns:
        synth_df = synth_df.rename(columns={"cls_claim_acc (rationale_pool)": "cls_claim_acc"})
    common_cols = ["experiment", "variant", "final_lm_loss", "final_cons_loss", "gen_claim_acc", "cls_claim_acc", "cfact_cls_follows_swap", "cfact_cls_follows_orig", "shuffled_cls_acc", "shuffled_gen_acc"]
    def safe_select(df, cols):
        for c in cols:
            if c not in df.columns:
                df = df.copy()
                df[c] = float("nan")
        return df[cols]
    combined = pd.concat([safe_select(fever_df2, common_cols), safe_select(synth_df, common_cols)], ignore_index=True)
    lines = ["# FEVER Pretrained GPT-2 vs Synthetic Results Comparison\n", "| " + " | ".join(common_cols) + " |", "|" + "|".join(["---"] * len(common_cols)) + "|"]
    for _, row in combined.iterrows():
        vals = [str(row.get(c, "N/A")) if c in ("experiment", "variant") else (f"{float(row.get(c, float('nan'))):.4f}" if pd.notna(row.get(c, float('nan'))) else "N/A") for c in common_cols]
        lines.append("| " + " | ".join(vals) + " |")
    return combined, "\n".join(lines)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="FEVER pretrained GPT-2 claim-consistency coupling experiment")
    parser.add_argument("--model_name", type=str, default="gpt2")
    parser.add_argument("--train_samples", type=int, default=50_000)
    parser.add_argument("--eval_samples", type=int, default=5_000)
    parser.add_argument("--max_seq_len", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--consistency_loss_weight", type=float, default=0.5)
    parser.add_argument("--freeze_lower_layers_epochs", type=int, default=1)
    parser.add_argument("--variants", nargs="+", default=["no_consistency_loss", "evidence_only_pooling", "evidence_only_strict", "full_sequence_pooling", "claim_only_pooling", "evidence_only_random_labels"])
    parser.add_argument("--output_csv", type=str, default="results_fever_pretrained_gpu.csv")
    parser.add_argument("--smoke_test", action="store_true")
    parser.add_argument("--require_gpu", action="store_true")
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()
    cfg = PretrainedGPT2Config(model_name=args.model_name, num_train_samples=args.train_samples, num_eval_samples=args.eval_samples, max_seq_len=args.max_seq_len, num_epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, consistency_loss_weight=args.consistency_loss_weight, freeze_lower_layers_epochs=args.freeze_lower_layers_epochs, pooling_modes=tuple(args.variants), results_path=args.output_csv, smoke_test=args.smoke_test, require_gpu=args.require_gpu, seed=args.seed)
    df = run_pretrained_gpt2_experiment(cfg)
    print(df.to_string())
    md_path = args.output_csv.replace(".csv", ".md")
    with open(md_path, "w") as f:
        f.write(format_results_markdown(df, cfg))
    result = build_pretrained_comparison_table(df, "results_comparison_hard.csv")
    if result is not None:
        comp_df, comp_md = result
        comp_df.to_csv("results_fever_pretrained_vs_synthetic.csv", index=False)
        with open("results_fever_pretrained_vs_synthetic.md", "w") as f:
            f.write(comp_md)
