import os
import warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

warnings.filterwarnings("ignore")
from timeit import default_timer as timer
import random
import torch
import traceback
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.commongen import complete_prompt, commongen_dfa
import utils.logits_processors as lp
from argparse import ArgumentParser
import utils.dfa_beam_search as dfa_beam_search
import json, traceback
from utils.utils import generate_variants_commongen_ctrlg

def main(
    dataset: str,
    model_name: str,
    max_length: int,
    beam_nums_list: list,
    alpha_list: list,
    gamma_list: list,
    data_partition: str,
    seed: int,
    final_norm: bool,
    device: str,

    supervised: bool = False,
):
        
    if supervised:
        supervised_str = "supervised"
    else:
        supervised_str = "unsupervised"

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device).half()
    results_dir = "./LLM/results/gpt2_unsupervised_commongen" if not supervised else "./LLM/results/gpt2_supervised_commongen"
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    references_dir = os.path.join(results_dir, "references")
    if not os.path.exists(references_dir):
        os.makedirs(references_dir)
    candidates_dir = os.path.join(results_dir, "candidates")
    if not os.path.exists(candidates_dir):
        os.makedirs(candidates_dir)
    times_dir = os.path.join(results_dir, "times")
    if not os.path.exists(times_dir):
        os.makedirs(times_dir)
    errors_dir = os.path.join(results_dir, "errors")
    if not os.path.exists(errors_dir):
        os.makedirs(errors_dir)
    tracebacks_dir = os.path.join(results_dir, "tracebacks")
    if not os.path.exists(tracebacks_dir):
        os.makedirs(tracebacks_dir)
    dataset_split = dataset[data_partition]
    i = 0
    total_groups = 0
    while i < len(dataset_split):
        current_concepts = set(dataset_split[i]["concepts"])
        while i < len(dataset_split) and set(dataset_split[i]["concepts"]) == current_concepts:
            i += 1
        total_groups += 1

    for alpha in alpha_list:
        for gamma in gamma_list:
            for num_beams in beam_nums_list:
                print(f"\nalpha={alpha}, gamma={gamma}, beam={num_beams}")
                
                references = []
                candidates = []
                times_logits_processor = []
                errors = []
                tracebacks = []
                with tqdm(total=total_groups, desc="Progress") as pbar:
                    i = 0
                    while i < len(dataset_split):
                        current_concepts = set(dataset_split[i]["concepts"])
                        start_idx = i
                        
                        group_references = []
                        while i < len(dataset_split) and set(dataset_split[i]["concepts"]) == current_concepts:
                            
                            group_references.append({
                                "concept_set_idx": i,
                                "concepts": dataset_split[i]["concepts"],
                                "target": dataset_split[i]["target"]
                            })
                            i += 1
                            
                        
                        references.extend(group_references)
                        
                        try:
                            concepts = dataset_split[start_idx]["concepts"]
                            variants = generate_variants_commongen_ctrlg(concepts)
                            concepts_str = "\"{}\"".format(", ".join(concepts))
                            dfa_layer = commongen_dfa(variants, tokenizer)
                            prompt = complete_prompt(concepts_str)

                            start_time = timer()
                            input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device=device)
                            inputs_len = input_ids.shape[-1]
                            
                            dfa_logits_processor = lp.DFALogitsProcessor(
                                dfa_layer=dfa_layer,
                                device=device,
                                variants=variants,
                                concepts=concepts + ["_dot_", "eos", "_others_"],
                                tokenizer=tokenizer,
                                num_beams=num_beams,
                                max_length=max_length + inputs_len,
                                alpha=alpha,
                                gamma=gamma,
                                eps_favor=1e-2,
                            )

                            results = dfa_beam_search.generate_with_dfa(
                                model=model,
                                tokenizer=tokenizer,
                                dfa_processor=dfa_logits_processor,
                                input_ids=input_ids,
                                prompt=prompt,
                                num_beams=num_beams,
                                max_new_tokens=max_length,
                                length_penalty=0,
                                half=True,
                                batch_size=16,
                                final_norm=final_norm,
                            )

                            end_time = timer()
                            sequences = results["sequences"]
                            candidate = tokenizer.decode(sequences[0][inputs_len:], skip_special_tokens=True)
                            
                            candidates.append({
                                "concept_set_idx": start_idx,
                                "group_size": len(group_references),
                                "concepts": concepts,
                                "sentence": candidate
                            })
                            
                            times_logits_processor.append({
                                "concept_set_idx": start_idx,
                                "time": end_time - start_time
                            })
                            
                            torch.cuda.empty_cache()

                        except Exception as e:
                            print(f"Group {start_idx}: {str(e)}")
                        
                            errors.append({
                                "concept_set_idx": start_idx,
                                "error": str(e)
                            })
                            tracebacks.append({
                                "concept_set_idx": start_idx,
                                "traceback": traceback.format_exc()
                            })
                        
                        pbar.update(1)

                config_tag = f"{data_partition}_finalNorm_{final_norm}_maxLength_{max_length}_beams_{num_beams}_alpha_{str(alpha).replace('.', '_')}_gamma_{str(gamma).replace('.', '_')}_seed_{seed}"
                
                references_path = os.path.join(references_dir, f'references_{supervised_str}_{config_tag}.json')
                with open(references_path, 'w') as f:
                    json.dump(references, f, indent=4)
                candidates_path = os.path.join(candidates_dir, f'candidates_{supervised_str}_{config_tag}.json')
                with open(candidates_path, 'w') as f:
                    json.dump(candidates, f, indent=4)
                times_path = os.path.join(times_dir, f'times_{supervised_str}_{config_tag}.json')
                with open(times_path, 'w') as f:
                    json.dump(times_logits_processor, f, indent=4)
                errors_path = os.path.join(errors_dir, f'errors_{supervised_str}_{config_tag}.json')
                with open(errors_path, 'w') as f:
                    json.dump(errors, f, indent=4)
                tracebacks_path = os.path.join(tracebacks_dir, f'tracebacks_{supervised_str}_{config_tag}.json')
                with open(tracebacks_path, 'w') as f:
                    json.dump(tracebacks, f, indent=4)

    print("\nDone!")

