import argparse
import os
import datetime

from model.decoder import create_generate_decoder
from method.baseline_engine import create_inference_engine
from utils import *


def main():
    ## Set args ...
    args = parse_arguments()
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    file_path = args.save_path+"/generation.json"
    print('*****************************')
    print(args)
    print('*****************************')
    fix_seed(args.random_seed)
    
    ## Set loarders ...
    print("setup data loader and decoder ...")
    dataloader = setup_data_loader(args)
    decoder = create_generate_decoder(args)
    get_system_prompt(args)
    get_demo(args)
    infer_engine = create_inference_engine(args)
    print_now()

    ### for eval data
    with open(args.eval_data_path, 'r', encoding='utf-8') as f:
        eval_data_id = json.load(f)

    ## Inference ...
    for i, data in enumerate(dataloader):
        print('*************************')
        print("{}st data".format(i+1))
        
        x, ind, parse_groups = data
        x = x[0].strip()
        x = x[:6000]
        ind = ind[0].strip()
        ### for eval data
        eval_label = ""
        if int(ind) in eval_data_id["branch_good"]:
            eval_label = "good"
        elif int(ind) in eval_data_id["branch_bad"]:
            eval_label = "bad"
        else:
            continue
        ###
        generation = infer_engine.generate(decoder, x, parse_groups)

        print("main_id: {}".format(ind))
        print("Instructions: {}".format(x))
        print("generation : {}".format(generation))
        print('*************************')

        data = {'main_id':int(ind),'model':args.model,'method':args.method, 'instruction':x,'eval_label': eval_label, 'generated':generation} 
        save_results(data, file_path)

        if (args.limit_dataset_size != 0) and ((i+1) >= args.limit_dataset_size):
            break
            #raise ValueError("Stop !!")

def parse_arguments():
    parser = argparse.ArgumentParser(description="HoT")
    
    ## model args
    parser.add_argument("--call_mode", type=str, default="huggingface", choices=["vllm", "huggingface", "gpt"], help="call mode for engine")
    parser.add_argument(
        "--model", type=str, default="llama3-8b", choices=["gpt3.5", "llama3-8B"], help="model used for decoding.")
    parser.add_argument("--query_type", type=str, default="lamma3-instruct", help="query type")
    parser.add_argument("--temperature", type=float, default=0.0, help="temperature for llm")    
    parser.add_argument("--top_k", type=int, default=1, help="top_k for llm")
    parser.add_argument("--top_p", type=float, default=0.95, help="top_p for llm")
    parser.add_argument("--random_seed", type=int, default=1, help="random seed")
    parser.add_argument("--api_time_interval", type=float, default=1.0, help="")  
    ## data args
    parser.add_argument(
        "--dataset", type=str, default="complexbench", choices=["complexbench", "parallel"], help="dataset used for experiment"
    )
    parser.add_argument(
        "--limit_dataset_size", type=int, default=1500, help="whether to limit test dataset size. if 0, the dataset size is unlimited and we use all the samples in the dataset for testing."
    )
    parser.add_argument("--minibatch_size", type=int, default=1, choices=[1], help="minibatch size should be 1 because GPT-3 API takes only 1 input for each request")
    parser.add_argument("--max_num_worker", type=int, default=3, help="maximum number of workers for dataloader")
    ## method args
    parser.add_argument(
        "--method", type=str, default="direct_prompt", choices=["direct_prompt", "zero_shot_cot", "plan_solve", "think_execute","NL_plan"], help="method"
    )
    parser.add_argument(
        "--stat", type=bool, default=False, help="structure attention or not"
    )
    parser.add_argument(
        "--stat_mode", type=str, default="span_only", choices=["relative_only", "span_only","relative+span"], help="stat mode"
    )
    parser.add_argument(
        "--dynamic_span", type=bool, default=False, help="dynamic add new generated token for parallel or not"
    )
    parser.add_argument("--demo_flag", type=bool, default=False, help="demo flag")
    parser.add_argument("--max_tokens", type=int, default=8192, help="max_tokens for llm")
    parser.add_argument(
        "--max_length", type=int, default=8192, help="maximum length of output tokens by model for reasoning extraction"
    )
    ## path args
    parser.add_argument("--save_path", type=str, default="llm_generations", help="log directory")
    parser.add_argument("--stat_head_config", type=str, default="/home/cuisen/users/lfr/FocusAtten/Structure_Emb/head_profiling/head_config/mask_heads_branch_sort_go_intersection_100.json", help="head config path")
    parser.add_argument("--eval_data_path", type=str, default="/home/cuisen/users/lfr/FocusAtten/Structure_Emb/llm_generations/complexbench/llama3-8b/direct_prompt/1212_16:32/eval_data_id.json", help="eval data path")

    args = parser.parse_args()
    if args.stat:
        args.save_path = args.save_path + "/" + args.dataset + "/" + args.model + "/stat/" + args.method+"/"+str(datetime.datetime.now().strftime('%m%d_%H:%M'))
    else:
        args.save_path = args.save_path + "/" + args.dataset + "/" + args.model + "/" + args.method +"/"+str(datetime.datetime.now().strftime('%m%d_%H:%M'))
    return args

if __name__ == "__main__":
    main()