#!/usr/bin/env python3
import os
import sys
import json
import argparse
from pathlib import Path
from typing import List, Tuple, Dict

from dotenv import load_dotenv
import torch
import torch._dynamo
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

from src.utils import (
    output_decoder,
    read_prompt_template,
    print_label_stats,
    get_unique_symbols,
)
from src.substitute_tokens import (
    create_token_substitution,
    apply_token_mapping,
    create_modified_prompt,
    create_formal_language_prompt,
)

torch._dynamo.config.suppress_errors = True
seed = 42

load_dotenv()
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
DATASET_PATH = os.getenv("DATASET_DIR")
k = 5


# ----------------------------- helpers -----------------------------

def load_data(language: str, type_of_set: str | None = None, n: int = 100) -> Tuple[List[str], List[int]]:
    """Load sequences and labels from DATASET_DIR/<language>/(data|labels).(train|test)."""
    if type_of_set is None:
        sys.exit(1)

    base_path = Path(f"{DATASET_PATH}/{language}/")
    sequences: List[str] = []
    labels: List[int] = []

    if type_of_set == "train":
        with open(base_path / "data.train", encoding="utf-8") as f:
            sequences = [line.strip() for line in f][:n]
        with open(base_path / "labels.train", encoding="utf-8") as f:
            labels = [int(line.strip()) for line in f][:n]

    elif type_of_set == "test":
        with open(base_path / "data.test", encoding="utf-8") as f:
            sequences = [line.strip() for line in f]
        with open(base_path / "labels.test", encoding="utf-8") as f:
            labels = [int(line.strip()) for line in f]
    else:
        raise ValueError(f"Unknown type_of_set: {type_of_set}")

    return sequences, labels


def create_prompt(
    examples: List[Tuple[str, int]],
    val_sequence: str,
    prompt_template: str,
    mapping: Dict[str, str] | None = None,
    encoding_technique: str = "none",
    k: int = 5,
) -> str:
    """Create the prompt text by injecting examples + test string according to encoding."""
    if encoding_technique == "many_to_one":
        if mapping is not None:
            modified_prompt = create_modified_prompt(prompt_template, examples, mapping, encoding_technique, 1)
            prompt = modified_prompt.replace(
                "test_string",
                apply_token_mapping(val_sequence, mapping, encoding_technique=encoding_technique, k=1),
            )
        else:
            modified_prompt = create_formal_language_prompt(prompt_template, examples)
            prompt = modified_prompt.replace("test_string", val_sequence)
    else:
        keff = 0
        if encoding_technique == "one_to_one":
            keff = 1
        elif encoding_technique == "one_to_many":
            keff = k

        if mapping is not None:
            modified_prompt = create_modified_prompt(prompt_template, examples, mapping, encoding_technique, keff)
            prompt = modified_prompt.replace(
                "test_string",
                apply_token_mapping(val_sequence, mapping, encoding_technique=encoding_technique, k=keff),
            )
        else:
            modified_prompt = create_formal_language_prompt(prompt_template, examples)
            prompt = modified_prompt.replace("test_string", val_sequence)
    
    print(f"prompt:\n{prompt}\n---")
    return prompt


def compute_accuracy(raw_results: List[Dict]) -> float:
    correct = 0
    total = 0
    for entry in raw_results:
        if entry.get("model_output") == entry.get("true_label"):
            correct += 1
        total += 1
    return (correct / total) if total else 0.0


def build_chat_inputs(
    tokenizer: AutoTokenizer,
    raw_prompts: List[str],
    system_prompt: str | None = None,
) -> List[str]:
    """Wrap plain prompts with the model's chat template."""
    chat_inputs: List[str] = []
    for p in raw_prompts:
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": p})
        chat_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,  # appends assistant header
        )
        chat_inputs.append(chat_text)
    return chat_inputs


def collect_stop_ids(tokenizer: AutoTokenizer) -> List[int]:
    """Return a robust set of stop token ids (eos + Llama-3 <|eot_id|> if present)."""
    stop_ids: List[int] = []
    if getattr(tokenizer, "eos_token_id", None) is not None:
        stop_ids.append(tokenizer.eos_token_id)
    try:
        eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
        if isinstance(eot_id, int) and eot_id >= 0:
            stop_ids.append(eot_id)
    except Exception:
        pass
    # dedupe
    seen = set()
    stop_ids = [x for x in stop_ids if not (x in seen or seen.add(x))]
    return stop_ids


