#!/usr/bin/env python3
import os
import sys
import json
import argparse
from pathlib import Path
from typing import List, Tuple, Dict

from dotenv import load_dotenv
import torch
import torch._dynamo
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

from src.utils import (
    output_decoder,
    read_prompt_template,
    print_label_stats,
    get_unique_symbols,
)
from src.substitute_tokens import (
    create_token_substitution,
    apply_token_mapping,
    create_modified_prompt,
    create_formal_language_prompt,
)

torch._dynamo.config.suppress_errors = True
seed = 42

load_dotenv()
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
DATASET_PATH = os.getenv("DATASET_DIR")

# ----------------------------- helpers -----------------------------

def load_data(language: str, type_of_set: str | None = None, n: int = 100) -> Tuple[List[str], List[int]]:
    """Load sequences and labels from DATASET_DIR/<language>/(data|labels).(train|test)."""
    if type_of_set is None:
        sys.exit(1)

    base_path = Path(f"{DATASET_PATH}/{language}/")
    sequences: List[str] = []
    labels: List[int] = []

    if type_of_set == "train":
        with open(base_path / "data.train", encoding="utf-8") as f:
            sequences = [line.strip() for line in f][:n]
        with open(base_path / "labels.train", encoding="utf-8") as f:
            labels = [int(line.strip()) for line in f][:n]

    elif type_of_set == "test":
        with open(base_path / "data.test", encoding="utf-8") as f:
            sequences = [line.strip() for line in f]
        with open(base_path / "labels.test", encoding="utf-8") as f:
            labels = [int(line.strip()) for line in f]
    else:
        raise ValueError(f"Unknown type_of_set: {type_of_set}")

    return sequences, labels


def create_prompt(
    examples: List[Tuple[str, int]],
    val_sequence: str,
    prompt_template: str,
    mapping: Dict[str, str] | None = None,
    encoding_technique: str = "none",
    k: int = 5,
) -> str:
    """Create the prompt text by injecting examples + test string according to encoding."""
    if encoding_technique == "many_to_one":
        if mapping is not None:
            modified_prompt = create_modified_prompt(prompt_template, examples, mapping, encoding_technique, 1)
            prompt = modified_prompt.replace(
                "test_string",
                apply_token_mapping(val_sequence, mapping, encoding_technique=encoding_technique, k=1),
            )
        else:
            modified_prompt = create_formal_language_prompt(prompt_template, examples)
            prompt = modified_prompt.replace("test_string", val_sequence)
    else:
        keff = 0
        if encoding_technique == "one_to_one":
            keff = 1
        elif encoding_technique == "one_to_many":
            keff = k

        if mapping is not None:
            modified_prompt = create_modified_prompt(prompt_template, examples, mapping, encoding_technique, keff)
            prompt = modified_prompt.replace(
                "test_string",
                apply_token_mapping(val_sequence, mapping, encoding_technique=encoding_technique, k=keff),
            )
        else:
            modified_prompt = create_formal_language_prompt(prompt_template, examples)
            prompt = modified_prompt.replace("test_string", val_sequence)
    
    return prompt


def compute_accuracy(raw_results: List[Dict]) -> float:
    correct = 0
    total = 0
    for entry in raw_results:
        if entry.get("model_output") == entry.get("true_label"):
            correct += 1
        total += 1
    return (correct / total) if total else 0.0


def build_chat_inputs(
    tokenizer: AutoTokenizer,
    raw_prompts: List[str],
    system_prompt: str | None = None,
) -> List[str]:
    """Wrap plain prompts with the model's chat template."""
    chat_inputs: List[str] = []
    for p in raw_prompts:
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": p})
        chat_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,  # appends assistant header
        )
        chat_inputs.append(chat_text)
    return chat_inputs


