import os
import yaml
import json
import time
import pandas as pd
from tqdm import tqdm

from LLMClient import init_client
from API_Manager import API_Pool
from Agents import Data_Provider, API_Extractor, API_Verifier, API_Maintainer

def ensure_dir(dir:str):
    if not os.path.exists(dir):
        os.makedirs(dir)

class APILogger:
    def __init__(self, base_dir:str, snap_N:int=10):
        self.base_dir = base_dir
        ensure_dir(base_dir)
        self.summary = {}
        self.notes = "# Note"
        self.snap_N = snap_N
        self.reject_api_list = []
        self.error_api_list = []
        self.snap_count = 0
    
    def set_summary(self, configs:dict):
        self.summary = configs.copy()
        self.notes += "\n" + "# Init Time: " + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    
    def add_notes(self, info:str):
        self.notes += "\n# " + info

    def add_reject_item(self, reject_item):
        self.reject_api_list.append(reject_item)
        self.snap_count += 1
        if self.snap_count >= self.snap_N:
            self.save_logs()
            self.snap_count = 0
    def add_error_item(self, error_item):
        self.error_api_list.append(error_item)
        self.snap_count += 1
        if self.snap_count >= self.snap_N:
            self.save_logs()
            self.snap_count = 0
    
    def save_logs(self):
        yaml.dump(self.summary, open(os.path.join(self.base_dir, "summary.yaml"), "w"))
        with open(os.path.join(self.base_dir, "summary.yaml"), "a") as f:
            f.write(f"# {self.notes}\n")
        with open(os.path.join(self.base_dir, "error_api_list.jsonl"), "w") as f:
            for item in self.error_api_list:
                f.write(json.dumps(item))
                f.write("\n")
        with open(os.path.join(self.base_dir, "reject_api_list.jsonl"), "w") as f:
            for item in self.reject_api_list:
                f.write(json.dumps(item))
                f.write("\n")

def init_configs_lookup(config_file_path:str):
    configs_lookup = {}
    config = yaml.safe_load(open(config_file_path, 'r'))
    for key in config.keys():
        # ignore API_KEY
        if key in ["API_KEY"]:
            continue
        file_path = config[key]
        configs_lookup[key] = file_path
    # ensure directory exists
    ensure_dir(configs_lookup['SAVE_PATH'])
    return configs_lookup

def init_prompt_lookup(config_file_path:str):
    prompt_lookup = {
        "EXPERT_KNOWLEDGE_COLD_START": "",
        "EXPERT_KNOWLEDGE_GUIDE": "",
        "EXTRACTOR_PROMPT": "",
        "VERIFIER_PROMPT": "",
        "MAINTAINER_PROMPT": "",
        "EXPERT_KNOWLEDGE_MATERIALS": "",
    }
    config = yaml.safe_load(open(config_file_path, 'r'))
    for key in prompt_lookup.keys():
        file_path = config[key]
        if os.path.exists(file_path):
            prompt_lookup[key] = open(file_path, 'r').read()
        else:
            prompt_lookup[key] = ""
    return prompt_lookup

# TODO: cold start
def cold_start(api_pool, prompt_lookup, client, mode='cold_start'):
    raise NotImplementedError

def main():
    # INIT
    config_file_path = '/mnt/workspace/workgroup/Benchmark/project-code/configs/api_factory_config.yaml'
    client = init_client(config_file_path)
    prompt_lookup = init_prompt_lookup(config_file_path)
    configs_lookup = init_configs_lookup(config_file_path)
    data_provider = Data_Provider(configs_lookup['DATA_PATH'])
    api_logger = APILogger(configs_lookup['SAVE_PATH'])
    api_pool = API_Pool(os.path.join(configs_lookup['SAVE_PATH'], "pool.jsonl"))

    # Logger: Summary
    api_logger.set_summary(configs_lookup)

    # guidance
    expert_guidance_enable = True
    if expert_guidance_enable:
        domain_knowledge = prompt_lookup["EXPERT_KNOWLEDGE_MATERIALS"]
        guidance = prompt_lookup["EXPERT_KNOWLEDGE_GUIDE"].replace("{domain_knowledge_materials}", domain_knowledge)
        extractor_sys_prompt =  prompt_lookup["EXTRACTOR_PROMPT"] + "\n" + guidance + "\n"
        api_extractor = API_Extractor(client, extractor_sys_prompt)
    else:
        api_extractor = API_Extractor(client, prompt_lookup["EXTRACTOR_PROMPT"])
    api_verifier = API_Verifier(client, prompt_lookup["VERIFIER_PROMPT"])

    # COLD_START
    # cold_start(api_pool, prompt_lookup, client, mode="cold_start")

    # RUN
    error_list = []
    rejiect_list = []
    for i in tqdm(range(len(data_provider)), desc='Processing'):
        try:
            chat_data = data_provider.get_data_from_idx(i)
            # >>> API_Extractor
            extractor_input = f"Input dialogue:\n{chat_data}"
            new_api_info = api_extractor.run(extractor_input)
            print(f">> [API_Extractor]: \n{new_api_info}")

            # >>> API_Verifier
            current_api_info = api_pool.get_api_info()
            # current_api_info = api_pool.get_api_info_from_keys(intent=new_api_info["Intent"], action=new_api_info["Action"])
            verifier_input = f"Current existing Intent-Action-Tool list:\n{current_api_info}\nTarget Input Intent-Action-Tool:\n{new_api_info}"
            output = api_verifier.run(verifier_input)
            print(f">> [API_Verifier]: {output}")
            status = output['status']
            if status == 1:
                api_pool.add_item(new_api_info)
            else:
                rejiect_list.append(new_api_info)
                api_logger.add_reject_item(new_api_info)
            # >>> Maintain
            if i % 200 == 0 and i != 0:
                # Maintain Merge
                # Maintain Analyze
                pass

        except Exception as e:
            if e == KeyboardInterrupt:
                exit()
            else:
                error_list.append(i)
                api_logger.add_error_item(i)
                print(f"Error: {i}")
                continue
    # END
    api_logger.add_notes(str(api_extractor.get_token_count()))
    api_logger.add_notes(str(api_verifier.get_token_count()))
    api_logger.add_notes(f"End Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
    api_logger.save_logs()
    api_pool.save_pool()
    # REPORT
    print("="*20, "Report", "="*20)
    print("### Tasks Done:")
    print(f"Completed: {len(data_provider)}, Success: {len(data_provider)-len(error_list)}, Error: {len(error_list)}")
    print(f"Error list saved in {api_logger.base_dir}")
    print(f"Rejected list saved in {api_logger.base_dir}")
    print("")
    print("### Token Costs")
    print(f"Extractor: {api_extractor.get_token_count()}")
    print(f"Verifier: {api_verifier.get_token_count()}")
    print("")
    print("### API Pools")
    # print(f"{api_pool.get_pool_basic_info()}")
    print("="*40, "END", "="*40)

# TODO: 多线程处理，提升速度
def multi_thread_main():
    pass

if __name__ == "__main__":
    main()