import os
import warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import argparse
warnings.filterwarnings("ignore")
from timeit import default_timer as timer
import random
import torch
import traceback
import numpy as np
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import utils.logits_processors as lp
import utils.dfa_beam_search as dfa_beam_search

import json, traceback
from utils.utils import generate_variants_commongen_ctrlg
from utils.ordered_commongen import ordered_commongen_dfa, complete_prompt_ordered_commongen


instruct_model_list = ["meta-llama/Llama-3.1-8B-Instruct","mistralai/Mistral-Nemo-Instruct-2407"]

def clean_model_name(model_name):
    if model_name in instruct_model_list:
        model_name = model_name.replace(".","_")
        return model_name.split("/")[1].replace('-','_')
    else:
        model_name = model_name.replace(".","_")
        return model_name.split("/")[0].replace('-','_')

def main( 
    model_name:str,
    gamma_list:list,
    alpha_list:list,
    beam_nums_list:list,
    max_length:int,
    final_norm:bool,
    seed:int,
    dataset:list,
    results_dir:str,
    batch_size:int,
    
         
         ):
    
    tokenizer = AutoTokenizer.from_pretrained(model_name,device_map="auto")
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda:0", torch_dtype=torch.float16)
    
    candidates_dir = results_dir #os.path.join(results_dir, "candidates")
    if not os.path.exists(candidates_dir):
        os.makedirs(candidates_dir)
    times_dir = os.path.join(results_dir, "times")
    if not os.path.exists(times_dir):
        os.makedirs(times_dir)
    errors_dir = os.path.join(results_dir, "errors")
    if not os.path.exists(errors_dir):
        os.makedirs(errors_dir)
    tracebacks_dir = os.path.join(results_dir, "tracebacks")
    if not os.path.exists(tracebacks_dir):
        os.makedirs(tracebacks_dir)

    for gamma in gamma_list:

        for alpha in alpha_list:    
            for num_beams in beam_nums_list:
                print(f"\nalpha={alpha}, gamma={gamma}, beam={num_beams}")

                candidates = []
                times_logits_processor = []
                errors = []
                tracebacks = []
                for i in tqdm(range(len(dataset))):
                    
                    try:
                        concepts = dataset[i]["concepts"]
                        variants = generate_variants_commongen_ctrlg(concepts)
                        concepts_str = "\"{}\"".format(", ".join(concepts))
                        dfa_layer = ordered_commongen_dfa(variants, tokenizer)
                        prompt = complete_prompt_ordered_commongen(concepts_str)

                        start_time = timer()
                        input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device="cuda:0")
                        inputs_len = input_ids.shape[-1]
                        
                        dfa_logits_processor = lp.DFALogitsProcessor(
                            dfa_layer=dfa_layer,
                            device="cuda:0",
                            variants=variants,
                            concepts=concepts + ["_dot_", "eos", "_others_"],
                            tokenizer=tokenizer,
                            num_beams=num_beams,
                            max_length=max_length + inputs_len,
                            alpha=alpha,
                            gamma=gamma,
                            eps_favor=1e-2,
                        )

                        results = dfa_beam_search.generate_with_dfa(
                            model=model,
                            tokenizer=tokenizer,
                            dfa_processor=dfa_logits_processor,
                            input_ids=input_ids,
                            prompt=prompt,
                            num_beams=num_beams,
                            max_new_tokens=max_length,
                            length_penalty=0,
                            half=True,
                            batch_size=batch_size,
                            final_norm=final_norm,
                        )

                        end_time = timer()
                        sequences = results["sequences"]
                        candidate = tokenizer.decode(sequences[0][inputs_len:], skip_special_tokens=True)
                        
                        candidates.append({
                            "id": dataset[i]["id"],
                            "concepts": concepts,
                            "sentence": candidate
                        })
                        
                        times_logits_processor.append({
                            "id": dataset[i]["id"],
                            "time": end_time - start_time
                        })
                    
                        
                        torch.cuda.empty_cache()

                    except Exception as e:
                        print(f"Exception: {str(e)}")
                        print(f"Traceback: {traceback.format_exc()}")
                    
                        errors.append({
                            "id": dataset[i]["id"],
                            "error": str(e)
                        })
                        tracebacks.append({
                            "id": dataset[i]["id"],
                            "traceback": traceback.format_exc()
                        })
                    

            config_tag = f"LLAMA_ordered_commongen_finalNorm_{final_norm}_maxLength_{max_length}_beams_{num_beams}_alpha_{str(alpha).replace('.', '_')}_gamma_{str(gamma).replace('.', '_')}_seed_{seed}"
            
            candidates_path = os.path.join(candidates_dir, f'candidates_{config_tag}.json')
            with open(candidates_path, 'w') as f:
                json.dump(candidates, f, indent=4)
            times_path = os.path.join(times_dir, f'times_{config_tag}.json')
            with open(times_path, 'w') as f:
                json.dump(times_logits_processor, f, indent=4)
            errors_path = os.path.join(errors_dir, f'errors_{config_tag}.json')
            with open(errors_path, 'w') as f:
                json.dump(errors, f, indent=4)
            tracebacks_path = os.path.join(tracebacks_dir, f'tracebacks_{config_tag}.json')
            with open(tracebacks_path, 'w') as f:
                json.dump(tracebacks, f, indent=4)

    print("\nDone!")

def str_to_bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')
    
if __name__ == "__main__":
        
    max_length = 16
    beam_nums_list = [64]
    alpha_list = [0.25]
    gamma_list = [1]
    model_name = "meta-llama/Llama-3.1-8B"

    seed = 42
    final_norm = True
    
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--model_name", type=str, default=model_name, help="Model name")
    argparser.add_argument("--dataset_path", type=str, default="./data/Ordered CommonGen/ordered_commongen.json", help="Path to the dataset")
    argparser.add_argument("--gamma_list", type=str, default="1", help="List of gamma values")
    argparser.add_argument("--alpha_list", type=str, default="0.25", help="List of alpha values")
    argparser.add_argument("--beam_nums_list", type=str, default="64", help="List of beam numbers")
    argparser.add_argument("--max_length", type=int, default=max_length, help="Max length of the generated text")
    argparser.add_argument("--final_norm", type=bool, default=final_norm, help="Final normalization")
    argparser.add_argument("--batch_size", type=int, default=16, help="Batch size")
    argparser.add_argument("--seed", type=int, default=seed, help="Random seed")
    args = argparser.parse_args()
    model_name = args.model_name
    gamma_list = [float(g) for g in args.gamma_list.split(",")]
    alpha_list = [float(a) for a in args.alpha_list.split(",")]
    beam_nums_list = [int(b) for b in args.beam_nums_list.split(",")]
    max_length = args.max_length
    
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)

    print(f"Running on {model_name}")
    print(f"Clean model name: {clean_model_name(model_name)}")
    results_dir = f"./LLM/results/llama8b_ordered_commongen"
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    with open(args.dataset_path, 'r') as f:
        dataset = json.load(f)
        
    main(
        model_name=model_name,
        gamma_list=gamma_list,
        alpha_list=alpha_list,
        beam_nums_list=beam_nums_list,
        max_length=max_length,
        final_norm=final_norm,
        seed=seed,
        dataset=dataset,
        results_dir=results_dir,
        batch_size=args.batch_size,
    )