import os
import re
import json
from tqdm import tqdm
from timeit import default_timer as timer
from guidance import models, gen
from utils.utils import generate_variants_commongen_ctrlg
from utils.ordered_commongen import complete_prompt_ordered_commongen
from transformers import AutoTokenizer

os.environ["CUDA_VISIBLE_DEVICES"] = "3"

model_names = ["meta-llama/Llama-3.1-8B"]
max_tokens = [32, 64, 100]

def build_ordered_concept_regex(variants, tokenizer):
    pattern_parts = []
    for variant_group in variants:
        escaped_variants = [re.escape(v) for v in variant_group]
        group_pattern = "(" + "|".join(escaped_variants) + ")"
        pattern_parts.append(group_pattern)
    return ".*" + ".*".join(pattern_parts) + ".*" + tokenizer.eos_token

with open("../data/Ordered CommonGen/ordered_commongen.json", 'r') as f:
    dataset = json.load(f)

for model_name in model_names:
    for max_token in max_tokens:
        print(f"\n=== Processing model: {model_name} ===")

        results_dir = f"guidance_ordered/{model_name.split('/')[-1]}_max_{max_token}"
        os.makedirs(results_dir, exist_ok=True)
        references_path = os.path.join(results_dir, "references.json")
        candidates_path = os.path.join(results_dir, "candidates.json")
        times_path = os.path.join(results_dir, "times.json")

        gpt2 = models.Transformers(model_name)
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        references = []
        candidates = []
        times = []

        for example in tqdm(dataset, desc="Processing ordered examples"):
            concepts = example["concepts"]
            prompt = complete_prompt_ordered_commongen(concepts)
            variants = generate_variants_commongen_ctrlg(concepts)
            concept_pattern = build_ordered_concept_regex(variants, tokenizer=tokenizer)

            try:
                start = timer()
                output = gpt2 + f"{prompt}{gen(max_tokens=max_token, regex=concept_pattern)}"
                end = timer()

                full_text = str(output)
                generated_text = full_text[len(prompt):].strip()

                references.append({
                    "concept_set_idx": example["id"],
                    "concepts": example["concepts"],
                    "target": example["targets"]
                })

                candidates.append({
                    "concept_set_idx": example["id"],
                    "concepts": example["concepts"],
                    "sentence": generated_text
                })

                times.append({
                    "concept_set_idx": example["id"],
                    "time": end - start
                })

            except Exception as e:
                print(f"Error {concepts}: {e}")

        with open(references_path, "w") as f:
            json.dump(references, f, indent=4)

        with open(candidates_path, "w") as f:
            json.dump(candidates, f, indent=4)

        with open(times_path, "w") as f:
            json.dump(times, f, indent=4)

        print(f"Saved results to: {results_dir}")