import json
import yaml
import pandas as pd
import os
from tqdm import tqdm
from queue import Queue
import threading

from LLMClient import init_client
from Agents import EvaluationAgent

"""
Ensure directory exists
"""
def ensure_dir(dir:str):
    if not os.path.exists(dir):
        os.makedirs(dir)

"""
Load configs from config file
"""
def init_configs_lookup(config_file_path:str):
    configs_lookup = {}

    config = yaml.safe_load(open(config_file_path, 'r'))
    for key in config.keys():
        # ignore API_KEY
        if key in ["API_KEY"]:
            continue
        file_path = config[key]
        configs_lookup[key] = file_path
    # ensure directory exists
    # ensure_dir(configs_lookup['SAVE_PATH'])
    return configs_lookup

"""
Load action space
"""
class APISampler:
    def __init__(self, file_path:str):
        self.action_space = self.load_actions(file_path)
    def load_actions(self, file_path):
        ## load action from .jsonl files
        # action_space = pd.read_json(file_path, lines=True)
        # return action_space

        ## load action from aggregation .json file
        action_space = json.load(open(file_path, "r"))
        self.raw_action_space = action_space
        action_space_list = []
        for intent in action_space:
            for action in action_space[intent]:
                for tool in action_space[intent][action]:
                    action_item = {
                        "Intent": intent,
                        "Action": action,
                        "Tool": tool,
                        "Tool_Desc": action_space[intent][action][tool]["description"],
                        "Tool_Params": action_space[intent][action][tool]["parameters"],
                    }
                    action_space_list.append(action_item)
        df_action_space = pd.DataFrame(action_space_list)
        return df_action_space
    def get_action_num(self) -> int:
        return len(self.action_space)
    def sample_action_list(self, N:int) -> list:
        sampled_actions = self.action_space.sample(N)
        sampled_action_list = []
        for index, row in sampled_actions.iterrows():
            action_dict = row.to_dict()
            item = {
                "Intent": action_dict["Intent"],
                "Action": action_dict["Action"],
                "Tool": action_dict["Tool"],
                "Tool_Desc": action_dict["Tool_Desc"],
            }
            sampled_action_list.append(item)
        return sampled_action_list
    def get_raw_actions(self) -> str:
        return self.raw_action_space
    
    def get_idx_from_key_value(self, key, value) -> int:
        try:
            idx = self.action_space[self.action_space[key] == value].index.item()
            return idx
        except:
            return -1

"""
Load Dataset
"""
class DataLoader:
    def __init__(self, file_path:str):
        self.data = pd.read_json(file_path, lines=True)

    def get_data_num(self) -> int:
        return len(self.data)

    def get_data_iterator(self) -> list:
        for index, row in self.data.iterrows():
            yield index, row.to_dict()


def main():
    # init
    # config_file = "/mnt/workspace/workgroup/Benchmark/project-code/configs/evaluate_config.yaml"
    config_file = "/mnt/workspace/workgroup/project-code/configs/evaluate_config.yaml"
    configs_lookup = init_configs_lookup(config_file)
    mode = configs_lookup.get("MODE", "singleturn").lower()
    client = init_client(config_file)
    dataloader = DataLoader(configs_lookup["DATASET_PATH"])
    api_sampler = APISampler(configs_lookup["ACTION_SPACE_PATH"])

    # prepare
    action_space = api_sampler.get_raw_actions()
    # run
    results = []
    ground_truth = []
    for index, data in tqdm(dataloader.get_data_iterator(), desc="Evaluating...", total=dataloader.get_data_num()):
        system_prompt = data["candidate_system_prompt"]
        user_prompt = f"""
            ## Product Info:
            {data['product_info']}
            """ + "\n" + \
            f"""
            ## Buyer's Question:
            {data['buyer_question']}
            """ + "\n" + \
            f"""
            ## action_choice_space:
            {action_space}
            """ + "\n" + \
            """
            For example, if you want to choose below action, please respond "API_CheckAvailability"
            "Inquire_Product_Availability": {
            "Check_Availability": {
            "API_CheckAvailability": {
            "description": "Check whether the item is still available for purchase",
            "parameters": {
            "available": { "type": "boolean" }
            }
            },

            ## Output Format:
            please output in json format, including following keys:
            - action_list: list, including action name
            """
        message = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]
        response = client.chat(message)
        response = json.loads(response["choices"][0]["message"]["content"])["action_list"]
        # results record
        action_choice_answer = []
        for item in response:
            item_idx = api_sampler.get_idx_from_key_value("Tool", item)
            action_choice_answer.append(item_idx)
        results.append([index, str(action_choice_answer)])
        # ground truth record
        gt = data["ground_truth_action"]
        ground_truth_answer = []
        for item in gt:
            tool_name = item["Tool"]
            item_idx = api_sampler.get_idx_from_key_value("Tool", tool_name)
            ground_truth_answer.append(item_idx)
        ground_truth.append([index, str(ground_truth_answer)])
    # save as csv
    save_dir = configs_lookup["RESULT_SAVE_PATH"]
    with open(save_dir, "w") as f:
        f.write("question_index, action_choice_answer\n")
        for index, answer in results:
            line = f"{index}, "+ str(answer).replace("[", "{").replace("]", "}").replace(" ", "") + "\n"
            f.write(line)

    save_dir = configs_lookup["GROUND_TRUTH_PATH"]
    with open(save_dir, "w") as f:
        f.write("question_index, action_choice_answer\n")
        for index, answer in ground_truth:
            line = f"{index}, "+ str(answer).replace("[", "{").replace("]", "}").replace(" ", "") + "\n"
            f.write(line)
    

