import os
import re
import json
from tqdm import tqdm
from timeit import default_timer as timer
from guidance import models, gen
from datasets import load_dataset
from utils.commongen import complete_prompt
from utils.utils import generate_variants_commongen_ctrlg
from itertools import permutations
from transformers import AutoTokenizer

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

model_names = ["ctrlg/gpt2-large_common-gen", "openai-community/gpt2-large"]
max_tokens = [16, 32, 64, 100]

def build_permutation_regex(variants, tokenizer):
    escaped_groups = [
        "(" + "|".join(re.escape(v) for v in group) + ")"
        for group in variants
    ]
    permuted_patterns = []
    for perm in permutations(escaped_groups):
        pattern = ".*?".join(perm)
        permuted_patterns.append(pattern)
    full_pattern = "(" + "|".join(permuted_patterns) + ")"
    return full_pattern + tokenizer.eos_token

for model_name in model_names:
    for max_token in max_tokens:

        print(f"\n=== Processing model: {model_name} max new tokens: {max_token} ===")

        dataset = load_dataset("common_gen", split="validation")
        dataset = dataset.add_column("idx", list(range(len(dataset))))

        results_dir = f"results_guidance/{model_name.split('/')[-1]}_max_{max_token}"
        os.makedirs(results_dir, exist_ok=True)
        references_path = os.path.join(results_dir, "references.json")
        candidates_path = os.path.join(results_dir, "candidates.json")
        times_path = os.path.join(results_dir, "times.json")

        gpt2 = models.Transformers(model_name)
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        grouped_data = []
        i = 0
        while i < len(dataset):
            current_concepts = set(dataset[i]["concepts"])
            group = []
            while i < len(dataset) and set(dataset[i]["concepts"]) == current_concepts:
                group.append(dataset[i])
                i += 1
            grouped_data.append(group)

        references = []
        candidates = []
        times = []

        for group in tqdm(grouped_data, desc="Processing groups"):
            example = group[0]
            concepts = example["concepts"]
            prompt = complete_prompt(concepts)
            variants = generate_variants_commongen_ctrlg(concepts)
            concept_pattern = build_permutation_regex(variants, tokenizer)

            try:
                start = timer()
                output = gpt2 + f"{prompt}{gen(regex=concept_pattern, max_tokens=max_token)}" # 
                end = timer()

                full_text = str(output)
                generated_text = full_text[len(prompt):].strip()

                for ex in group:
                    references.append({
                        "concept_set_idx": ex["idx"],
                        "concepts": ex["concepts"],
                        "target": ex["target"]
                    })

                candidates.append({
                    "concept_set_idx": example["idx"],
                    "group_size": len(group),
                    "concepts": concepts,
                    "sentence": generated_text
                })

                times.append({
                    "concept_set_idx": example["idx"],
                    "time": end - start
                })

            except Exception as e:
                print(f"Error {concepts}: {e}")

        with open(references_path, "w") as f:
            json.dump(references, f, indent=4)

        with open(candidates_path, "w") as f:
            json.dump(candidates, f, indent=4)

        with open(times_path, "w") as f:
            json.dump(times, f, indent=4)

        print(f"Saved results to: {results_dir}")