import os
import re
import json
import outlines
from tqdm import tqdm
from outlines.types import Regex
from outlines.models import transformers
from timeit import default_timer as timer
from utils.utils import generate_variants_commongen_ctrlg
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.ordered_commongen import complete_prompt_ordered_commongen

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

model_names = ["meta-llama/Llama-3.1-8B"]
max_tokens = [16]

def build_ordered_concept_regex(variants, tokenizer):
    pattern_parts = []
    for variant_group in variants:
        escaped_variants = [re.escape(v) for v in variant_group]
        group_pattern = "(" + "|".join(escaped_variants) + ")"
        pattern_parts.append(group_pattern)
    return ".*" + ".*".join(pattern_parts) + ".*" + tokenizer.eos_token

with open("../data/Ordered CommonGen/ordered_commongen.json", 'r') as f:
    dataset = json.load(f)

for model_name in model_names:
    for max_token in max_tokens:
        print(f"\n=== Processing model: {model_name} ===")

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        hf_model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda:1")
        model = transformers.Transformers(hf_model, tokenizer)

        results_dir = f"results_outlines_ordered/{model_name.split('/')[-1]}"
        os.makedirs(results_dir, exist_ok=True)
        references_path = os.path.join(results_dir, "references.json")
        candidates_path = os.path.join(results_dir, "candidates.json")
        times_path = os.path.join(results_dir, "times.json")

        references = []
        candidates = []
        times = []

        for example in tqdm(dataset, desc="Processing examples"):
            concepts = example["concepts"]
            prompt = complete_prompt_ordered_commongen(concepts)
            variants = generate_variants_commongen_ctrlg(concepts)
            concept_pattern = build_ordered_concept_regex(variants, tokenizer=tokenizer)
            regex_constraint = Regex(concept_pattern)
            generator = outlines.Generator(model, regex_constraint)
            try:
                start = timer()
                output = generator(prompt, max_new_tokens=max_token)
                end = timer()

                references.append({
                    "concept_set_idx": example["id"],
                    "concepts": example["concepts"],
                    "target": example["targets"]
                })

                candidates.append({
                    "concept_set_idx": example["id"],
                    "concepts": example["concepts"],
                    "sentence": output
                })

                times.append({
                    "concept_set_idx": example["id"],
                    "time": end - start
                })

            except Exception as e:
                print(f"Error {concepts}: {e}")

        with open(references_path, "w") as f:
            json.dump(references, f, indent=4)

        with open(candidates_path, "w") as f:
            json.dump(candidates, f, indent=4)

        with open(times_path, "w") as f:
            json.dump(times, f, indent=4)

        print(f"Saved results to: {results_dir}")