from __future__ import annotations
from importlib.metadata import entry_points

from huggingface_hub.errors import EntryNotFoundError
import torch
import json
from typing import Any, Iterable, Optional
from tqdm import tqdm

from src.run.config import RunConfig
from src.model.config import Transformer
from src.run.utils import get_batch, log_line, get_select_mask
from src.run.logger import get_tqdm_kwargs
from src.run.distributed import get_raw_model, is_main_process, barrier

# --------------------------------------------------------------------------- #
# EVAL LOSS                                                                   #
# --------------------------------------------------------------------------- #

@torch.inference_mode()
def eval_loss(
    model: Transformer,
    config: RunConfig,
    data_label: str,
    expert_labels: Optional[Iterable[str]] = None,
    num_batches: Optional[int] = None,
    split: str = "test",
) -> float:

    # Ensure all GPUs are synchronized before starting evaluation
    barrier()

    # unpack run config
    loaders = config.loaders
    logger = config.logger

    loader = loaders[data_label][split]

    loader.reset()  # Ensure we start from the beginning of the dataset

    # Get raw model for accessing config and model_type
    raw_model = get_raw_model(model)

    total_loss = 0.0
    max_batches = len(loader)
    if num_batches is None:
        num_batches = max_batches
    num_batches = min(num_batches, max_batches)
    
    # Guard against empty loaders (e.g., test split smaller than batch size)
    if num_batches == 0:
        logger.warning(f"No evaluation batches available for label={data_label}; returning NaN loss")
        return float('nan')
    
    for _ in range(num_batches):

        x, y, _ = get_batch(loader)

        if raw_model.model_type == "routed":
            labels = ["core"] + config.aux_labels
            sel_mask = get_select_mask(labels, expert_labels, device=x.device)

            loss = model(
                tokens=x,
                targets=y,
                select_mask=sel_mask,
            )[1]
            
        else:
            loss = model(
                tokens=x,
                targets=y,
            )[1]

        total_loss += loss.item()

    loss = total_loss / num_batches
    return loss


# --------------------------------------------------------------------------- #
# GENERATION                                                                  #
# --------------------------------------------------------------------------- #