# ----------------------------- core -----------------------------

def run_validation(
    examples: List[Tuple[str, int]],
    val_sequences: List[str],
    val_labels: List[int],
    model: str,
    prompt_template: str,
    mapping: Dict[str, str] | None = None,
    file_path: str = "default",
    encoding_technique: str = "none",
    lm: LLM | None = None,
    sampling_params: SamplingParams | None = None,
    tokenizer: AutoTokenizer | None = None,
    system_prompt: str | None = None,
) -> None:
    """Run validation via vLLM using proper chat formatting."""
    print(f"Creating validation for model: {model}")

    # Load the results file
    with open(file_path, "r", encoding="utf-8") as file:
        results = json.load(file)

    # Skip sequences already processed
    raw_sequences = {entry["sequence"] for entry in results.get("raw_results", [])}
    all_requests: List[Tuple[Tuple[str, int], str]] = []
    
    k = 2
    for i, (seq, label) in enumerate(zip(val_sequences, val_labels)):
        if seq in raw_sequences:
            print(f"Skipping sequence {i}: {seq}")
            continue
        prompt_text = create_prompt(examples, seq, prompt_template, mapping, encoding_technique, k)
        all_requests.append(((seq, label), prompt_text))

    if not all_requests:
        print("No new sequences to process.")
        return

    # Apply chat template
    assert tokenizer is not None, "tokenizer is required"
    chat_inputs = build_chat_inputs(tokenizer, [p for (_, p) in all_requests], system_prompt=system_prompt)

    # Generate
    responses = lm.generate(chat_inputs, sampling_params, use_tqdm=True)

    # Reload file (in case another process wrote meanwhile)
    with open(file_path, "r", encoding="utf-8") as file:
        results = json.load(file)

    # Append outputs aligned with all_requests’ order
    for i, resp in enumerate(responses):
        (seq, true_label), _ = all_requests[i]
        # one output per request (n=1)
        out_text = resp.outputs[0].text if resp.outputs and resp.outputs[0].text is not None else ""
        res = {
            "sequence": seq,
            "raw_response": out_text,
            "model_output": int(output_decoder(out_text, prompt_template)) if out_text else -1,
            "true_label": true_label,
        }
        results["raw_results"].append(res)

    # Update metrics
    results["validation_size"] = len(results["raw_results"])
    results["accuracy"] = compute_accuracy(results["raw_results"])

    # Save
    with open(file_path, "w", encoding="utf-8") as file:
        json.dump(results, file, indent=4, ensure_ascii=False)


