from __future__ import annotations

import argparse
import math
import random
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset


SPECIAL_TOKENS = [
    "[PAD]", "[BOS]", "[SEP]", "[RAT]", "[CLAIM]", "[EOS]",
    "S0", "S1", "S2", "S3", "S4", "S5", "S6", "S7",
    "V0", "V1", "V2", "V3", "V4", "V5", "V6", "V7", "V8", "V9",
]

INPUT_PREFIX = ["[BOS]"]
RAT_PREFIX = ["[RAT]"]
CLAIM_PREFIX = ["[CLAIM]"]


@dataclass
class Config:
    seed: int = 42
    num_states: int = 8
    num_train: int = 2048
    num_eval: int = 512
    batch_size: int = 64
    epochs: int = 20
    lr: float = 3e-4
    weight_decay: float = 0.01
    d_model: int = 128
    n_layers: int = 2
    n_heads: int = 4
    d_ff: int = 256
    dropout: float = 0.1
    consistency_weight: float = 0.5
    rationale_loss_weight: float = 1.0
    claim_loss_weight: float = 1.0
    output_csv: str = "generated_rationale_scalar_results.csv"


STATE_SCALARS = {
    0: 0.05,
    1: 0.18,
    2: 0.31,
    3: 0.44,
    4: 0.57,
    5: 0.70,
    6: 0.83,
    7: 0.96,
}

STATE_TO_BIN = {s: min(9, max(0, round(v * 9))) for s, v in STATE_SCALARS.items()}

RATIONAL_TEMPLATES = {
    0: [
        ["low", "stable", "edge", "quiet"],
        ["weak", "calm", "small", "sealed"],
        ["thin", "cold", "flat", "settled"],
        ["narrow", "dim", "fixed", "slow"],
    ],
    1: [
        ["low", "mixed", "edge", "tense"],
        ["weak", "cool", "small", "open"],
        ["thin", "drift", "flat", "fragile"],
        ["narrow", "odd", "fixed", "brittle"],
    ],
    2: [
        ["light", "balanced", "lane", "active"],
        ["mild", "steady", "medium", "mobile"],
        ["soft", "warm", "plain", "ready"],
        ["clear", "firm", "open", "alert"],
    ],
    3: [
        ["light", "mixed", "lane", "sharp"],
        ["mild", "rising", "medium", "live"],
        ["soft", "warm", "plain", "swing"],
        ["clear", "firm", "open", "press"],
    ],
    4: [
        ["solid", "balanced", "center", "active"],
        ["strong", "steady", "broad", "mobile"],
        ["thick", "warm", "round", "ready"],
        ["deep", "firm", "linked", "alert"],
    ],
    5: [
        ["solid", "rising", "center", "sharp"],
        ["strong", "hot", "broad", "live"],
        ["thick", "bright", "round", "swing"],
        ["deep", "charged", "linked", "press"],
    ],
    6: [
        ["heavy", "rising", "core", "forcing"],
        ["dense", "hot", "wide", "driving"],
        ["large", "bright", "arched", "severe"],
        ["deep", "urgent", "locked", "strong"],
    ],
    7: [
        ["heavy", "surging", "core", "decisive"],
        ["dense", "bright", "wide", "crushing"],
        ["large", "hot", "arched", "dominant"],
        ["deep", "urgent", "locked", "final"],
    ],
}


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def build_vocab() -> Tuple[Dict[str, int], Dict[int, str]]:
    vocab = set(SPECIAL_TOKENS)
    for templates in RATIONAL_TEMPLATES.values():
        for t in templates:
            vocab.update(t)
    tok2id = {tok: i for i, tok in enumerate(sorted(vocab))}
    id2tok = {i: tok for tok, i in tok2id.items()}
    return tok2id, id2tok


def oracle_scalar_from_state(state: int) -> float:
    return STATE_SCALARS[state]


class SyntheticDataset(Dataset):
    def __init__(self, n: int, tok2id: Dict[str, int], seed: int = 42):
        self.tok2id = tok2id
        self.pad_id = tok2id["[PAD]"]
        rng = random.Random(seed)
        self.samples = []
        for _ in range(n):
            state = rng.randrange(8)
            template = rng.choice(RATIONAL_TEMPLATES[state])
            claim_bin = STATE_TO_BIN[state]
            seq = INPUT_PREFIX + [f"S{state}"] + RAT_PREFIX + template + CLAIM_PREFIX + [f"V{claim_bin}"] + ["[EOS]"]
            ids = [tok2id[t] for t in seq]
            self.samples.append({
                "state": state,
                "input_ids": torch.tensor(ids, dtype=torch.long),
                "rationale_tokens": template,
                "claim_bin": claim_bin,
                "scalar": oracle_scalar_from_state(state),
                "pad_id": self.pad_id,
            })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


