import os
import json
import torch
import random
import numpy as np
import warnings
import traceback
from tqdm import tqdm
from time import time
from argparse import ArgumentParser
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList

import utils.ctrlg as ctrlg
from utils.commongen import complete_prompt
from utils.utils import generate_variants_commongen_ctrlg

warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def str_to_bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise ValueError(f"Invalid boolean string: {v}")

def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_device
    device = f"cuda:{args.cuda_device}" if torch.cuda.is_available() else "cpu"

    dataset = load_dataset("allenai/common_gen")
    data_partition = args.data_partition
    SEED = args.seed
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
    base_model = AutoModelForCausalLM.from_pretrained(args.base_model_path).half().to(device)
    base_model.eval()
    hmm_model = ctrlg.HMM.from_pretrained(args.hmm_model_path).to(device)

    vocab_size = hmm_model.vocab_size
    eos_token_id = hmm_model.eos_token_id
    suffix = ".<|endoftext|>"
    suffix_ids = tokenizer.encode(suffix)
    prefix_ids = tokenizer.encode("")
    min_new_tokens = args.min_new_tokens
    max_new_tokens = args.max_new_tokens
    num_test = args.num_test
    beam_sizes = args.beam_sizes

    for beam_size in beam_sizes:
        start_run_time = time()

        references = []
        candidates = []
        generation_times = []

        all_indices = list(range(0, len(dataset[data_partition])))

        for _ in tqdm(range(0, num_test), desc=f"Beam size {beam_size}"):
            try:
                concept_idx = int(np.random.choice(all_indices))
                all_indices.remove(concept_idx)
                concepts = dataset[data_partition][concept_idx]["concepts"]
                concepts_set = set(concepts)

                references.append({
                    "concept_set_idx": concept_idx,
                    "concepts": concepts,
                    "target": dataset[data_partition][concept_idx]["target"]
                })

                for j in range(-2, 3):
                    if j == 0 or not (0 <= concept_idx + j < len(dataset[data_partition])):
                        continue
                    if set(dataset[data_partition][concept_idx + j]["concepts"]) == concepts_set:
                        references.append({
                            "concept_set_idx": concept_idx + j,
                            "concepts": concepts,
                            "target": dataset[data_partition][concept_idx + j]["target"]
                        })
                        if concept_idx + j in all_indices:
                            all_indices.remove(concept_idx + j)

                concepts_str = "\"{}\"".format(", ".join(concepts))
                sample_start_time = time()

                ac_builder = ctrlg.AhoCorasickBuilder(vocab_size)
                dfa_graphs = []
                keyphrases = generate_variants_commongen_ctrlg(concepts)
                for keyphrase in keyphrases:
                    patterns = [tokenizer.encode(x) for x in keyphrase]
                    dfa_graphs.append(ac_builder.build(patterns))
                dfa_graph = ctrlg.DFA_prod(dfa_graphs, mode='intersection')
                dfa_model = ctrlg.DFAModel(dfa_graph, vocab_size).to(device)

                automata_end_time = time()
                generation_start_time = time()

                prompt = complete_prompt(concepts_str)
                prompt_ids = tokenizer.encode(prompt)
                constraint_logits_processor = ctrlg.ConstraintLogitsProcessor(
                    hmm_model, dfa_model,
                    min_new_tokens, max_new_tokens,
                    prompt_ids, prefix_ids=prefix_ids, suffix_ids=suffix_ids
                )
                constraint_logits_processor.hmm_batch_size = beam_size

                input_ids = torch.tensor([prompt_ids], device=device)
                outputs = base_model.generate(
                    input_ids=input_ids,
                    do_sample=False,
                    length_penalty=0.2,
                    num_beams=beam_size,
                    num_return_sequences=beam_size,
                    min_new_tokens=min_new_tokens,
                    max_new_tokens=max_new_tokens,
                    logits_processor=LogitsProcessorList([constraint_logits_processor]),
                    pad_token_id=tokenizer.eos_token_id,
                )

                generated_ids = ctrlg.extract_generated_ids(outputs.tolist(), prompt_ids, suffix_ids, eos_token_id)
                generated_ids = ctrlg.rank_generated_ids(base_model, generated_ids, prompt_ids, suffix_ids, length_penalty=0.2)
                generated = generated_ids[:1]
                candidate = tokenizer.decode(list(generated[0]), skip_special_tokens=True)

                sample_end_time = time()
                generation_times.append({
                    "concept_set_idx": concept_idx,
                    "final_time": sample_end_time - sample_start_time,
                    "automata_time": automata_end_time - sample_start_time,
                    "generation_time": sample_end_time - generation_start_time,
                })

                candidates.append({
                    "concept_set_idx": concept_idx,
                    "concepts": concepts,
                    "sentence": candidate
                })

            except Exception as e:
                print(f"Error during generation: {str(e)}")
                traceback.print_exc()
                continue

        total_time = time() - start_run_time

        out_prefix = f"{data_partition}_beam_size_{beam_size}_seed_{SEED}"
        os.makedirs(args.output_dir, exist_ok=True)
        with open(f'{args.output_dir}/references_{out_prefix}.json', 'w') as f:
            json.dump(references, f, indent=4)
        with open(f'{args.output_dir}/candidates_{out_prefix}.json', 'w') as f:
            json.dump(candidates, f, indent=4)
        with open(f'{args.output_dir}/times/times_{out_prefix}.json', 'w') as f:
            json.dump({
                "total_time": total_time,
                "sample_times": generation_times
            }, f, indent=4)

    print("Done.")

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--cuda_device', type=str, default='0')
    parser.add_argument('--base_model_path', type=str, default='ctrlg/gpt2-large_common-gen') # "openai-community/gpt2-large" unsupervised
    parser.add_argument('--hmm_model_path', type=str, default='ctrlg/hmm_gpt2-large_common-gen_32768') # 'ctrlg/hmm_gpt2-large_common-gen_4096'
    parser.add_argument('--data_partition', type=str, default='validation')
    parser.add_argument('--num_test', type=int, default=993)
    parser.add_argument('--min_new_tokens', type=int, default=3)
    parser.add_argument('--max_new_tokens', type=int, default=32)
    parser.add_argument('--beam_sizes', nargs='+', type=int, default=[2, 4, 8, 16, 32, 64, 128])
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--output_dir', type=str, default='LLM/results/ctrlg')

    args = parser.parse_args()
    main(args)