import json
from pathlib import Path
import asyncio

import torch
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from monarch.actor import Actor, endpoint

from .utils import slugify, load_model_and_tokenizer


@dataclass
class GenConfig:
    model_instruct_mode: bool = True
    name_of_model_instruct: str = "google/gemma-3-27b-it"

    constrastive: bool = True

    n_related: int = 500
    n_unrelated: int = 500
    batch_size: int = 100
    max_new_tokens: int = 300
    temperature: float = 0.9
    top_k: int = 50
    top_p: float = 0.95
    seed: int = 42

    @property
    def contrastive(self) -> bool:
        return self.constrastive


class LLMActor(Actor):
    """
    One actor per GPU. Loads a single HF causal LM on its local device and can
    generate JSONLs per concept:

      - <concept>_positive.jsonl
      - <concept>_<model_slug>_negative.jsonl   (non-contrastive; one per negative model)
      - <concept>_negative.jsonl                (contrastive; single model)

    Behavior is controlled by cfg.constrastive:

      - constrastive=False:
          * "related" is generated by the model passed at construction time
            (--model_generating_concept).
          * "unrelated" is generated by whatever model is currently loaded
            (reloaded via reload_model), unconditionally from BOS.

      - constrastive=True:
          * both "related" and "unrelated" are generated by the *same* concept chat model
            (--model_generating_concept), using explicit positive/negative prompts.
          * generate_prompts.py should NOT reload negative models in this mode.
    """

    def __init__(self, model_name: str):
        torch.set_default_device("cuda")
        torch.backends.cuda.matmul.allow_tf32 = True
        self.model_name = model_name
        self.tok, self.model = load_model_and_tokenizer(model_name, dtype_str="bfloat16")


    @endpoint
    async def reload_model(self, model_name: str):
        """
        Reload the underlying HF model on this GPU.
        """
        try:
            del self.model
            del self.tok
        except AttributeError:
            pass
        torch.cuda.empty_cache()

        self.model_name = model_name
        self.tok, self.model = load_model_and_tokenizer(model_name, dtype_str="bfloat16")
        return {"model_name": self.model_name}


    @endpoint
    async def generate_for_concept(
        self,
        concept: str,
        cfg_dict: dict,
        out_dir: str,
        rank_hint: int,
        mode: str = "both",  # "related", "unrelated", or "both"
        negative_model_tag=None,
    ):
        cfg_dict = dict(cfg_dict or {})
        if "contrastive" in cfg_dict and "constrastive" not in cfg_dict:
            cfg_dict["constrastive"] = cfg_dict.pop("contrastive")

        cfg = GenConfig(**cfg_dict)

        if mode in ("both", "related"):
            assert cfg.n_related % cfg.batch_size == 0, "n_related must be divisible by batch_size"
        if mode in ("both", "unrelated"):
            assert cfg.n_unrelated % cfg.batch_size == 0, "n_unrelated must be divisible by batch_size"

        seed = cfg.seed + int(rank_hint)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

        out_dir_p = Path(out_dir)
        out_dir_p.mkdir(parents=True, exist_ok=True)
        s = slugify(concept)

        path_rel = out_dir_p / f"{s}_positive.jsonl"

        if negative_model_tag:
            model_slug = slugify(negative_model_tag)
            path_unrel = out_dir_p / f"{s}_{model_slug}_negative.jsonl"
        else:
            path_unrel = out_dir_p / f"{s}_negative.jsonl"

        pos_plain_prompt = f"""Describe the concept {concept} precisely and concisely.
Use cautious language, and clearly distinguish between established facts and interpretations.
Structure the response as: definition then example."""

        pos_chat_messages = [
            {"role": "system", "content": "You are an assistant."},
            {
                "role": "user",
                "content": (
                    f"Describe the concept: {concept}. Begin with a one-sentence definition. "
                    "Then give several natural first-person examples that demonstrate the concept in an unmistakable way. "
                    "Write as if you are the concept itself. Your examples must make the concept obvious through what you say or do "
                    "— for example, if the concept is *lying*, give clearly false statements (e.g., claiming impossible facts or events).\n"
                    "Do not use third-person references. Do not explain anything. "
                    "After the definition, output only the examples as simple first-person sentences that fully embody the concept."
                ),
            },
        ]

        neg_plain_prompt = f"""Describe the opposite / negative of the concept {concept} precisely and concisely.
Structure the response as: one-sentence definition then several first-person examples."""

        neg_chat_messages = [
            {"role": "system", "content": "You are an assistant."},
            {
                "role": "user",
                "content": (
                    f"Describe the opposite of the concept: {concept}. Begin with a one-sentence definition. "
                    "Then give several natural first-person examples that demonstrate the opposite in an unmistakable way. "
                    "Write as if you are the opposite of the concept itself.\n"
                    "Do not use third-person references. Do not explain anything. "
                    "After the definition, output only the examples as simple first-person sentences."
                ),
            },
        ]

        has_chat = bool(getattr(self.tok, "chat_template", None))
        use_chat = bool(
            cfg.model_instruct_mode
            and has_chat
            and (self.model_name == cfg.name_of_model_instruct)
        )

        def _encode_prompt(chat_messages, fallback_text):
            if use_chat:
                text = self.tok.apply_chat_template(
                    chat_messages,
                    tokenize=False,
                    add_generation_prompt=True,
                )
            else:
                text = fallback_text

            enc = self.tok(text, return_tensors="pt")
            input_ids = enc["input_ids"].to("cuda")
            attention_mask = enc.get("attention_mask")
            if attention_mask is None:
                attention_mask = torch.ones_like(input_ids, device="cuda")
            else:
                attention_mask = attention_mask.to("cuda")
            return input_ids, attention_mask, input_ids.shape[1]

        pos_input_ids = pos_attention_mask = None
        pos_prompt_len = 0
        if mode in ("both", "related"):
            pos_input_ids, pos_attention_mask, pos_prompt_len = _encode_prompt(
                pos_chat_messages, pos_plain_prompt
            )

        neg_input_ids = neg_attention_mask = None
        neg_prompt_len = 0
        if cfg.constrastive and mode in ("both", "unrelated"):
            neg_input_ids, neg_attention_mask, neg_prompt_len = _encode_prompt(
                neg_chat_messages, neg_plain_prompt
            )

        bos = self.tok.bos_token_id
        if bos is None:
            bos = self.tok.eos_token_id
        if bos is None:
            bos = self.tok.pad_token_id
        if bos is None:
            raise ValueError("Tokenizer must define at least one of BOS/EOS/PAD token ids.")
        start = torch.tensor([[bos]], device="cuda", dtype=torch.long)

        pad_id = self.tok.pad_token_id
        if pad_id is None:
            pad_id = self.tok.eos_token_id
        if pad_id is None:
            pad_id = bos

        gen_kwargs = dict(
            max_new_tokens=cfg.max_new_tokens,
            do_sample=True,
            temperature=cfg.temperature,
            top_k=cfg.top_k,
            top_p=cfg.top_p,
            num_return_sequences=cfg.batch_size,
            use_cache=True,
            eos_token_id=None,  # bounded by max_new_tokens
            pad_token_id=pad_id,
        )

        produced_rel = 0
        produced_unrel = 0
        files = []
        amp = torch.autocast(device_type="cuda", dtype=torch.bfloat16)

        with torch.inference_mode(), amp:
            if mode in ("both", "related"):
                if pos_input_ids is None or pos_attention_mask is None:
                    raise RuntimeError("Internal: positive prompt inputs were not prepared.")
                n_batches_rel = cfg.n_related // cfg.batch_size
                with open(path_rel, "w", encoding="utf-8") as f_rel:
                    for i in range(n_batches_rel):
                        out_ids = self.model.generate(
                            input_ids=pos_input_ids,
                            attention_mask=pos_attention_mask,
                            **gen_kwargs,
                        )
                        new_tokens = out_ids[:, pos_prompt_len:]
                        texts = self.tok.batch_decode(new_tokens, skip_special_tokens=True)

                        for t in texts:
                            f_rel.write(
                                json.dumps({"concept": concept, "kind": "positive", "text": t.strip()})
                                + "\n"
                            )

                        if i % 2 == 1:
                            f_rel.flush()
                        produced_rel += len(texts)
                        if i % 5 == 4:
                            await asyncio.sleep(0)  # keep mailbox responsive
                files.append(str(path_rel))

            if mode in ("both", "unrelated"):
                n_batches_unr = cfg.n_unrelated // cfg.batch_size
                with open(path_unrel, "w", encoding="utf-8") as f_unrel:
                    for j in range(n_batches_unr):
                        if cfg.constrastive:
                            if neg_input_ids is None or neg_attention_mask is None:
                                raise RuntimeError(
                                    "Contrastive mode requested but negative prompt inputs were not prepared."
                                )
                            out_ids = self.model.generate(
                                input_ids=neg_input_ids,
                                attention_mask=neg_attention_mask,
                                **gen_kwargs,
                            )
                            new_tokens = out_ids[:, neg_prompt_len:]
                            texts = self.tok.batch_decode(new_tokens, skip_special_tokens=True)
                        else:
                            out_ids = self.model.generate(input_ids=start, **gen_kwargs)
                            if out_ids.shape[1] > 1:
                                out_ids = out_ids[:, 1:]  # drop BOS for cleaner strings
                            texts = self.tok.batch_decode(out_ids, skip_special_tokens=True)

                        for t in texts:
                            f_unrel.write(
                                json.dumps({"concept": concept, "kind": "negative", "text": t.strip()})
                                + "\n"
                            )

                        if j % 2 == 1:
                            f_unrel.flush()
                        produced_unrel += len(texts)
                        if j % 5 == 4:
                            await asyncio.sleep(0)
                files.append(str(path_unrel))

        torch.cuda.empty_cache()
        return {
            "rank": int(rank_hint),
            "concept": concept,
            "related": produced_rel,
            "unrelated": produced_unrel,
            "files": files,
        }