def make_user_prompt(data:dict, mode:str, round_idx:int=None, prev_intents:list=None):
    if mode == "singleturn":
        question = data['buyer_question']
        prev_context_str = ""
    else:
        round_data = next((d for d in data["multi_turn_dialogue"] if d["round"] == round_idx), None)
        question = round_data["buyer_question"]
        prev_context_list = round_data.get("prev_context", [])
        prev_context_str = "\n".join(prev_context_list) if prev_context_list else ""
    user_prompt = f"## Product Info:\n{data['product_info']}\n\n"
    if prev_context_str:
        user_prompt += f"## Previous Context:\n{prev_context_str}\n\n"
    user_prompt += f"## Buyer's Question:\n{question}\n\n"
    user_prompt += f"## action_choice_space:\n{data['choice_space']}\n\n"
    if prev_intents:
        user_prompt += f"## Previous Intent Choices:\n{prev_intents}\n\n"
    user_prompt += "Please output in JSON with key 'action_list', listing tool name(s)."
    return user_prompt


def convert_action_to_index(api_sampler:APISampler, action_list:list, key=None):
    idx_list = []
    for item in action_list:
        if key is None:
            item_idx = api_sampler.get_idx_from_key_value("Tool", item)
        else:
            item_name = item[key]
            item_idx = api_sampler.get_idx_from_key_value("Tool", item_name)
        idx_list.append(item_idx)
    return idx_list

def get_idx_from_choice_space(choice_space:list, action_list:list, key=None):
    idx_list = []
    tool_name_list = []
    for action in choice_space:
        tool_name = action["Tool"]
        tool_name_list.append(tool_name)
    for action in action_list:
        try:
            if key is None:
                idx = tool_name_list.index(action)
            else:
                idx = tool_name_list.index(action[key])
        except:
            idx = -1
        idx_list.append(idx)
    return idx_list

def worker(input_queue, output_queue, client, pbar):
    # cnt = 0
    token_count = {
        "completion_tokens": 0,
        "prompt_tokens": 0,
        "total_tokens": 0
    }
    while True:
        item = input_queue.get()
        if item is None:
            input_queue.task_done()
            break
        cnt = 0
        while cnt < 3:
            try:
                system_prompt = item["system_prompt"]
                agent = EvaluationAgent(client, system_prompt)
                index = item["index"]
                choice_space = item["choice_space"]
                agent_inputs = item["inputs"]
                agent_output = agent.run(agent_inputs)
                action_list = agent_output["action_list"]
                break
            except:
                cnt += 1
                print(f"[Evaluation] Retry {cnt} / 3 times")
        output_item = {
            "index": index,
            "round": item.get("round"),
            "mode": item.get("mode"),
            "system_prompt": item["system_prompt"],
            "user_prompt": agent_inputs,
            "action_list": action_list,
            "ground_truth": item["ground_truth"],
            "choice_space": choice_space,
        }
        
        output_queue.put(output_item)
        pbar.update(1)
        input_queue.task_done()
        # update token count
        for key, value in agent.get_token_count().items():
            token_count[key] += value
    # print token count
    print(token_count)