def collect_stop_ids(tokenizer: AutoTokenizer) -> List[int]:
    """Return a robust set of stop token ids (eos + Llama-3 <|eot_id|> if present)."""
    stop_ids: List[int] = []
    if getattr(tokenizer, "eos_token_id", None) is not None:
        stop_ids.append(tokenizer.eos_token_id)
    try:
        eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
        if isinstance(eot_id, int) and eot_id >= 0:
            stop_ids.append(eot_id)
    except Exception:
        pass
    # dedupe
    seen = set()
    stop_ids = [x for x in stop_ids if not (x in seen or seen.add(x))]
    return stop_ids


# ----------------------------- core -----------------------------

def run_validation(
    examples: List[Tuple[str, int]],
    val_sequences: List[str],
    val_labels: List[int],
    model: str,
    prompt_template: str,
    mapping: Dict[str, str] | None = None,
    file_path: str = "default",
    encoding_technique: str = "none",
    lm: LLM | None = None,
    sampling_params: SamplingParams | None = None,
    tokenizer: AutoTokenizer | None = None,
    system_prompt: str | None = None,
    k: int = 5,
) -> None:
    """Run validation via vLLM using proper chat formatting."""
    print(f"Creating validation for model: {model}")

    # Load the results file
    with open(file_path, "r", encoding="utf-8") as file:
        results = json.load(file)

    # Skip sequences already processed
    raw_sequences = {entry["sequence"] for entry in results.get("raw_results", [])}
    all_requests: List[Tuple[Tuple[str, int], str]] = []
    
    for i, (seq, label) in enumerate(zip(val_sequences, val_labels)):
        if seq in raw_sequences:
            print(f"Skipping sequence {i}: {seq}")
            continue
        prompt_text = create_prompt(examples, seq, prompt_template, mapping, encoding_technique, k)
        all_requests.append(((seq, label), prompt_text))

    if not all_requests:
        print("No new sequences to process.")
        return

    # Apply chat template
    assert tokenizer is not None, "tokenizer is required"
    chat_inputs = build_chat_inputs(tokenizer, [p for (_, p) in all_requests], system_prompt=system_prompt)

    # Generate
    responses = lm.generate(chat_inputs, sampling_params, use_tqdm=True)

    # Reload file (in case another process wrote meanwhile)
    with open(file_path, "r", encoding="utf-8") as file:
        results = json.load(file)

    # Append outputs aligned with all_requests’ order
    for i, resp in enumerate(responses):
        (seq, true_label), _ = all_requests[i]
        # one output per request (n=1)
        out_text = resp.outputs[0].text if resp.outputs and resp.outputs[0].text is not None else ""
        res = {
            "sequence": seq,
            "raw_response": out_text,
            "model_output": int(output_decoder(out_text, prompt_template)) if out_text else -1,
            "true_label": true_label,
        }
        results["raw_results"].append(res)

    # Update metrics
    results["validation_size"] = len(results["raw_results"])
    results["accuracy"] = compute_accuracy(results["raw_results"])

    # Save
    with open(file_path, "w", encoding="utf-8") as file:
        json.dump(results, file, indent=4, ensure_ascii=False)


