from typing import List, Optional, Tuple, Iterator, TYPE_CHECKING
import json
import re
from pathlib import Path
import torch.nn as nn
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

if TYPE_CHECKING:
    from .cross_entropy_actor import CrossEntropyPlotConfig



def discover_concepts(in_dir: Path):
    related = {
        p.name[:-len("_positive.jsonl")]
        for p in in_dir.glob("*_positive.jsonl")
    }

    unrelated = set()
    for p in in_dir.glob("*_negative.jsonl"):
        base = p.name[:-len("_negative.jsonl")]
        concept_slug = base.rsplit("_", 1)[0]
        unrelated.add(concept_slug)

    slugs = sorted(related & unrelated)
    concepts = []
    for s in slugs:
        label = None
        probe = in_dir / f"{s}_positive.jsonl"
        if probe.exists():
            try:
                with open(probe, "r", encoding="utf-8") as f:
                    for line in f:
                        row = json.loads(line)
                        if (
                            isinstance(row, dict)
                            and "concept" in row
                            and isinstance(row["concept"], str)
                        ):
                            label = row["concept"].strip()
                        break
            except Exception:
                pass
        if not label:
            label = s.replace("_", " ")
        concepts.append((s, label))
    return concepts

def slugify(s: str) -> str:
    return re.sub(r"[^A-Za-z0-9]+", "-", s.strip().lower()).strip("-") or "concept"

def model_slug(name: str) -> str:
    return slugify(name.replace("/", "-"))

def read_lines(path: Path) -> List[str]:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if s:
                rows.append(s)
    return rows


def read_jsonl_texts(path: Path, n_prompts=None, text_key="text"):
    out = []
    if not path.exists():
        return out
    with open(path, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            if n_prompts is not None and i >= n_prompts:
                break
            try:
                row = json.loads(line)
                txt = row.get(text_key, "")
                if isinstance(txt, str):           # keep empty strings too
                    out.append(txt)
            except Exception:
                continue
    return out


def load_contexts_for_concept(contexts_file: str, concept_slug: str, concept_label: str):
    """
    contexts_file is a JSONL with:
      - one line like: {"negative": ["neg1", "neg2", ...]}
      - one line per concept, e.g.: {"depression": ["p1", "p2", ...]}

    Negative prompts are shared across all concepts.
    Positive prompts are looked up by concept key.

    Returns:
        contexts:           list[str]  (negatives first, then positives)
        source_line_indices:list[int]  same length, each is the JSONL line index (0-based)
    """
    path = Path(contexts_file)
    contexts: List[str] = []
    source_line_indices: List[int] = []

    if path.suffix != ".jsonl":
        from .utils import read_lines
        contexts = read_lines(path)
        source_line_indices = [-1] * len(contexts)
        return contexts, source_line_indices

    negatives: List[str] = []
    negatives_src: List[int] = []
    positives: List[str] = []
    positives_src: List[int] = []

    concept_keys = list({k for k in (concept_slug, concept_label) if k})

    with path.open("r", encoding="utf-8") as f:
        for line_idx, line in enumerate(f):
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)

            if "negative" in obj and isinstance(obj["negative"], list):
                for prompt in obj["negative"]:
                    negatives.append(prompt)
                    negatives_src.append(line_idx)

            for key in concept_keys:
                if key and key in obj and isinstance(obj[key], list):
                    for prompt in obj[key]:
                        positives.append(prompt)
                        positives_src.append(line_idx)
    if len(positives) == 0 or len(negatives) == 0:
        raise ValueError('positive or negatives is empty')
    
    contexts = negatives + positives
    source_line_indices = negatives_src + positives_src
    return contexts, source_line_indices


def _from_pretrained_with_dtype(cls, model_name: str, *, dtype, **kwargs):
    try:
        return cls.from_pretrained(model_name, dtype=dtype, **kwargs)
    except TypeError:
        return cls.from_pretrained(model_name, torch_dtype=dtype, **kwargs)

def find_block_list(model: nn.Module, override_path: Optional[str] = None) -> nn.ModuleList:
    if override_path:
        obj = model
        for attr in override_path.split("."):
            if not hasattr(obj, attr):
                raise ValueError(f"layer_path '{override_path}' not found at '{attr}'")
            obj = getattr(obj, attr)
        if not isinstance(obj, nn.ModuleList):
            raise ValueError(f"layer_path '{override_path}' is not a ModuleList")
        return obj

    candidates = [
        ("model", "layers"),                 # LLaMA/Mistral/Qwen
        ("model", "decoder", "layers"),
        ("transformer", "h"),                # GPT-2/OPT
        ("transformer", "layers"),
        ("gpt_neox", "layers"),
        ("model", "encoder", "layers"),
        ("model", "language_model", "layers") # for multimodal gemma 3 (parameter count >= 4B) 
    ]
    for path in candidates:
        obj = model
        ok = True
        for attr in path:
            if hasattr(obj, attr):
                obj = getattr(obj, attr)
            else:
                ok = False
                break
        if ok and isinstance(obj, nn.ModuleList):
            return obj

    for name in ("layers", "h", "blocks", "block"):
        if hasattr(model, name) and isinstance(getattr(model, name), nn.ModuleList):
            return getattr(model, name)

    raise ValueError("Could not locate transformer block ModuleList; provide --layer_path.")