def collate_fn(batch):
    max_len = max(len(x["input_ids"]) for x in batch)
    pad_id = batch[0]["pad_id"]
    input_ids, targets, attn = [], [], []
    states, claim_bins, scalars = [], [], []
    rat_positions, claim_positions = [], []

    for ex in batch:
        ids = ex["input_ids"]
        pad_len = max_len - len(ids)
        padded = F.pad(ids, (0, pad_len), value=pad_id)
        input_ids.append(padded[:-1])
        targets.append(padded[1:])
        attn.append(torch.tensor([1] * (len(ids) - 1) + [0] * pad_len, dtype=torch.long))
        states.append(ex["state"])
        claim_bins.append(ex["claim_bin"])
        scalars.append(ex["scalar"])

        rat_start = 3
        rat_end = rat_start + 4
        claim_pos = rat_end + 1
        rat_positions.append((rat_start - 1, rat_end - 1))
        claim_positions.append(claim_pos - 1)

    return {
        "input_ids": torch.stack(input_ids),
        "targets": torch.stack(targets),
        "attention_mask": torch.stack(attn),
        "states": torch.tensor(states, dtype=torch.long),
        "claim_bins": torch.tensor(claim_bins, dtype=torch.long),
        "scalars": torch.tensor(scalars, dtype=torch.float32),
        "rat_positions": rat_positions,
        "claim_positions": claim_positions,
    }


class DecoderBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x, key_padding_mask=None, attn_mask=None):
        h = self.ln1(x)
        a, _ = self.attn(h, h, h, key_padding_mask=key_padding_mask, attn_mask=attn_mask, need_weights=False)
        x = x + a
        x = x + self.ff(self.ln2(x))
        return x