def main():
    parser = argparse.ArgumentParser(description="Batch experiment with LLMs on formal languages")
    parser.add_argument("--model", required=True, help="Model name")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    args = parser.parse_args()

    # Tokenizer for chat formatting + stop ids
    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
    stop_ids = collect_stop_ids(tokenizer)

    # # vLLM
    llm = LLM(
        model=args.model,
        tensor_parallel_size=torch.cuda.device_count() or 1,
        distributed_executor_backend="mp",
        dtype=torch.float16,
        gpu_memory_utilization=0.95,
        enforce_eager=False,
        max_model_len=32768,           # ↓ cap context to fit KV cache
        disable_custom_all_reduce=True,
        trust_remote_code=True,
        enable_prefix_caching=True,
    )

    languages = ["binary-addition", "binary-multiplication", "bucket-sort", "compute-sqrt", "cycle-navigation", "dyck-2-3", "even-pairs", "first", "majority", "marked-copy", "marked-reversal", "missing-duplicate-string", "modular-arithmetic-simple", "odds-first", "parity", "repeat-01", "stack-manipulation", "unmarked-reversal"]

    encoding_techniques = ["one_to_one", "one_to_many", "many_to_one"]

    prompt_templates = ["io_prompt", "zsr_prompt"]

    for encoding_technique in encoding_techniques:
        for language in languages:
            for prompt_template in prompt_templates:
                # Sampling
                sampling_params = SamplingParams(
                    n=1,
                    top_p=1.0,
                    temperature=0.2,
                    seed=seed,
                    max_tokens=1000 if prompt_template == "zsr_prompt" else 1,
                    stop_token_ids=stop_ids if stop_ids else None,
                )
                base_prompt = read_prompt_template(f"prompts/{prompt_template}_template.txt")
                # Data
                train_sequences, train_labels = load_data(language, "train")
                train_data = list(zip(train_sequences, train_labels))
                print_label_stats(train_data, "Full training data")

                ks = range(1, 6) if encoding_technique == "one_to_many" else [1]
                for k in ks:
                    file_path = f"experiments/{args.model.replace('/', '_')}/{prompt_template}/{encoding_technique}/{language}/100_long_{args.seed}.json" if encoding_technique != "one_to_many" else f"experiments/{args.model.replace('/', '_')}/{prompt_template}/{encoding_technique}_{k}/{language}/100_long__{args.seed}.json"
                    os.makedirs(os.path.dirname(file_path), exist_ok=True)

                    # If file exists & already computed, skip
                    if os.path.exists(file_path):
                        with open(file_path, "r", encoding="utf-8") as file:
                            existing = json.load(file)
                        if existing.get("accuracy", -1) != -1:
                            print(
                                f"Skipping encoding {encoding_technique}, prompt {prompt_template}, language {language}: already processed"
                            )
                            continue
                    else:
                        # Initialize results file
                        val_sequences_sampled, val_labels_sampled = load_data(language, "test")
                        metrics_dict = {
                            "accuracy": -1,
                            "encoding_technique": encoding_technique,  # (kept original key spelling)
                            "prompt_template": prompt_template,
                            "raw_results": [],
                            "examples_used": list(zip(train_sequences, train_labels)),
                            "model": args.model,
                            "validation_size": len(val_sequences_sampled),
                            "validation_set": val_sequences_sampled,
                            "label_distribution": {
                                "training": {
                                    "label_0": sum(1 for ex in train_data if ex[1] == 0),
                                    "label_1": sum(1 for ex in train_data if ex[1] == 1),
                                },
                                "validation": {
                                    "label_0": sum(1 for label in val_labels_sampled if label == 0),
                                    "label_1": sum(1 for label in val_labels_sampled if label == 1),
                                },
                            },
                        }
                        with open(file_path, "w", encoding="utf-8") as f:
                            json.dump(metrics_dict, f, indent=4, ensure_ascii=False)
                        print(f"File created: {file_path}")

                    # (Re)load validation data right before running
                    val_sequences_sampled, val_labels_sampled = load_data(language, "test")
                    print(f"Train data size: {len(train_sequences)}")
                    print(f"Validation data size: {len(val_sequences_sampled)}")

                    unique_symbols = get_unique_symbols(train_data, val_sequences_sampled)
                    print(f"Unique symbols: {unique_symbols}")

                    mapping = create_token_substitution(
                        encoding_technique=encoding_technique,
                        symbols=unique_symbols,
                        prompt_text=base_prompt,
                        model_name=args.model,
                        seed=args.seed,
                        verbose=False,
                    )
                    print(f"prompt template: {prompt_template}, encoding: {encoding_technique}, k={k}")
                    # Run validation with proper chat formatting
                    run_validation(
                        examples=list(zip(train_sequences, train_labels)),
                        val_sequences=val_sequences_sampled,
                        val_labels=val_labels_sampled,
                        model=args.model,
                        prompt_template=prompt_template,
                        mapping=mapping,
                        file_path=file_path,
                        encoding_technique=encoding_technique,
                        lm=llm,
                        sampling_params=sampling_params,
                        tokenizer=tokenizer,
                        k=k
                    )


if __name__ == "__main__":
    main()