def main_multi_thread():
    # init 
    config_file = "/mnt/workspace/workgroup/Benchmark/project-code/configs/evaluate_config.yaml"
    # config_file = "/mnt/workspace/workgroup/project-code/configs/evaluate_config.yaml"

    configs_lookup = init_configs_lookup(config_file)
    client = init_client(config_file)
    mode = configs_lookup.get("MODE", "singleturn").lower()
    dataloader = DataLoader(configs_lookup["DATASET_PATH"])
    api_sampler = APISampler(configs_lookup["ACTION_SPACE_PATH"])
    N_thread = configs_lookup["THREADING_NUM"]

    # prepare data
    input_queue = Queue()
    output_queue = Queue()
    
    data_num = 0
    for index, data in tqdm(dataloader.get_data_iterator(), desc="Loading Data", total=dataloader.get_data_num()):
        if mode == "singleturn":
            user_prompt = make_user_prompt(data, mode)
            item = {
                "index": index,
                "round": 1,
                "system_prompt": data["candidate_system_prompt"],
                "inputs": user_prompt,
                "ground_truth": data["ground_truth_action"],
                "choice_space": data['choice_space'],
                "mode": mode,
            }
            input_queue.put(item)
            data_num += 1
        else:
            history_predictions = []
            for round_data in data["multi_turn_dialogue"]:
                r_idx = round_data["round"]
                user_prompt = make_user_prompt(data, mode, round_idx=r_idx, prev_intents=history_predictions)
                round_gt = [gt for gt in data["ground_truth_action"] if gt["Round"] == r_idx]
                item = {
                    "index": index,
                    "round": r_idx,
                    "system_prompt": data["candidate_system_prompt"],
                    "inputs": user_prompt,
                    "ground_truth": round_gt,
                    "choice_space": data['choice_space'],
                    "mode": mode,
                    "history_predictions": history_predictions
                }
                input_queue.put(item)
                data_num += 1

    
    # threading
    pbar = tqdm(total=data_num, desc="Evaluating...", unit="item")
    threads = []
    for _ in range(N_thread):
        t = threading.Thread(
            target=worker, 
            args=(input_queue, output_queue, client, pbar)
        )
        t.daemon = True
        t.start()
        threads.append(t)

    # add END
    for _ in range(N_thread):
        input_queue.put(None)
    
    # wait
    input_queue.join()
    pbar.close()

    # final results & post-processing
    final_outputs = list(output_queue.queue)
    answer_list = []
    gt_list = []
    for item in final_outputs:
        index = item["index"]
        r_idx = item.get("round", 1)
        action_list = item["action_list"]
        ground_truth = item["ground_truth"]
        choice_space = item["choice_space"]
        system_prompt = item["system_prompt"]
        inputs = item["user_prompt"]
        if item.get("mode") == "multiturn":
            gt_tools = [gt["Tool"] for gt in ground_truth]
            action_idx_list = get_idx_from_choice_space(choice_space, action_list)
            gt_idx_list = get_idx_from_choice_space(choice_space, gt_tools)
        else:
            action_idx_list = get_idx_from_choice_space(choice_space, action_list)
            gt_idx_list = get_idx_from_choice_space(choice_space, ground_truth, key="Tool")

        answer_item = {
            "index": index,
            "round": r_idx,
            "system_prompt": system_prompt,
            "user_prompt": inputs,
            "action": action_list,                 
            "action_idx": action_idx_list,               
            "ground_truth": ground_truth,           
            "ground_truth_idx": gt_idx_list,        
            "choice_space": choice_space,
        }
        answer_list.append(answer_item)

    # DataFrame
    answer_df = pd.DataFrame(answer_list).sort_values("index")

    # save
    save_dir = configs_lookup["RESULT_SAVE_PATH"]
    answer_df.to_csv(save_dir, index=False)

if __name__ == "__main__":
    # main()
    main_multi_thread()