def load_model_and_tokenizer(model_name: str, dtype_str: str = "float32") -> Tuple[AutoTokenizer, nn.Module]:
    tok = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token or tok.bos_token

    dtype = torch.float32 if dtype_str == "float32" else torch.bfloat16
    common = dict(low_cpu_mem_usage=True, device_map={"": 0})
    try:
        model = _from_pretrained_with_dtype(
            AutoModelForCausalLM, model_name, dtype=dtype,
            attn_implementation="flash_attention_2", **common
        )
    except Exception:
        model = _from_pretrained_with_dtype(
            AutoModelForCausalLM, model_name, dtype=dtype,
            attn_implementation="sdpa", **common
        )
    model.eval()
    model.generation_config.pad_token_id = tok.pad_token_id
    if tok.eos_token_id is not None:
        model.generation_config.eos_token_id = tok.eos_token_id
    return tok, model

def load_steer_vector(steer_dir: Path, model_name: str, concept_slug: str, layer_idx: int) -> torch.Tensor:
    mslug = model_slug(model_name)
    path = steer_dir / mslug / concept_slug / f"layer_{layer_idx}.pt"
    data = torch.load(path, map_location="cpu")
    vec = data["steering_vector"]  # [H], float32
    if not isinstance(vec, torch.Tensor):
        vec = torch.tensor(vec, dtype=torch.float32)
    return vec

def chunked(seq, n):
    buf = []
    for x in seq:
        buf.append(x)
        if len(buf) >= n:
            yield buf
            buf = []
    if buf:
        yield buf



def _get_eos_id(tokenizer) -> Optional[int]:
    for attr in ("eos_token_id", "sep_token_id", "pad_token_id"):
        tid = getattr(tokenizer, attr, None)
        if isinstance(tid, int):
            return tid
    return None


    











def iter_eval_blocks_from_parquet(
    tokenizer,
    parquet_path: str,
    cfg: "CrossEntropyPlotConfig",
    batch_size: int,
) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:

    """Yield (input_ids, labels) CPU tensors of shape [B, T] from a parquet file.

    Behavior:
      - If cfg.eval_max_blocks is 0/None, we iterate the *full parquet* (can be very large).
      - Uses a fast local-file path via `pyarrow` + batched tokenization.
      - Falls back to HF `datasets` streaming when pyarrow isn't available/usable.

    This function intentionally does *not* pad: it yields fixed-size contiguous blocks
    from each document with stride `cfg.eval_stride`.
    """

    required_amount = int(cfg.eval_seq_len) + 1
    stride = int(cfg.eval_stride) if cfg.eval_stride else int(cfg.eval_seq_len)
    eos_id = _get_eos_id(tokenizer)

    max_blocks = int(cfg.eval_max_blocks) if getattr(cfg, "eval_max_blocks", None) else None
    if max_blocks is not None and max_blocks <= 0:
        max_blocks = None

    read_rows = 2048          # parquet rows per IO batch
    tokenize_batch_size = 64  # texts per tokenizer call

    buffer: List[List[int]] = []
    emitted = 0

    def _flush_buffer() -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
        nonlocal buffer
        if not buffer:
            return None
        batch = torch.tensor(buffer, dtype=torch.long)
        if torch.cuda.is_available():
            try:
                batch = batch.pin_memory()
            except Exception:
                pass
        buffer = []
        return batch[:, :-1], batch[:, 1:]

    def _yield_from_token_ids(token_ids: List[int]):
        nonlocal emitted, buffer
        if not token_ids:
            return
        if cfg.add_eos_between_docs and eos_id is not None:
            token_ids = token_ids + [int(eos_id)]
        if len(token_ids) < required_amount:
            return

        for i in range(0, len(token_ids) - required_amount + 1, stride):
            buffer.append(token_ids[i : i + required_amount])
            emitted += 1

            if len(buffer) >= batch_size:
                out = _flush_buffer()
                if out is not None:
                    yield out

            if max_blocks is not None and emitted >= max_blocks:
                break

    def _yield_from_texts(texts: List[str]):
        good = [t for t in texts if isinstance(t, str) and t.strip()]
        if not good:
            return

        enc = tokenizer(
            good,
            add_special_tokens=False,
            truncation=bool(getattr(cfg, "max_doc_tokens", None)),
            max_length=int(cfg.max_doc_tokens) if getattr(cfg, "max_doc_tokens", None) else None,
            return_attention_mask=False,
        )
        for token_ids in enc.get("input_ids", []):
            for out in _yield_from_token_ids(token_ids):
                yield out
            if max_blocks is not None and emitted >= max_blocks:
                break

    path = Path(parquet_path)
    used_pyarrow = False
    if path.exists() and path.suffix in {".parquet", ".pq"}:
        try:
            import pyarrow.parquet as pq  # type: ignore

            pf = pq.ParquetFile(str(path))
            used_pyarrow = True

            for rb in pf.iter_batches(batch_size=int(read_rows), columns=[cfg.text_field]):
                texts = rb.column(0).to_pylist()

                for j in range(0, len(texts), int(tokenize_batch_size)):
                    chunk = texts[j : j + int(tokenize_batch_size)]
                    for out in _yield_from_texts(chunk):
                        yield out
                    if max_blocks is not None and emitted >= max_blocks:
                        break

                if max_blocks is not None and emitted >= max_blocks:
                    break

        except Exception:
            used_pyarrow = False

    if not used_pyarrow:
        dataset = load_dataset("parquet", data_files=str(parquet_path), split="train", streaming=True)
        text_buf: List[str] = []
        for sample in dataset:
            text = sample.get(cfg.text_field, None)
            if isinstance(text, str) and text.strip():
                text_buf.append(text)

            if len(text_buf) >= int(tokenize_batch_size):
                for out in _yield_from_texts(text_buf):
                    yield out
                text_buf = []
                if max_blocks is not None and emitted >= max_blocks:
                    break

            if max_blocks is not None and emitted >= max_blocks:
                break

        if text_buf and (max_blocks is None or emitted < max_blocks):
            for out in _yield_from_texts(text_buf):
                yield out

    out = _flush_buffer()
    if out is not None:
        yield out