def str_to_bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise ValueError(f"Invalid boolean string: {v}")

def str_to_float_list(v):
    if isinstance(v, list):
        return v
    try:
        return [float(x) for x in v.split(",")]
    except ValueError:
        raise ValueError(f"Invalid float list string: {v}")
def str_to_int_list(v):
    if isinstance(v, list):
        return v
    try:
        return [int(x) for x in v.split(",")]
    except ValueError:
        raise ValueError(f"Invalid int list string: {v}")

if __name__ == "__main__":
    
    dataset = load_dataset("allenai/common_gen")
    
    max_length = 32
    beam_nums_list = [64]
    alpha_list = [0.5]
    gamma_list = [1.]
    data_partition = "validation"
    seed = 42
    final_norm = True
    supervised = False
    
    parser = ArgumentParser()
    parser.add_argument('--data_partition', type=str, default=data_partition)
    parser.add_argument("--seed", type=int, default=seed)
    parser.add_argument("--max_length", type=int, default=max_length)
    parser.add_argument("--beam_nums_list", type=str, default=beam_nums_list)
    parser.add_argument("--alpha_list", type=str_to_float_list, default=alpha_list)
    parser.add_argument("--gamma_list", type=str_to_float_list, default=gamma_list)
    parser.add_argument("--final_norm", type=str_to_bool, default=final_norm)
    parser.add_argument("--supervised", type=str_to_bool, default=supervised)
    args = parser.parse_args()
    seed = args.seed
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    
    model_name = "ctrlg/gpt2-large_common-gen" if args.supervised else "openai-community/gpt2-large"
    main(
        dataset=dataset,
        model_name=model_name,
        max_length=args.max_length,
        beam_nums_list=args.beam_nums_list,
        alpha_list=args.alpha_list,
        gamma_list=args.gamma_list,
        data_partition=args.data_partition,
        seed=args.seed,
        final_norm=args.final_norm,
        device="cuda:0",
        supervised=args.supervised,
    )