# -*- coding: utf-8 -*-
import argparse, json, math, random, os
from pathlib import Path
from typing import Dict, List

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizerFast, RobertaModel, AdamW

# ------------------------------------------------------------------
# Dataset
# ------------------------------------------------------------------
class AttackLogDataset(Dataset):
    """Load (prompt, response, thinking, logits, success) tuples."""

    def __init__(self, file_path: str, tokenizer: RobertaTokenizerFast, max_len: int = 512):
        self.file_path = Path(file_path)
        self.samples = self._load(self.file_path)

        texts = [self._concat_text(s) for s in self.samples]
        self.enc = tokenizer(texts, truncation=True, padding=True, max_length=max_len, return_tensors="pt")
        self.logit_mean = torch.tensor([self._mean_logits(s) for s in self.samples]).unsqueeze(1)  # (N,1)
        self.success = torch.tensor([s.get("success", 0) for s in self.samples], dtype=torch.long)  # (N,)

    # --------------------------------------------------------------
    def _load(self, p: Path) -> List[Dict]:
        if not p.exists():
            raise FileNotFoundError(p)
        if p.suffix.lower() == ".jsonl":
            return [json.loads(l) for l in p.read_text().splitlines() if l.strip()]
        elif p.suffix.lower() == ".json":
            return json.loads(p.read_text())
        else:
            raise ValueError("Only .jsonl/.json files are supported")

    def _concat_text(self, sample):
        parts = [sample.get(k, "") for k in ("prompt", "response", "thinking")]
        return " <sep> ".join([p for p in parts if p])

    def _mean_logits(self, sample):
        logits = sample.get("logits") or []
        return float(sum(logits) / len(logits)) if logits else 0.0

    # --------------------------------------------------------------
    def __len__(self):
        return len(self.success)

    def __getitem__(self, idx):
        item = {k: v[idx] for k, v in self.enc.items()}
        item["logit_mean"] = self.logit_mean[idx]
        item["success"] = self.success[idx]
        return item

# ------------------------------------------------------------------
# Mutual‑Information Critic  fθ(z,t)  (InfoNCE)
# ------------------------------------------------------------------
class Critic(nn.Module):
    def __init__(self, z_dim=768):
        super().__init__()
        self.text_encoder = RobertaModel.from_pretrained("roberta-base")
        for p in self.text_encoder.parameters():
            p.requires_grad = False
        self.logit_fc = nn.Linear(1, 32)
        self.z_proj = nn.Linear(z_dim + 32, z_dim)
        self.t_embed = nn.Embedding(2, z_dim)

    def encode_z(self, input_ids, attention_mask, logit_mean):
        h = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).pooler_output  # (B,768)
        l = self.logit_fc(logit_mean)  # (B,32)
        z = torch.relu(self.z_proj(torch.cat([h, l], dim=-1)))  # (B,768)
        return z

    def forward(self, batch):
        z = self.encode_z(batch["input_ids"], batch["attention_mask"], batch["logit_mean"])
        t = self.t_embed(batch["success"])  # (B,768)
        return z @ t.T  # (B,B)

# ------------------------------------------------------------------
# InfoNCE helpers
# ------------------------------------------------------------------

def info_nce_loss(scores: torch.Tensor) -> torch.Tensor:
    pos = torch.diag(scores)
    denom = torch.logsumexp(scores, dim=1)
    return -(pos - denom).mean()


def batch_leak_bits(scores: torch.Tensor) -> torch.Tensor:
    pos = torch.diag(scores)
    denom = torch.logsumexp(scores, dim=1)
    return (pos - denom) / math.log(2)  # nats→bits

# ------------------------------------------------------------------
# Training Loop
# ------------------------------------------------------------------

def train(args):
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    ds = AttackLogDataset(args.data_file, tokenizer)
    loader = DataLoader(ds, batch_size=args.batch_size, shuffle=True, drop_last=True)

    device = torch.device(args.device)
    critic = Critic().to(device)
    opt = AdamW(critic.parameters(), lr=args.lr)

    global_step = 0
    for epoch in range(args.epochs):
        for batch in loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            scores = critic(batch)
            loss = info_nce_loss(scores)
            opt.zero_grad()
            loss.backward()
            opt.step()

            if global_step % 100 == 0:
                print(f"step={global_step}\tloss={loss.item():.4f}")
            global_step += 1

    # ----------------------------------------------------------
    # Estimate per‑sample leak_bits + dataset MI
    critic.eval()
    all_bits = []
    records_with_bits = []
    with torch.no_grad():
        for batch_idx, batch in enumerate(DataLoader(ds, batch_size=args.batch_size)):
            idx_range = range(batch_idx * args.batch_size, min(len(ds), (batch_idx + 1) * args.batch_size))
            batch = {k: v.to(device) for k, v in batch.items()}
            scores = critic(batch)
            bits = batch_leak_bits(scores).cpu().tolist()
            all_bits.extend(bits)
            for i, orig_idx in enumerate(idx_range):
                rec = ds.samples[orig_idx].copy()
                rec["leak_bits_est"] = bits[i]
                records_with_bits.append(rec)

    mi_dataset = sum(all_bits) / len(all_bits)
    print(f"\nEstimated I(Z;T) ≈ {mi_dataset:.4f} bits")

    # ----------------------------------------------------------
    # 1) save jsonl with leak_bits
    out_jsonl = Path(args.data_file).with_suffix(".with_bits.jsonl")
    with out_jsonl.open("w", encoding="utf-8") as f:
        for rec in records_with_bits:
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")
    print(f"Per‑sample estimates written to {out_jsonl}")

    # 2) save model + config
    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    model_path = out_dir / "critic.pt"
    torch.save(critic.state_dict(), model_path)
    config = {
        "data_file": str(Path(args.data_file).resolve()),
        "epochs": args.epochs,
        "batch_size": args.batch_size,
        "lr": args.lr,
        "mi_bits": mi_dataset
    }
    with (out_dir / "config.json").open("w", encoding="utf-8") as f:
        json.dump(config, f, indent=2)
    print(f"Model saved to {model_path} (config.json included)")

# ------------------------------------------------------------------
# CLI
# ------------------------------------------------------------------

def main():
    p = argparse.ArgumentParser(description="MI estimation from attack logs (InfoNCE)")
    p.add_argument("--data_file", required=True, help="Path to JSONL attack log")
    p.add_argument("--output_dir", default="model", help="Directory to save model + config")
    p.add_argument("--epochs", type=int, default=2)
    p.add_argument("--batch_size", type=int, default=32)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    args = p.parse_args()

    train(args)


if __name__ == "__main__":
    main()