import torch
import torch.nn as nn
from typing import Any, Optional, Sequence


class TempFp32LayerWrapper(nn.Module):
    def __init__(
        self,
        module: nn.Module,
        storage_dtype: torch.dtype = torch.bfloat16,
        compute_dtype: torch.dtype = torch.float32,
    ):
        super().__init__()
        self.module = module
        self.storage_dtype = storage_dtype
        self.compute_dtype = compute_dtype

        with torch.no_grad():
            for p in self.module.parameters():
                p.data = p.data.to(storage_dtype)
                p.requires_grad = False

    def __getattr__(self, name: str):
        try:
            return super().__getattr__(name)
        except AttributeError:
            try:
                return getattr(self.module, name)
            except AttributeError:
                raise AttributeError(
                    f"'{self.__class__.__name__}' object and its inner module "
                    f"have no attribute '{name}'"
                )

    def _cast_tensors(self, x: Any, dtype: torch.dtype):
        if torch.is_tensor(x):
            return x.to(dtype)
        if isinstance(x, (list, tuple)):
            return type(x)(self._cast_tensors(v, dtype) for v in x)
        if isinstance(x, dict):
            return {k: self._cast_tensors(v, dtype) for k, v in x.items()}
        return x

    def forward(self, *args, **kwargs):
        with torch.no_grad():
            for p in self.module.parameters():
                if p.dtype is not self.compute_dtype:
                    p.data = p.data.to(self.compute_dtype)

        args_fp32 = self._cast_tensors(args, self.compute_dtype)
        kwargs_fp32 = self._cast_tensors(kwargs, self.compute_dtype)

        out = self.module(*args_fp32, **kwargs_fp32)

        out = self._cast_tensors(out, self.storage_dtype)

        with torch.no_grad():
            for p in self.module.parameters():
                if p.dtype is not self.storage_dtype:
                    p.data = p.data.to(self.storage_dtype)

        return out

def wrap_blocks_with_temp_fp32(
    model: nn.Module,
    layer_indices=None,
    *,
    override_path: Optional[str] = None,
    storage_dtype: torch.dtype = torch.bfloat16,
    compute_dtype: torch.dtype = torch.float32,
) -> nn.Module:
    blocks = find_block_list(model, override_path=override_path)
    n_layers = len(blocks)

    if layer_indices is None:
        idxs = list(range(n_layers))
    else:
        idxs = []
        for i in layer_indices:
            if i < 0:
                i = n_layers + i
            if not (0 <= i < n_layers):
                raise IndexError(f"layer index {i} out of range [0, {n_layers-1}]")
            idxs.append(i)

    for i in idxs:
        blocks[i] = TempFp32LayerWrapper(
            blocks[i],
            storage_dtype=storage_dtype,
            compute_dtype=compute_dtype,
        )

    for p in model.parameters():
        p.requires_grad_(False)

    return model


