#!/usr/bin/env python3
"""Models and tokenization for LeanCheck."""

from __future__ import annotations

import json
import re
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple

import torch
import torch.nn as nn

from leancheck_data import SPECIAL_TOKENS


LABEL_TO_ID = {"VERIFIES": 1, "FAILS": 0}
ID_TO_LABEL = {v: k for k, v in LABEL_TO_ID.items()}


class SimpleTokenizer:
    """Small deterministic tokenizer for offline smoke tests."""

    def __init__(self, vocab: Optional[Dict[str, int]] = None):
        self.vocab = vocab or {tok: i for i, tok in enumerate(SPECIAL_TOKENS)}
        self.pad_token = "[PAD]"
        self.pad_token_id = self.vocab[self.pad_token]
        self.eos_token = "[EOS]"
        self.eos_token_id = self.vocab[self.eos_token]

    @staticmethod
    def split(text: str) -> List[str]:
        specials = "|".join(re.escape(t) for t in SPECIAL_TOKENS)
        return re.findall(rf"{specials}|[A-Za-z_][A-Za-z0-9_'.]*|[0-9]+|:=|=>|->|[^\s]", text)

    def fit(self, texts: Iterable[str]) -> "SimpleTokenizer":
        for text in texts:
            for tok in self.split(text):
                if tok not in self.vocab:
                    self.vocab[tok] = len(self.vocab)
        return self

    def encode(self, text: str, max_length: int) -> Dict[str, torch.Tensor]:
        ids = [self.vocab.get(tok, self.eos_token_id) for tok in self.split(text)]
        ids = ids[:max_length]
        attn = [1] * len(ids)
        while len(ids) < max_length:
            ids.append(self.pad_token_id)
            attn.append(0)
        return {"input_ids": torch.tensor(ids), "attention_mask": torch.tensor(attn)}

    def save_pretrained(self, path: str) -> None:
        Path(path).mkdir(parents=True, exist_ok=True)
        (Path(path) / "simple_tokenizer.json").write_text(json.dumps(self.vocab, indent=2), encoding="utf-8")

    @classmethod
    def from_pretrained(cls, path: str) -> "SimpleTokenizer":
        vocab = json.loads((Path(path) / "simple_tokenizer.json").read_text(encoding="utf-8"))
        return cls(vocab)

    def convert_tokens_to_ids(self, tok: str) -> int:
        return self.vocab[tok]

    def __len__(self) -> int:
        return len(self.vocab)


class TinyCausalBackbone(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int = 128, n_layers: int = 2, n_heads: int = 4):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.pos = nn.Embedding(512, hidden_size)
        layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=n_heads,
            dim_feedforward=hidden_size * 4,
            dropout=0.1,
            batch_first=True,
            activation="gelu",
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
        self.lm_head = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        bsz, seqlen = input_ids.shape
        pos_ids = torch.arange(seqlen, device=input_ids.device).unsqueeze(0).expand(bsz, -1)
        x = self.embed(input_ids) + self.pos(pos_ids)
        causal = torch.triu(torch.ones(seqlen, seqlen, device=input_ids.device), diagonal=1).bool()
        padding = attention_mask == 0
        h = self.encoder(x, mask=causal, src_key_padding_mask=padding)
        return self.lm_head(h), h


class LeanCheckModel(nn.Module):
    def __init__(
        self,
        model_name: str,
        tokenizer,
        variant: str,
        hidden_size: int = 128,
        tiny_layers: int = 2,
    ):
        super().__init__()
        self.variant = variant
        self.is_tiny = model_name == "tiny-local"
        if self.is_tiny:
            self.backbone = TinyCausalBackbone(len(tokenizer), hidden_size=hidden_size, n_layers=tiny_layers)
            head_dim = hidden_size
        else:
            from transformers import AutoModelForCausalLM

            self.backbone = AutoModelForCausalLM.from_pretrained(model_name)
            self.backbone.resize_token_embeddings(len(tokenizer))
            head_dim = self.backbone.config.hidden_size
        self.consistency_head = nn.Linear(head_dim, 2)

    def marker_ids(self, tokenizer) -> Dict[str, int]:
        return {tok: tokenizer.convert_tokens_to_ids(tok) for tok in SPECIAL_TOKENS}

    @staticmethod
    def span_mask(input_ids: torch.Tensor, attention_mask: torch.Tensor, marker: Dict[str, int], variant: str) -> torch.Tensor:
        bsz, seqlen = input_ids.shape
        mask = torch.zeros_like(attention_mask, dtype=torch.bool)
        for i in range(bsz):
            ids = input_ids[i].tolist()

            def first(tok: str, default: int) -> int:
                try:
                    return ids.index(marker[tok])
                except ValueError:
                    return default

            th = first("[THEOREM]", 0)
            pr = first("[PROOF]", th)
            rat = first("[RAT]", pr)
            claim = first("[CLAIM]", rat)
            eos = first("[EOS]", seqlen)
            if variant in {"rationale_only", "random_consistency", "no_consistency_loss", "lm_only"}:
                start, end = rat + 1, claim
            elif variant == "full_sequence":
                start, end = th + 1, claim
            elif variant == "proof_only":
                start, end = pr + 1, rat
            elif variant == "claim_only":
                start, end = claim, min(eos, claim + 2)
            elif variant == "wrong_span":
                start, end = th + 1, pr
            else:
                start, end = rat + 1, claim
            start = max(0, min(start, seqlen - 1))
            end = max(start + 1, min(end, seqlen))
            mask[i, start:end] = True
        return mask & attention_mask.bool()

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        cls_labels: Optional[torch.Tensor] = None,
        tokenizer=None,
    ) -> Dict[str, torch.Tensor]:
        if self.is_tiny:
            logits, hidden = self.backbone(input_ids, attention_mask)
        else:
            out = self.backbone(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
                return_dict=True,
            )
            logits = out.logits
            hidden = out.hidden_states[-1]

        marker = self.marker_ids(tokenizer)
        smask = self.span_mask(input_ids, attention_mask, marker, self.variant)
        denom = smask.sum(dim=1).clamp_min(1).unsqueeze(-1)
        pooled = (hidden * smask.unsqueeze(-1)).sum(dim=1) / denom
        cls_logits = self.consistency_head(pooled)
        result = {"logits": logits, "cls_logits": cls_logits}
        if labels is not None:
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            result["lm_loss"] = nn.functional.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                ignore_index=-100,
            )
        if cls_labels is not None:
            result["cons_loss"] = nn.functional.cross_entropy(cls_logits, cls_labels)
        return result