@torch.inference_mode()
def generate_samples(
    model: Transformer,
    config: RunConfig,
    data_label: str,
    expert_labels: Optional[Iterable[str]] = None,
    num_examples: int = 128,
    prefix_len: int = 64,
) -> None:

    device = config.device
    logger = config.logger

    # Get raw model for accessing config and model_type
    raw_model = get_raw_model(model)

    def _collect_full_sequences(data_label: str, n_examples: int) -> list[list[int]]:
        
        loader = config.loaders[data_label]["test"]
        loader.reset()
        eos_id = raw_model.config.eos_token_id

        sequences: list[list[int]] = []
        seg_buf: list[int] = []

        while len(sequences) < n_examples:
            x = get_batch(loader)[0]
            for tok in x.flatten().tolist():
                seg_buf.append(tok)
                if tok == eos_id:
                    if len(seg_buf) > 1:
                        # drop EOS from the stored sequence
                        sequences.append(seg_buf[:-1])
                    seg_buf = []
                    if len(sequences) >= n_examples:
                        break

        return sequences[:n_examples]

    def _generate_batch(batch_prompts: list[list[int]]) -> list[list[int]]:
        if not batch_prompts:
            return []

        eos_token_id = raw_model.config.eos_token_id
        bs = len(batch_prompts)
        max_prompt_len = max(len(p) for p in batch_prompts)

        # Right-align prompts; fill left with EOS to create clean segment boundaries
        prompt_batch = torch.full(
            (bs, max_prompt_len), eos_token_id, dtype=torch.long, device=device
        )
        for i, p in enumerate(batch_prompts):
            if len(p) > 0:
                prompt_batch[i, -len(p) :] = torch.tensor(p, dtype=torch.long, device=device)

        finished = [False] * bs
        generated_tok: list[list[int]] = [[] for _ in range(bs)]
        max_new_tokens = raw_model.config.ctx_len - max_prompt_len

        for _ in tqdm(range(max_new_tokens), **get_tqdm_kwargs(logger, desc=f"Generating samples | Data: {data_label}", ncols=100)):
            if raw_model.model_type == "routed":
                labels = ["core"] + config.aux_labels
                sel_mask = get_select_mask(labels, expert_labels, device=prompt_batch.device)

                logp = model(
                    tokens=prompt_batch,
                    targets=None,
                    select_mask=sel_mask,
                )[0]
                next_tok = torch.argmax(logp[:, -1, :], dim=-1)
            else:
                logits = model(prompt_batch)[0]
                next_tok = torch.argmax(logits[:, -1, :], dim=-1)

            prompt_batch = torch.cat([prompt_batch, next_tok.unsqueeze(1)], dim=1)

            for i, tok in enumerate(next_tok.tolist()):
                if finished[i]:
                    continue
                if tok == eos_token_id:
                    finished[i] = True
                else:
                    generated_tok[i].append(tok)
            if all(finished):
                break

        return generated_tok

    # If more than one data label is provided, iterate over each individually

    logger.info(f"Generating samples | Data: {data_label}")

    # 1. Collect complete sequences (strip trailing EOS)
    sequences_tok = _collect_full_sequences(data_label, num_examples)
    if len(sequences_tok) == 0:
        return
    prompts_tok = [seq[:prefix_len] for seq in sequences_tok if len(seq) > 0]

    # 2. Batched autoregressive generation
    gen_batch_size = config.loaders[data_label]["test"].B
    generations_tok: list[list[int]] = []
    for i in range(0, len(prompts_tok), gen_batch_size):
        batch_prompts = prompts_tok[i : i + gen_batch_size]
        generations_tok.extend(_generate_batch(batch_prompts))

    # 3. Decode and save
    tok = raw_model.config.tokenizer
    sequences_text = [tok.decode(s, skip_special_tokens=False) for s in sequences_tok]
    prompts_text = [tok.decode(p, skip_special_tokens=False) for p in prompts_tok]
    continuations_text = [tok.decode(g, skip_special_tokens=False) for g in generations_tok]
    generations = [
        {
            "prompt": prompts_text[i],
            "continuation": continuations_text[i],
            "truth": sequences_text[i],
        }
        for i in range(len(prompts_text))
    ]
        
    out = {
        "prefix_len": prefix_len,
        "num_examples": len(prompts_text),
        "generations": generations,
    }

    return out


# --------------------------------------------------------------------------- #
# MAIN                                                                       #
# --------------------------------------------------------------------------- #

@torch.inference_mode()
def do_eval(
    stage: dict,
    model: Transformer,
    config: RunConfig,
    data_labels: Optional[Iterable[str]] = None,
    expert_labels: Optional[Iterable[str]] = None,
    log: Optional[dict[str, Any]] = None,
) -> None:

    # unpack run config
    gen_samples = stage.get("gen_samples", False)
    res_dir = config.res_dir
    labels = ["core"] + config.aux_labels
    logger = config.logger
    log_fp = res_dir / "stats.jsonl"
    num_batches = 10004
    test_ood = config.test_ood

    splits = ["test"]
    if test_ood:
        splits.append("test_ood")
        assert all("test_ood" in config.loaders[x] for x in labels)

    # get labels 
    if data_labels is None:
        data_labels = labels

    logger.info(f"---- Begin Eval ----")

    # evaluate for each label
    for data_label in data_labels:

        for split in splits:

            logger.info(f"Data: {data_label} | Experts: {expert_labels} | Split: {split}")

            loss = eval_loss(model, config, data_label, expert_labels, num_batches, split)
            logger.info(f"loss: {loss:.8f}")

            # File I/O operations need guards
            if is_main_process():

                entry = {
                    "stage": stage,
                    "function": "do_eval",
                    "data_label": data_label,
                    "expert_labels": expert_labels,
                    "loss": loss,
                    "split": split,
                }

                if log:
                    entry.update(log)

                # Generate samples (only on main process to avoid duplication)
                if gen_samples:
                    samples = generate_samples(
                        model=model,
                        config=config,
                        data_label=data_label,
                        expert_labels=expert_labels,
                    )
                    entry["samples"] = samples
                
                # Write to JSONL file
                log_line(entry, log_fp)