def main():
    parser = argparse.ArgumentParser(description="Batch experiment with LLMs on formal languages")
    parser.add_argument("--language", required=True, help="Formal language task name")
    parser.add_argument("--model", required=True, help="Model name")
    parser.add_argument("--sample_sizes", type=str, required=True, help="Comma-separated list of sample sizes")
    parser.add_argument("--validation_sets", type=str, required=True, help="Comma-separated list of validation sets (long,short)")
    parser.add_argument("--prompt_template", required=True, help="Type of prompt template (io_prompt|zsr_prompt)")
    parser.add_argument("--encoding_technique", required=True, help="Encoding technique")
    parser.add_argument("--system", default=None, help="Optional system prompt (prepended in chat template)")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    parser.add_argument("--validation_size", type=int, help="Number of validation examples to test (default: all)")
    parser.add_argument("--min_long_validation_length", type=int, help="Minimum sequence length for long validation set")
    parser.add_argument("--balance_validation", action="store_true", help="Whether to ensure balanced labels in validation data")
    args = parser.parse_args()

    # Load base prompt template TEXT (not the name)
    prompt_template_path = (
        "prompts/io_prompt_template.txt" if args.prompt_template == "io_prompt" else "prompts/zsr_prompt_template.txt"
    )
    base_prompt = read_prompt_template(prompt_template_path)

    # Tokenizer for chat formatting + stop ids
    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
    stop_ids = collect_stop_ids(tokenizer)

    # vLLM
    llm = LLM(
        model=args.model,
        tensor_parallel_size=torch.cuda.device_count() or 1,
        distributed_executor_backend="mp",
        dtype=torch.float16,
        gpu_memory_utilization=0.95,
        enforce_eager=False,
        max_model_len=8192,           # ↓ cap context to fit KV cache
        disable_custom_all_reduce=True,
        trust_remote_code=True,
        enable_prefix_caching=True,
    )

    # Sampling
    sampling_params = SamplingParams(
        n=1,
        top_p=1.0,
        temperature=0.2,
        seed=seed,
        max_tokens=1000 if args.prompt_template == "zsr_prompt" else 1,
        stop_token_ids=stop_ids if stop_ids else None,
    )

    # Data
    train_sequences, train_labels = load_data(args.language, "train")
    train_data = list(zip(train_sequences, train_labels))
    print_label_stats(train_data, "Full training data")

    print(f"Language: {args.language}")
    sample_sizes = [int(size) for size in args.sample_sizes.split(",")]
    validation_sets = args.validation_sets.split(",")
    sizes = [(s, v) for s in sample_sizes for v in validation_sets]
    print(f"Sample sizes: {sizes}")

    for size, val_set in sizes:
        print(f"Processing size {size} and validation set {val_set}")

        file_path = f"experiments/{args.model.replace('/', '_')}/{args.prompt_template}/{args.encoding_technique}/{args.language}/{size}_{val_set}_{args.seed}.json"
        os.makedirs(os.path.dirname(file_path), exist_ok=True)

        # If file exists & already computed, skip
        if os.path.exists(file_path):
            with open(file_path, "r", encoding="utf-8") as file:
                existing = json.load(file)
            if existing.get("accuracy", -1) != -1:
                print(
                    f"Skipping size {size}, validation set {val_set}, encoding {args.encoding_technique}: already processed"
                )
                continue
        else:
            # Initialize results file
            val_sequences_sampled, val_labels_sampled = load_data(args.language, "test")
            metrics_dict = {
                "accuracy": -1,
                "encoding_technique": args.encoding_technique,  # (kept original key spelling)
                "prompt_template": args.prompt_template,
                "raw_results": [],
                "examples_used": list(zip(train_sequences, train_labels)),
                "model": args.model,
                "validation_size": len(val_sequences_sampled),
                "validation_set": val_sequences_sampled,
                "label_distribution": {
                    "training": {
                        "label_0": sum(1 for ex in train_data if ex[1] == 0),
                        "label_1": sum(1 for ex in train_data if ex[1] == 1),
                    },
                    "validation": {
                        "label_0": sum(1 for label in val_labels_sampled if label == 0),
                        "label_1": sum(1 for label in val_labels_sampled if label == 1),
                    },
                },
            }
            with open(file_path, "w", encoding="utf-8") as f:
                json.dump(metrics_dict, f, indent=4, ensure_ascii=False)
            print(f"File created: {file_path}")

        # (Re)load validation data right before running
        val_sequences_sampled, val_labels_sampled = load_data(args.language, "test")
        print(f"Train data size: {len(train_sequences)}")
        print(f"Validation data size: {len(val_sequences_sampled)}")

        unique_symbols = get_unique_symbols(train_data, val_sequences_sampled)
        print(f"Unique symbols: {unique_symbols}")

        mapping = create_token_substitution(
            encoding_technique=args.encoding_technique,
            symbols=unique_symbols,
            prompt_text=base_prompt,
            model_name=args.model,
            seed=args.seed,
            verbose=False,
        )

        # Run validation with proper chat formatting
        run_validation(
            examples=list(zip(train_sequences, train_labels)),
            val_sequences=val_sequences_sampled,
            val_labels=val_labels_sampled,
            model=args.model,
            prompt_template=args.prompt_template,
            mapping=mapping,
            file_path=file_path,
            encoding_technique=args.encoding_technique,
            lm=llm,
            sampling_params=sampling_params,
            tokenizer=tokenizer,
            system_prompt=args.system,
        )


if __name__ == "__main__":
    main()
