import json
import os

import torch
from tqdm import tqdm

import utils
import lib
import mcmc

def load_gad_tasks(split, subset):
    assert split in ["SLIA", "CP", "BV4"] 
    slia_tasks_path = f"datasets/GAD-dataset/{split}.jsonl"
    slia_tasks = []
    with open(slia_tasks_path, "r") as f:
        for line in f:
            task = json.loads(line)
            slia_tasks.append(task)
    if subset is not None:
        slia_tasks = [slia_tasks[i] for i in subset]
    return slia_tasks

def run_mcmc_gad_tasks(split, subset):
    model_id = "meta-llama/Llama-3.1-8B-Instruct"

    model = lib.ConstrainedModel(model_id, None, torch_dtype=torch.bfloat16)

    root_log_dir = "gad_dataset_runs"

    split_log_dir = f"{root_log_dir}/{utils.timestamp()}-{split}"
    # make sure the directory exists
    os.makedirs(split_log_dir, exist_ok=True)

    # Load the GAD tasks
    slia_tasks = load_gad_tasks(split, subset)
    print(f"Loaded {len(slia_tasks)} tasks")
    print([task["id"] for task in slia_tasks])

    n_samples = 100
    n_steps = 11
    max_new_tokens = 128
    propose_styles = ["restart", "priority", "prefix"]

    for task in tqdm(slia_tasks):
        task_id = task["id"] 
        task_prompt = task["prompt"]
        task_grammar = task["grammar"]
        print(f"Task ID: {task_id}")

        model._set_grammar_constraint(task_grammar)
        for propose_style in propose_styles:
            print(f"Task ID: {task_id}, Propose Style: {propose_style}")

            model._set_grammar_constraint(task_grammar)
            mcmc_runner = mcmc.MCMC(
                model=model,
                prompt=task_prompt,
                propose_style=propose_style,
                name_prefix=task_id,
                root_log_dir=split_log_dir,
            )
            mcmc_runner.get_samples(
                n_samples=n_samples,
                n_steps=n_steps,
                max_new_tokens=max_new_tokens,
            )


if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser()
    parser.add_argument("--split", required=True, choices=["SLIA", "CP", "BV4"])
    parser.add_argument("--subset", default=None)

    args = parser.parse_args()
    subset = args.subset
    if subset is not None:
        subset = [int(i) for i in subset.split(",")]
    print(f"Subset: {subset}")

    run_mcmc_gad_tasks(args.split, subset)