class TinyVerifierLM(nn.Module):
    def __init__(self, vocab_size: int, cfg: Config):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(32, cfg.d_model)
        self.blocks = nn.ModuleList([
            DecoderBlock(cfg.d_model, cfg.n_heads, cfg.d_ff, cfg.dropout) for _ in range(cfg.n_layers)
        ])
        self.ln_f = nn.LayerNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, vocab_size, bias=False)
        self.scalar_head = nn.Linear(cfg.d_model, 1)
        self.bin_head = nn.Linear(cfg.d_model, 10)
        self.lm_head.weight = self.token_emb.weight
        self.apply(self._init)

    def _init(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def causal_mask(self, T, device):
        return torch.triu(torch.full((T, T), float("-inf"), device=device), diagonal=1)

    def forward(self, input_ids, attention_mask, rat_positions, claim_positions):
        B, T = input_ids.shape
        pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, T)
        x = self.token_emb(input_ids) + self.pos_emb(pos)
        key_padding_mask = attention_mask == 0
        attn_mask = self.causal_mask(T, input_ids.device)
        for blk in self.blocks:
            x = blk(x, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        h = self.ln_f(x)
        lm_logits = self.lm_head(h)

        pooled = []
        claim_h = []
        for i in range(B):
            rs, re = rat_positions[i]
            pooled.append(h[i, rs:re, :].mean(dim=0))
            claim_h.append(h[i, claim_positions[i], :])
        pooled = torch.stack(pooled)
        claim_h = torch.stack(claim_h)
        scalar_pred = self.scalar_head(claim_h).squeeze(-1)
        bin_logits = self.bin_head(pooled)
        return lm_logits, scalar_pred, bin_logits


@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total = 0
    token_correct = 0
    claim_bin_correct = 0
    scalar_mse = 0.0
    swap_follow = 0
    orig_follow = 0

    for batch in loader:
        input_ids = batch["input_ids"].to(device)
        targets = batch["targets"].to(device)
        attn = batch["attention_mask"].to(device)
        claim_bins = batch["claim_bins"].to(device)
        scalars = batch["scalars"].to(device)
        rat_positions = batch["rat_positions"]
        claim_positions = batch["claim_positions"]

        lm_logits, scalar_pred, bin_logits = model(input_ids, attn, rat_positions, claim_positions)
        pred_tokens = lm_logits.argmax(dim=-1)
        mask = attn.bool()
        token_correct += ((pred_tokens == targets) & mask).sum().item()
        total += mask.sum().item()
        claim_bin_correct += (bin_logits.argmax(dim=-1) == claim_bins).sum().item()
        scalar_mse += F.mse_loss(torch.sigmoid(scalar_pred), scalars, reduction="sum").item()

    token_acc = token_correct / max(1, total)
    cls_acc = claim_bin_correct / len(loader.dataset)
    mse = scalar_mse / len(loader.dataset)
    return {
        "token_acc": token_acc,
        "claim_bin_acc": cls_acc,
        "scalar_mse": mse,
    }


@torch.no_grad()
def evaluate_counterfactual(model, dataset, tok2id, device, n=256, seed=42):
    model.eval()
    rng = random.Random(seed)
    idxs = list(range(len(dataset)))
    rng.shuffle(idxs)
    idxs = idxs[:n]
    follows_swap = 0
    follows_orig = 0
    mse_shift = 0.0

    for idx in idxs:
        a = dataset[idx]
        candidates = [j for j in range(len(dataset)) if dataset[j]["state"] != a["state"]]
        b = dataset[rng.choice(candidates)]

        swapped_seq = INPUT_PREFIX + [f"S{a['state']}"] + RAT_PREFIX + b["rationale_tokens"] + CLAIM_PREFIX + [f"V{a['claim_bin']}"] + ["[EOS]"]
        ids = torch.tensor([[tok2id[t] for t in swapped_seq[:-1]]], dtype=torch.long, device=device)
        attn = torch.ones_like(ids)
        rat_positions = [(2, 6)]
        claim_positions = [7]
        _, scalar_pred, bin_logits = model(ids, attn, rat_positions, claim_positions)
        pred_bin = bin_logits.argmax(dim=-1).item()
        pred_scalar = torch.sigmoid(scalar_pred).item()
        swap_bin = b["claim_bin"]
        orig_bin = a["claim_bin"]
        if pred_bin == swap_bin:
            follows_swap += 1
        if pred_bin == orig_bin:
            follows_orig += 1
        mse_shift += (pred_scalar - b["scalar"]) ** 2

    d = max(1, len(idxs))
    return {
        "cfact_cls_follows_swap": follows_swap / d,
        "cfact_cls_follows_orig": follows_orig / d,
        "cfact_scalar_mse_to_swap": mse_shift / d,
    }


def train_variant(cfg: Config, variant: str, train_loader, eval_loader, train_dataset, eval_dataset, tok2id):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TinyVerifierLM(len(tok2id), cfg).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    for _ in range(cfg.epochs):
        model.train()
        for batch in train_loader:
            input_ids = batch["input_ids"].to(device)
            targets = batch["targets"].to(device)
            attn = batch["attention_mask"].to(device)
            claim_bins = batch["claim_bins"].to(device)
            scalars = batch["scalars"].to(device)
            rat_positions = batch["rat_positions"]
            claim_positions = batch["claim_positions"]

            lm_logits, scalar_pred, bin_logits = model(input_ids, attn, rat_positions, claim_positions)
            vocab = lm_logits.size(-1)
            lm_loss = F.cross_entropy(
                lm_logits.reshape(-1, vocab),
                targets.reshape(-1),
                ignore_index=train_dataset.pad_id,
            )
            scalar_loss = F.mse_loss(torch.sigmoid(scalar_pred), scalars)
            cons_loss = F.cross_entropy(bin_logits, claim_bins)

            if variant == "lm_only":
                loss = cfg.rationale_loss_weight * lm_loss
            elif variant == "no_consistency_loss":
                loss = cfg.rationale_loss_weight * lm_loss + cfg.claim_loss_weight * scalar_loss
            elif variant == "rationale_only":
                loss = cfg.rationale_loss_weight * lm_loss + cfg.consistency_weight * cons_loss
            elif variant == "full_consistency":
                loss = cfg.rationale_loss_weight * lm_loss + cfg.claim_loss_weight * scalar_loss + cfg.consistency_weight * cons_loss
            elif variant == "random_consistency":
                rand_bins = torch.randint(0, 10, claim_bins.shape, device=device)
                loss = cfg.rationale_loss_weight * lm_loss + cfg.claim_loss_weight * scalar_loss + cfg.consistency_weight * F.cross_entropy(bin_logits, rand_bins)
            else:
                raise ValueError(variant)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

    metrics = evaluate(model, eval_loader, device)
    cf = evaluate_counterfactual(model, eval_dataset, tok2id, device)
    out = {"variant": variant, **metrics, **cf}
    return out


def main(cfg: Config):
    set_seed(cfg.seed)
    tok2id, id2tok = build_vocab()
    train_dataset = SyntheticDataset(cfg.num_train, tok2id, seed=cfg.seed)
    eval_dataset = SyntheticDataset(cfg.num_eval, tok2id, seed=cfg.seed + 1)
    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
    eval_loader = DataLoader(eval_dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)

    rows = []
    for variant in ["lm_only", "no_consistency_loss", "rationale_only", "full_consistency", "random_consistency"]:
        rows.append(train_variant(cfg, variant, train_loader, eval_loader, train_dataset, eval_dataset, tok2id))

    df = pd.DataFrame(rows)
    df.to_csv(cfg.output_csv, index=False)
    md_path = cfg.output_csv.replace('.csv', '.md')
    with open(md_path, 'w') as f:
        f.write(df.to_markdown(index=False))
    print(df.to_string(index=False))
    print(f"Saved to {cfg.output_csv} and {md_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generated rationale + scalar claim verifier experiment")
    parser.add_argument("--num-train", type=int, default=2048)
    parser.add_argument("--num-eval", type=int, default=512)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--d-model", type=int, default=128)
    parser.add_argument("--n-layers", type=int, default=2)
    parser.add_argument("--n-heads", type=int, default=4)
    parser.add_argument("--d-ff", type=int, default=256)
    parser.add_argument("--consistency-weight", type=float, default=0.5)
    parser.add_argument("--output-csv", type=str, default="generated_rationale_scalar_results.csv")
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    cfg = Config(
        num_train=args.num_train,
        num_eval=args.num_eval,
        batch_size=args.batch_size,
        epochs=args.epochs,
        lr=args.lr,
        d_model=args.d_model,
        n_layers=args.n_layers,
        n_heads=args.n_heads,
        d_ff=args.d_ff,
        consistency_weight=args.consistency_weight,
        output_csv=args.output_csv,
        seed=args.seed,
    )
    main(cfg)
