"""
Teacher-based Synthetic Data Generation for Data Distillation.

Generates synthetic training data by having a teacher model produce responses
to prompts from a base dataset. Tracks energy consumption during generation.
"""

import os
import torch
import datasets
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from pathlib import Path
from typing import Optional
from torch.nn.utils.rnn import pad_sequence

from distill_bench.core.energy_logger import EnergyTracker
from distill_bench.core.config_loader import Config, load_config


def generate_synthetic_dataset(
    config: Config,
    energy_tracker: Optional[EnergyTracker] = None,
    stage_name: str = "synthetic dataset teacher generation (sft)",
    ) -> datasets.DatasetDict:
    """
    Generate synthetic dataset using teacher model.
    Args:
        config: Configuration object
        energy_tracker: Optional energy tracker for measuring generation

    Returns:
        DatasetDict with synthetic training data
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    gen_config = config.get("synthetic_data.generation", {})
    temperature = gen_config.get("temperature", 0.7)
    top_p = gen_config.get("top_p", 0.9)
    max_new_tokens = gen_config.get("max_new_tokens", 1024)
    decoding_strategy = gen_config.get("decoding_strategy", "sampling")
    generation_batch_size = config.get("batch_size", 4)

    max_seq_len = getattr(config, "max_sequence_length", None) or config.get("data.max_sequence_length", 2048)
    dataset_path = config.dataset_path or config.get("data.dataset_path")
    dataset_name = config.dataset_name or config.get("data.dataset_name")

    if not dataset_path:
        raise ValueError("dataset_path is not set. Run tulu_preprocess_dataset.py first.")

    print(f"===============================")
    print(f"Using dataset_choice='{dataset_name}' from preprocessed dataset at: {dataset_path}")
    print(f"===============================")
    
    # Start energy tracking for teacher generation (single stage)
    if energy_tracker and energy_tracker.current_stage is None:
        energy_tracker.start_stage(stage_name)

    # Load tokenizer and teacher model
    print(f"Loading teacher model: {config.teacher_model_name}")
    tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
    tokenizer.padding_side = "left"
    teacher_model = AutoModelForCausalLM.from_pretrained(
        config.teacher_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cuda",
    )
    teacher_model.eval()
    teacher_model.eval()
    teacher_model.requires_grad_(False)

    teacher_model.config.use_cache = True

    print(f"Loading preprocessed prompt dataset from: {dataset_path}")
    prompt_dataset = datasets.load_from_disk(dataset_path)
    if isinstance(prompt_dataset, datasets.DatasetDict):
        prompt_dataset = prompt_dataset["train"]

    max_gen_examples = getattr(config, "max_gen_examples", 7000)
    print(f"Generating {(max_gen_examples)} synthetic examples...")

    # Storage for synthetic data
    synthetic_data = {
        "input_ids": [],
        "attention_mask": [],
        "labels": [],
    }

    total_tokens_generated = 0
    successful_generations = 0
    filtering_config = config.get("synthetic_data.filtering", {})
    pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id

    base_generation_kwargs = {
        "pad_token_id": pad_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "do_sample": decoding_strategy == "sampling",
    }
    if decoding_strategy == "sampling":
        base_generation_kwargs["temperature"] = temperature
        base_generation_kwargs["top_p"] = top_p

    batch_prompts = []

    def flush_batch():
        nonlocal batch_prompts, total_tokens_generated, successful_generations
        if not batch_prompts:
            return

        prompt_lengths = [p["prompt_ids"].shape[0] for p in batch_prompts]
        max_new_tokens_for_batch = max_new_tokens

        batch_input_ids = pad_sequence(
            [p["prompt_ids"] for p in batch_prompts],
            batch_first=True,
            padding_value=pad_token_id,
        ).to(device)
        batch_attention_mask = pad_sequence(
            [p["prompt_attention_mask"] for p in batch_prompts],
            batch_first=True,
            padding_value=0,
        ).to(device)

        batch_outputs = teacher_model.generate(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_mask,
            max_new_tokens=max_new_tokens_for_batch,
            **base_generation_kwargs,
        ).cpu()

        for prompt_info, output, prompt_length in zip(batch_prompts, batch_outputs, prompt_lengths):
            generated_tokens = output[prompt_length:]

            synthetic_labels = torch.full_like(output, fill_value=-100)
            synthetic_labels[prompt_length:] = output[prompt_length:]
            output_attention_mask = torch.ones_like(output)

            if filtering_config.get("enabled", True):
                min_length = filtering_config.get("min_length", 10)
                max_length = filtering_config.get("max_length", max_seq_len)
                response_length = len(generated_tokens)
                total_length = len(output)
                if total_length > max_length + 200:
                    # Clamp to max_length instead of skipping long samples.
                    output = output[:max_length]
                    synthetic_labels = synthetic_labels[:max_length]
                    output_attention_mask = output_attention_mask[:max_length]
                    # Recompute generated_tokens length after clamping.
                    generated_tokens = output[prompt_length:]
                    response_length = len(generated_tokens)
                if response_length < min_length:
                    # Keep short generations but note it for visibility.
                    print(f"Response length shorter than min length - keeping idx {prompt_info['idx']} (len={response_length})")

            synthetic_data["input_ids"].append(output.tolist())
            synthetic_data["attention_mask"].append(output_attention_mask.tolist())
            synthetic_data["labels"].append(synthetic_labels.tolist())

            total_tokens_generated += len(generated_tokens)
            successful_generations += 1

        batch_prompts = []
        if device.type == "cuda" and successful_generations and successful_generations % 500 == 0:
            torch.cuda.empty_cache()

    # Generate responses
    progress_bar = tqdm(prompt_dataset, desc="Collecting prompts (generation runs in batches)")
    with torch.inference_mode():
        processed_examples = 0
        for idx, example in enumerate(progress_bar):
            if max_gen_examples is not None and processed_examples >= max_gen_examples:
                print(f"[EARLY STOP] Reached synthetic generation limit ({max_gen_examples})")
                break
            processed_examples += 1
            try:
                input_ids = example["input_ids"]
                attention_mask = example["attention_mask"]
                existing_labels = example["labels"]

                response_tokens = (existing_labels != -100).nonzero(as_tuple=True)[0]
                if len(response_tokens) == 0:
                    continue

                response_start = response_tokens[0].item()
                if response_start == 0:
                    continue

                prompt_ids = input_ids[:response_start]
                prompt_attention_mask = attention_mask[:response_start]
                prompt_length = prompt_ids.shape[0]

                # If the prompt is too long to leave room for a full generation, drop tokens from the front
                # (keep the tail so the request stays coherent) to leave headroom.
                headroom_for_prompt = max(max_seq_len - max_new_tokens, 1)
                if prompt_length > headroom_for_prompt:
                    trim_start = prompt_length - headroom_for_prompt
                    prompt_ids = prompt_ids[trim_start:]
                    prompt_attention_mask = prompt_attention_mask[trim_start:]
                    prompt_length = prompt_ids.shape[0]

                batch_prompts.append(
                    {
                        "idx": idx,
                        "prompt_ids": prompt_ids,
                        "prompt_attention_mask": prompt_attention_mask,
                    }
                )

                if len(batch_prompts) >= generation_batch_size:
                    flushed_batch_size = len(batch_prompts)
                    tqdm.write(
                        f"[FLUSH] Generating batch of {flushed_batch_size} prompts "
                        f"(processed={processed_examples}, successes={successful_generations}, tokens={total_tokens_generated})"
                    )
                    flush_batch()
                    progress_bar.set_postfix(
                        successes=successful_generations,
                        tokens=total_tokens_generated,
                        batch=flushed_batch_size,
                    )

            except Exception as e:
                print(f"Warning: Failed to generate for example {idx}: {e}")
                continue

        if batch_prompts:
            tqdm.write(
                f"[FLUSH] Generating final batch of {len(batch_prompts)} prompts "
                f"(processed={processed_examples}, successes={successful_generations}, tokens={total_tokens_generated})"
            )
        flush_batch()

    # End energy tracking
    if energy_tracker:
        energy_tracker.end_stage(tokens_processed=total_tokens_generated)

    print(f"Successfully generated {successful_generations} examples")
    print(f"Total tokens generated: {total_tokens_generated:,}")

    # Clean up teacher model
    del teacher_model
    torch.cuda.empty_cache()

    # Create dataset
    synthetic_dataset = datasets.Dataset.from_dict(synthetic_data)
    synthetic_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

    # Split into train/eval
    split_dataset = synthetic_dataset.train_test_split(
        test_size=0.05,
        seed=config.seed,
    )

    # Save if path specified
    synthetic_path = config.get("synthetic_data.synthetic_dataset_path")
    os.makedirs(synthetic_path, exist_ok=True)
    split_dataset.save_to_disk(synthetic_path)
    print(f"Saved synthetic dataset to: {synthetic_path}")

    return split_dataset


def load_synthetic_dataset(config: Config) -> datasets.DatasetDict:
    """Load existing synthetic dataset"""
    synthetic_path = config.get("synthetic_data.synthetic_dataset_path")

    try:
        print(f"Loading existing synthetic dataset from: {synthetic_path}")
        return datasets.load_from_disk(synthetic_path)
    except Exception as e:
        print(f"Synthetic dataset not found at {synthetic_path}. Generate dataset before running.")


def run_basic_checks(
    split_dataset: datasets.DatasetDict,
    tokenizer: AutoTokenizer,
    num_examples: int = 3,
    max_token_check_examples: int = 1000,
    ) -> None:
    """Sanity checks on the saved synthetic dataset."""
    # ---- Basic structural checks ----
    if not isinstance(split_dataset, datasets.DatasetDict):
        print("[CHECK] Dataset is not a DatasetDict; skipping split checks.")
        return

    if "train" not in split_dataset:
        print("[CHECK] No 'train' split found; available keys:", list(split_dataset.keys()))
        return

    train_len = len(split_dataset["train"])
    test_len = len(split_dataset.get("test", []))
    total_len = train_len + test_len
    print(f"[CHECK] Columns: {split_dataset['train'].column_names}")
    print(f"[CHECK] Split sizes -> train: {train_len}, test: {test_len}, total: {total_len}")

    if train_len == 0:
        print("[CHECK] Train split is empty; nothing to inspect.")
        return

    # ---- Token ID range check (input_ids) ----
    print("\n[CHECK] Verifying token ID range on train split...")
    max_token_id = -1
    min_token_id = float("inf")
    checked = 0

    n_to_check = min(max_token_check_examples, train_len)
    for ex in split_dataset["train"].select(range(n_to_check)):
        input_ids = ex["input_ids"]
        if isinstance(input_ids, torch.Tensor):
            input_ids = input_ids.tolist()
        if not input_ids:
            continue
        max_token_id = max(max_token_id, max(input_ids))
        min_token_id = min(min_token_id, min(input_ids))
        checked += 1

    if checked == 0:
        print("[CHECK] No non-empty examples to check token ranges.")
    else:
        print(f"[CHECK] Token ID range (first {checked} train examples): {min_token_id} .. {max_token_id}")
        print(f"[CHECK] Tokenizer vocab size: {len(tokenizer)}")
        if max_token_id >= len(tokenizer):
            print(
                f"[CHECK][WARN] Max token ID ({max_token_id}) >= vocab size ({len(tokenizer)}). "
                "This will cause issues in training; check tokenizer/model mismatch."
            )
        else:
            print("[CHECK] ✓ All token IDs are within vocabulary range.")

    # ---- Per-example label mask + decode ----
    print(f"\n[CHECK] Showing up to {num_examples} decoded samples from 'train' with label mask info:")
    sample_ds = split_dataset["train"]
    for i in range(min(num_examples, len(sample_ds))):
        rec = sample_ds[i]
        input_ids = rec["input_ids"]
        attn = rec.get("attention_mask", None)
        labels = rec.get("labels", None)

        if isinstance(input_ids, torch.Tensor):
            input_ids = input_ids.tolist()
        if attn is not None and isinstance(attn, torch.Tensor):
            attn = attn.tolist()
        if labels is not None and isinstance(labels, torch.Tensor):
            labels = labels.tolist()

        # length checks
        if attn is not None and len(input_ids) != len(attn):
            print(f"[CHECK][WARN] Sample {i}: len(input_ids) != len(attention_mask)")
        if labels is not None and len(input_ids) != len(labels):
            print(f"[CHECK][WARN] Sample {i}: len(input_ids) != len(labels)")

        # find response region according to labels
        response_idxs = []
        if labels is not None:
            response_idxs = [j for j, lab in enumerate(labels) if lab != -100]

        if not response_idxs:
            print(f"[CHECK][WARN] Sample {i}: no labels != -100 (no response region detected)")
            prompt_end = len(input_ids)
        else:
            prompt_end = response_idxs[0]
            # ensure labels match input_ids in response region
            mismatches = 0
            for j in response_idxs:
                if labels[j] != input_ids[j]:
                    mismatches += 1
            if mismatches > 0:
                print(f"[CHECK][WARN] Sample {i}: {mismatches} mismatches where labels != input_ids in response region")

        # build trimmed_ids for decoding according to attention_mask
        if attn is not None:
            trimmed_ids = [tid for tid, mask in zip(input_ids, attn) if mask == 1]
        else:
            trimmed_ids = input_ids

        # decode full text (trimmed by attention_mask)
        decoded_full = tokenizer.decode(trimmed_ids, skip_special_tokens=True)

        # decode prompt/response separately using labels
        prompt_ids = input_ids[:prompt_end]
        response_ids = input_ids[prompt_end:] if prompt_end < len(input_ids) else []

        decoded_prompt = tokenizer.decode(prompt_ids, skip_special_tokens=True)
        decoded_response = tokenizer.decode(response_ids, skip_special_tokens=True)

        print(f"\n[CHECK] Sample {i}:")
        print(f"  length(input_ids)={len(input_ids)}, prompt_end={prompt_end}, response_len={len(response_ids)}")
        print(f"  first 30 token ids: {input_ids[:30]}")
        print(f"  first 30 labels:    {labels[:30] if labels is not None else 'N/A'}")
        print(f"  [decoded] PROMPT (tail, up to 200 chars): {decoded_prompt[-200:]}")
        print(f"  [decoded] RESPONSE (head, up to 200 chars): {decoded_response[:200]}")
        print(f"  [decoded] FULL (trimmed by attention_mask, up to 200 chars): {decoded_full[:200]}")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Generate synthetic dataset with energy tracking")
    parser.add_argument("--config", type=str, required=True, help="Path to experiment config YAML")
    args = parser.parse_args()

    cfg = load_config(args.config)
    run_dir = Path(getattr(cfg, "run_dir", None) or cfg.get("output.run_dir", None) or getattr(cfg, "output_dir", "logs"))
    run_dir.mkdir(parents=True, exist_ok=True)
    tracker = EnergyTracker(run_dir=str(run_dir), experiment_name="synthetic_generation", config=cfg)

    ds = generate_synthetic_dataset(cfg, energy_tracker=tracker)
    tracker.save_summary()
    try:
        run_basic_checks(ds, AutoTokenizer.from_pretrained(cfg.tokenizer_name))
    except Exception as e:
        print(f"[CHECK] Skipping dataset checks due to error: {e}")
