import re
import os
import json
from tqdm import tqdm
from timeit import default_timer as timer
from datasets import load_dataset
from utils.commongen import complete_prompt
from utils.utils import generate_variants_commongen_ctrlg
import outlines
from outlines.types import Regex
from outlines.models import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

model_names = ["ctrlg/gpt2-large_common-gen", "openai-community/gpt2-large"]
max_tokens = [32]

def build_concept_regex(variants, tokenizer):
    lookaheads = []
    for variant_group in variants:
        escaped_variants = [re.escape(v) for v in variant_group]
        group_pattern = "(" + "|".join(escaped_variants) + ")"
        lookaheads.append(f"(?=.*{group_pattern})")

    return "".join(lookaheads) + ".*" + tokenizer.eos_token

for model_name in model_names:
    for max_token in max_tokens:
        print(f"\n=== Processing model: {model_name} max new tokens: {max_token} ===")

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        hf_model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda:0")
        model = transformers.Transformers(hf_model, tokenizer)

        dataset = load_dataset("common_gen", split="validation")
        dataset = dataset.add_column("idx", list(range(len(dataset))))

        results_dir = f"results_outlines_commongen/{model_name.split('/')[-1]}_max_{max_token}"
        os.makedirs(results_dir, exist_ok=True)
        references_path = os.path.join(results_dir, "references.json")
        candidates_path = os.path.join(results_dir, "candidates.json")
        times_path = os.path.join(results_dir, "times.json")

        grouped_data = []
        i = 0
        while i < len(dataset):
            current_concepts = set(dataset[i]["concepts"])
            group = []
            while i < len(dataset) and set(dataset[i]["concepts"]) == current_concepts:
                group.append(dataset[i])
                i += 1
            grouped_data.append(group)

        references = []
        candidates = []
        times = []

        for group in tqdm(grouped_data, desc="Processing groups"):
            example = group[0]
            concepts = example["concepts"]
            prompt = complete_prompt(concepts)
            variants = generate_variants_commongen_ctrlg(concepts)
            concept_pattern = build_concept_regex(variants, tokenizer)
            regex_constraint = Regex(concept_pattern)
            generator = outlines.Generator(model, regex_constraint)

            try:
                start = timer()
                output = generator(prompt, max_new_tokens=max_token)
                end = timer()

                for ex in group:
                    references.append({
                        "concept_set_idx": ex["idx"],
                        "concepts": ex["concepts"],
                        "target": ex["target"]
                    })

                candidates.append({
                    "concept_set_idx": example["idx"],
                    "group_size": len(group),
                    "concepts": concepts,
                    "sentence": output
                })

                times.append({
                    "concept_set_idx": example["idx"],
                    "time": end - start
                })

            except Exception as e:
                print(f"Error {concepts}: {e}")

        with open(references_path, "w") as f:
            json.dump(references, f, indent=4)

        with open(candidates_path, "w") as f:
            json.dump(candidates, f, indent=4)

        with open(times_path, "w") as f:
            json.dump(times, f, indent=4)

        print(f"Saved results to: {results_dir}")