import argparse
import random
import os
import json
import time

import tqdm

from utils.config import Config
from utils.logger import create_logger, display_exp_setting
from utils.loader import load_parser, load_llm, load_test_data, load_train_data
from utils.evaluation import nlp_evaluation
from prompt_compiler.compiler import PromptCompiler

parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='testset', help='cluster')
parser.add_argument('--log_dir', default="exp")

parser.add_argument("--seed", type=int, default=1)

parser.add_argument("--dataset", type=str, required=True)
parser.add_argument("--split_name", type=str)
parser.add_argument("--domain", type=str)
parser.add_argument("--num_shot", type=int)
parser.add_argument("--quickrun", action="store_true")
parser.add_argument("--cut_len", type=int)

# llm
parser.add_argument("--engine", type=str, required=True)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--freq_penalty", type=float, default=0.0)
parser.add_argument("--max_tokens", type=int, default=100)
parser.add_argument("--llm_cache_dir", type=str, default="llm_cache")

# for generation
parser.add_argument("--use_generic_grammar", action="store_true")
parser.add_argument("--num_samples", type=int, default=100)
parser.add_argument("--rule_temperature", type=float, default=0.0)

# prompting
parser.add_argument("--prompt_mode", type=str, required=True)
parser.add_argument("--prompt_template", type=str, required=True)
parser.add_argument("--add_rule_instruction_flag", action="store_true")
parser.add_argument("--retrieve_fn", type=str)
parser.add_argument("--batch_size", type=int)

## for std prompting
parser.add_argument("--use_linearized_tree", action="store_true")

## for grammar-based prompting
parser.add_argument("--use_oracle_rule_flag", action="store_true")
parser.add_argument("--constrain_rule_gen_flag", action="store_true")
parser.add_argument("--constrain_prog_gen_flag", action="store_true")
parser.add_argument("--lazy_constrain_flag", action="store_true")
### Unique for DSL grammar-based prompting
parser.add_argument("--kg_rule_flag", action="store_true")
parser.add_argument("--use_action_list_flag", action="store_true")
parser.add_argument("--add_rule_list_flag", action="store_true")

## for iterative prompting
parser.add_argument("--num_iterations", type=int, default=1)
parser.add_argument("--iter_prompt_template", type=str, default="iter")
parser.add_argument("--iter_engine", type=str)
parser.add_argument("--iter_retrieve_fn", type=str)

args = parser.parse_args()

if __name__ == "__main__":
    start_time = time.time()
    random.seed(args.seed)
    cfg = Config(args)
    logger = create_logger(os.path.join(cfg.log_dir, 'log.txt'))
    display_exp_setting(logger, cfg)
    llm = load_llm(cfg.engine)

    sources, prompts, predictions, targets, grammars, times, infos = [], [], [], [], [], [], []

    if cfg.mode == "testset":
        assert cfg.prompt_mode in ["std", "rot"]
        test_examples = load_test_data(cfg.quickrun, cfg.cut_len)

        for dataset in ["BioEng", "Ecology", "Genetics", "Medical"]:
            train_examples = load_train_data(dataset=dataset)
            test_data_subset = [a for a in test_examples if a["bigAreas"] == dataset]
            logger.info(f"{dataset}: loaded {len(train_examples)} indist examples, {len(test_data_subset)} test examples")

            global_parser, global_rules = load_parser(dataset, cfg.use_action_list_flag)
            prompt_compiler = PromptCompiler(dataset=dataset, prompt_mode=cfg.prompt_mode, llm=llm, retrieve_fn=cfg.retrieve_fn,
                                            batch_size=cfg.batch_size, train_examples=train_examples, prompt_template=cfg.prompt_template,
                                            global_rules=global_rules, add_rule_instruction_flag=cfg.add_rule_instruction_flag, use_oracle_rule_flag=cfg.use_oracle_rule_flag,
                                            use_linearized_tree_flag=cfg.use_linearized_tree, constrain_prog_gen_flag=cfg.constrain_prog_gen_flag,
                                            lazy_constrain_flag=cfg.lazy_constrain_flag, constrain_rule_gen_flag=cfg.constrain_rule_gen_flag, kg_rule_flag=cfg.kg_rule_flag,
                                            global_parser=global_parser, temperature=cfg.temperature, seed=cfg.seed, use_action_list_flag=cfg.use_action_list_flag,
                                            max_tokens=cfg.max_tokens, llm_cache_dir=cfg.llm_cache_dir, freq_penalty=cfg.freq_penalty, add_rule_list_flag=cfg.add_rule_list_flag)

            
            for example in tqdm.tqdm(test_data_subset, total=len(test_data_subset)):
                infos.append(json.dumps({"bigAreas":example["bigAreas"], "bigProb":example["bigProb"], "smallProb":example["smallProb"], "procedure_id":example["procedure_id"], "cut_id":example["cut_id"]}))
                example = example["cut"]

                s_time = time.time()
                prompt, prediction, grammar = prompt_compiler.compile(example)
                e_time = time.time()

                prompts.append(prompt)
                predictions.append(prediction)
                grammars.append(grammar)
                sources.append(example.source)
                targets.append(example.target)
                times.append(str(e_time - s_time))

                end_time = time.time()
                execution_time = end_time - start_time
                logger.info("program time: " + str(execution_time) + "s")

                json_results = {
                    "test_prompts": prompts,
                    "test_predictions": predictions,
                    "test_grammars": grammars,
                    "test_sources": sources,
                    "test_target": targets,
                    "test_time": times,
                    "test_infos": infos
                }

                with open(f"{cfg.result_dir}/results.json", "w") as f:
                    logger.info(f"dumping results to {cfg.result_dir}/results.json")
                    json.dump(json_results, f, indent=2)


# BioEng: loaded 623 indist examples, 101 test examples
# Ecology: loaded 285 indist examples, 4 test examples
# Genetics: loaded 768 indist examples, 161 test examples
# Medical: loaded 859 indist examples, 113 test examples
# Tot: 379 test examples
# 239