import os
import yaml
import json
import time
import random
import pandas as pd
from tqdm import tqdm
from queue import Queue
import threading

from LLMClient import init_client
from Agents import ProblemWeaver, SentenceAgent, MergeSentenceAgent
from API_Manager import API_Pool


def ensure_dir(dir:str):
    if not os.path.exists(dir):
        os.makedirs(dir)

"""
Load product info
"""
class ProductSampler:
    def __init__(self, file_path:str):
        self.items = pd.read_json(file_path, lines=True)

    def get_item_num(self) -> int:
        return len(self.items)

    def get_sampling_iterator(self, N:int, random_seed:int=42) -> list:
        item_num = self.get_item_num()
        if N > item_num:
            raise ValueError(f"N ({N}) is larger than the number of items ({item_num})")
        sampled_items = self.items.sample(N, random_state=random_seed).to_dict(orient="records")
        for item in sampled_items:
            yield item

    def get_item_from_idx(self, idx:int) -> dict:
        item_num = self.get_item_num()
        if idx >= item_num:
            raise ValueError(f"Index {idx} is out of range ({item_num})")
        return self.items.iloc[idx].to_dict()

"""
Load action space
"""
class APISampler:
    def __init__(self, file_path:str):
        self.action_space = self.load_actions(file_path)

    def load_actions(self, file_path):
        ## load action from .jsonl files
        # action_space = pd.read_json(file_path, lines=True)
        # return action_space

        ## load action from aggregation .json file
        action_space = json.load(open(file_path, "r"))
        action_space_list = []
        for intent in action_space:
            for action in action_space[intent]:
                for tool in action_space[intent][action]:
                    action_item = {
                        "Intent": intent,
                        "Action": action,
                        "Tool": tool,
                        "Tool_Desc": action_space[intent][action][tool]["description"],
                        "Tool_Params": action_space[intent][action][tool]["parameters"],
                    }
                    action_space_list.append(action_item)
        df_action_space = pd.DataFrame(action_space_list)
        return df_action_space

    def get_action_num(self) -> int:
        return len(self.action_space)

    def sample_action_list(self, N:int, max_choice_space:int=20) -> list:
        sample_done = False
        while not sample_done:
            # sample ground truth action
            sampled_actions = self.action_space.sample(N)
            sampled_action_list = []
            for index, row in sampled_actions.iterrows():
                action_dict = row.to_dict()
                item = {
                    "Intent": action_dict["Intent"],
                    "Action": action_dict["Action"],
                    "Tool": action_dict["Tool"],
                    "Tool_Desc": action_dict["Tool_Desc"],
                }
                sampled_action_list.append(item)
            # sample candidate actions
            candidate_actions = [x for x in sampled_action_list]
            if N < max_choice_space:
                remaining_df = self.action_space[~self.action_space.index.isin(sampled_actions.index)]
                # 从剩余的 DataFrame 中采样无关行动
                additional_actions = remaining_df.sample(max_choice_space - N)
                
                for index, row in additional_actions.iterrows():
                    action_dict = row.to_dict()
                    item = {
                        "Intent": action_dict["Intent"],
                        "Action": action_dict["Action"],
                        "Tool": action_dict["Tool"],
                        "Tool_Desc": action_dict["Tool_Desc"],
                    }
                    candidate_actions.append(item)
                random.shuffle(candidate_actions)
            sample_done = True
            # 检查 candidate 没有重复的，否则重新采样
            tool_name_space = []
            for item in candidate_actions:
                tool_name = item["Tool"]
                if tool_name in tool_name_space:
                    sample_done = False
                    break
                else:
                    tool_name_space.append(tool_name)
        return sampled_action_list, candidate_actions


"""
Load configs from config file
"""
def init_configs_lookup(config_file_path:str):
    configs_lookup = {}
    config = yaml.safe_load(open(config_file_path, 'r'))
    for key in config.keys():
        # ignore API_KEY
        if key in ["API_KEY"]:
            continue
        file_path = config[key]
        configs_lookup[key] = file_path
    # ensure directory exists
    # ensure_dir(configs_lookup['SAVE_PATH'])
    return configs_lookup

"""
Load prompt / text from files
"""
def init_prompt_lookup(config_file_path:str):
    prompt_lookup = {
        "PROMPT_WEAVER_PROMPT": "",
        "CANDIDATE_SYSTEM_PROMPT": "",
        "MULTITURN_PROMPT": "",
        "PROMPT_SENTENCE_AGENT": "",
        "PROMPT_SENTENCE_MERGE": "",
    }
    config = yaml.safe_load(open(config_file_path, 'r'))
    for key in prompt_lookup.keys():
        file_path = config.get(key, "")
        if os.path.exists(file_path):
            prompt_lookup[key] = open(file_path, 'r').read()
        else:
            prompt_lookup[key] = ""
    return prompt_lookup

def main():
    # INIT
    config_file_path = "/mnt/workspace/workgroup/Benchmark/project-code/configs/scripts2task_config.yaml"
    client = init_client(config_file_path)
    prompt_lookup = init_prompt_lookup(config_file_path)
    configs_lookup = init_configs_lookup(config_file_path)

    # by default, mode is singleturn
    mode = configs_lookup.get("MODE", "singleturn").lower()
    # read turns from config, applied only for multiturn mode
    MIN_TURNS = configs_lookup.get("MIN_TURNS", 2)
    MAX_TURNS = configs_lookup.get("MAX_TURNS", 4)


    # init conponents
    product_sampler = ProductSampler(configs_lookup["PRODUCT_FILE_PATH"])
    action_sampler = APISampler(configs_lookup["ACTION_FILE_PATH"])

    if mode == "multiturn":
        problem_weaver = ProblemWeaver(client, prompt_lookup["MULTITURN_PROMPT"])
    else:
        problem_weaver = ProblemWeaver(client, prompt_lookup["PROMPT_WEAVER_PROMPT"])

    # RUN
    generation_num = configs_lookup["GENERATION_NUM"]
    action_choice_space = configs_lookup["ACTION_CHOICE_SPACE"]
    item_num = product_sampler.get_item_num()
    dialogue_collection = []
    for i in tqdm(range(generation_num), desc="Generating Dialogues"):
        # product smapling
        item_idx = random.randint(0, item_num-1)
        item_info = product_sampler.get_item_from_idx(item_idx)
        # print(item_info)

        # action list sampling

        if mode == "multiturn":
            # multi-turn：action num = turns，limited to MIN_TURNS / MAX_TURNS 
            action_list_length = random.randint(MIN_TURNS, MAX_TURNS)
        else:
            # single-turn：keep the same
            action_list_length = random.randint(1, action_choice_space)

        action_list, candidate_actions = action_sampler.sample_action_list(action_list_length, max_choice_space=20)
        # print("="*40, f"{action_list_length}-{len(action_list)}")
        # for ac in action_list:
        #     print(ac["Tool"])
        # print("="*40)

        # # dialogue generation (remove for multi-turn adaptation)
        # pw_input = prompt_lookup["PROMPT_WEAVER_PROMPT"].replace("{product_info}", str(item_info)).replace("{ground_truth_action}", str(action_list))
        # pw_output = problem_weaver.run(pw_input)
        # # print("Problem Weaver: ", pw_output)
        # # merge
        # candidate_system_prompt = prompt_lookup["CANDIDATE_SYSTEM_PROMPT"].replace("{action_choice_space}", str(action_choice_space))
        # ground_truth_action = [
        #     {"Tool": item["Tool"], "Description": item["Tool_Desc"]} for item in action_list]
        # buyer_question = pw_output["buyer_question"]
        # choice_space = [
        #     {"Tool": item["Tool"], "Description": item["Tool_Desc"]} for item in candidate_actions]
        # dialogue_item = {
        #     "candidate_system_prompt": candidate_system_prompt,
        #     "product_info": item_info,
        #     "ground_truth_action": ground_truth_action,
        #     "choice_space": choice_space,
        #     "buyer_question": buyer_question,
        # }
        candidate_system_prompt = prompt_lookup["CANDIDATE_SYSTEM_PROMPT"].replace(
            "{action_choice_space}", str(action_choice_space)
        )

        # ground_truth_action = [
        #     {"Tool": item["Tool"], "Description": item["Tool_Desc"]}
        #     for item in action_list
        # 

        # 带 Round 字段的 ground_truth_action
        ground_truth_action = [
            {"Round": idx + 1, "Tool": item["Tool"], "Description": item["Tool_Desc"]}
            for idx, item in enumerate(action_list)
        ]

        # 提供给多轮 prompt 的 ground_truth_conditions
        ground_truth_conditions = [item["Tool"] for item in action_list]
        choice_space = [
            {"Tool": item["Tool"], "Description": item["Tool_Desc"]}
            for item in candidate_actions
        ]

        if mode == "singleturn":
            # singleturn turn prompt
            pw_input = prompt_lookup["PROMPT_WEAVER_PROMPT"]\
                .replace("{product_info}", json.dumps(item_info, ensure_ascii=False))\
                .replace("{ground_truth_action}", json.dumps(action_list, ensure_ascii=False))
            pw_output = problem_weaver.run(pw_input)
            buyer_question = pw_output["buyer_question"]
            dialogue_item = {
                "candidate_system_prompt": candidate_system_prompt,
                "product_info": item_info,
                "ground_truth_action": ground_truth_action,
                "choice_space": choice_space,
                "buyer_question": buyer_question,
            }

        elif mode == "multiturn":
            # multi-turn prompt
            ground_truth_conditions = [item["Tool"] for item in action_list]
            pw_input = prompt_lookup["MULTITURN_PROMPT"]\
                .replace("{product_info}", json.dumps(item_info, ensure_ascii=False))\
                .replace("{ground_truth_conditions}", json.dumps(ground_truth_conditions, ensure_ascii=False))\
                .replace("{number_of_turns}", str(len(ground_truth_conditions)))

            pw_output = problem_weaver.run(pw_input)  # LLM return value

            # make sure load JSON properly
            if isinstance(pw_output, str):
                try:
                    pw_output_json = json.loads(pw_output)
                except json.JSONDecodeError:
                    print(f"[Warning] LLM return multi-turn isn't valid JSON, full text stored in: {pw_output}")
                    pw_output_json = pw_output
            else:
                pw_output_json = pw_output

            dialogue_item = {
                "candidate_system_prompt": candidate_system_prompt,
                "product_info": item_info,
                "ground_truth_action": ground_truth_action,
                "choice_space": choice_space,
                "multi_turn_dialogue": pw_output_json
            }

        else:
            raise ValueError(f"Unknown MODE: {mode}")

        # for key, value in dialogue_item.items():
        #     print(f"[{key}]: {value}")
        dialogue_collection.append(dialogue_item)
        # break


    # save as .jsonl
    save_path = configs_lookup["DIALOGUE_SAVE_PATH"]
    with open(save_path, "w") as f:
        for item in dialogue_collection:
            f.write(json.dumps(item) + "\n")

def worker(input_queue, output_queue, agent, pbar, mode):
    while True:
        item = input_queue.get()
        if item is None:
            input_queue.task_done()
            break
        agent_inputs = item["inputs"]
        output_item = item["outputs"]
        agent_output = agent.run(agent_inputs)

        if mode == "singleturn":
            output_item["buyer_question"] = agent_output["buyer_question"]
        elif mode == "multiturn":
            if isinstance(agent_output, str):
                try:
                    parsed = json.loads(agent_output)
                except json.JSONDecodeError:
                    print(f"[Warning] Invalid JSON from LLM: {agent_output}")
                    parsed = agent_output
            else:
                parsed = agent_output

            # clear prev_intent_choice
            if isinstance(parsed, list):
                for turn in parsed:
                    turn["prev_intent_choice"] = []

            output_item["multi_turn_dialogue"] = parsed
        else:
            raise ValueError(f"Unknown MODE: {mode}")

        output_queue.put(output_item)
        pbar.update(1)
        input_queue.task_done()

def worker_multi_agents(input_queue, output_queue, sentence_agent, merge_agent, pbar, mode):
    while True:
        item = input_queue.get()
        if item is None:
            input_queue.task_done()
            break
        item_info = item["inputs"]["item_info"]
        action_list = item["inputs"]["action_list"]
        output_item = item["outputs"]
        sentence_agent_input = f"""
## Product Info
{item_info}

## Ground Truth Action
{action_list}
        """
        agent_output = sentence_agent.run(sentence_agent_input)
        # print(agent_output)
        sentence_list = [agent_output[ac] for ac in agent_output]
        merge_agent_input = f"""
## Product Info
{item_info}

## Sentence List:
{sentence_list}
        """
        merge_agent_output = merge_agent.run(merge_agent_input)
        if mode == "singleturn":
            output_item["buyer_question"] = merge_agent_output["question"]
        # elif mode == "multiturn":
        #     if isinstance(agent_output, str):
        #         try:
        #             parsed = json.loads(agent_output)
        #         except json.JSONDecodeError:
        #             print(f"[Warning] Invalid JSON from LLM: {agent_output}")
        #             parsed = agent_output
        #     else:
        #         parsed = agent_output

        #     # clear prev_intent_choice
        #     if isinstance(parsed, list):
        #         for turn in parsed:
        #             turn["prev_intent_choice"] = []

        #     output_item["multi_turn_dialogue"] = parsed
        # else:
        #     raise ValueError(f"Unknown MODE: {mode}")

        output_queue.put(output_item)
        pbar.update(1)
        input_queue.task_done()

def main_multi_thread():
    # INIT
    config_file_path = "/mnt/workspace/workgroup/Benchmark/project-code/configs/scripts2task_config.yaml"
    # config_file_path = "/mnt/workspace/workgroup/project-code/configs/scripts2task_config.yaml"
    client = init_client(config_file_path)
    prompt_lookup = init_prompt_lookup(config_file_path)
    configs_lookup = init_configs_lookup(config_file_path)

    mode = configs_lookup.get("MODE", "singleturn").lower()
    MIN_TURNS = configs_lookup.get("MIN_TURNS", 2)
    MAX_TURNS = configs_lookup.get("MAX_TURNS", 4)

    # INIT components
    product_sampler = ProductSampler(configs_lookup["PRODUCT_FILE_PATH"])
    action_sampler = APISampler(configs_lookup["ACTION_FILE_PATH"])

    if mode == "multiturn":
        problem_weaver = ProblemWeaver(client, prompt_lookup["MULTITURN_PROMPT"])
    else:
        # problem_weaver = ProblemWeaver(client, prompt_lookup["PROMPT_WEAVER_PROMPT"])
        sentence_agent = SentenceAgent(client, prompt_lookup["PROMPT_SENTENCE_AGENT"])
        merge_agent = MergeSentenceAgent(client, prompt_lookup["PROMPT_SENTENCE_MERGE"])

    # RUN
    N_thread = configs_lookup["THREADING_NUM"]
    generation_num = configs_lookup["GENERATION_NUM"]
    action_choice_space = configs_lookup["ACTION_CHOICE_SPACE"]
    item_num = product_sampler.get_item_num()

    input_queue = Queue()
    output_queue = Queue()

    # prepare for data
    for i in tqdm(range(generation_num), desc="Preparing for data"):
        # get information
        item_idx = random.randint(0, item_num - 1)
        item_info = product_sampler.get_item_from_idx(item_idx)

        if mode == "multiturn":
            action_list_length = random.randint(MIN_TURNS, MAX_TURNS)
        elif mode == "singleturn":
            action_list_length = random.randint(1, action_choice_space)

        action_list, candidate_actions = action_sampler.sample_action_list(
            action_list_length, max_choice_space=20
        )

        # actions
        # ground_truth_action = [
        #     {"Tool": item["Tool"], "Description": item["Tool_Desc"]}
        #     for item in action_list
        # ]


        ground_truth_action = [
            {"Round": idx + 1, "Tool": item["Tool"], "Description": item["Tool_Desc"]}
            for idx, item in enumerate(action_list)
        ]

        ground_truth_conditions = [item["Tool"] for item in action_list]
        
        choice_space = [
            {"Tool": item["Tool"], "Description": item["Tool_Desc"]}
            for item in candidate_actions
        ]

        candidate_system_prompt = prompt_lookup["CANDIDATE_SYSTEM_PROMPT"].replace(
            "{action_choice_space}", str(action_choice_space)
        )

        if mode == "singleturn":
            # agent_inputs = prompt_lookup["PROMPT_WEAVER_PROMPT"]\
            #     .replace("{product_info}", json.dumps(item_info, ensure_ascii=False))\
            #     .replace("{ground_truth_action}", json.dumps(action_list, ensure_ascii=False))
            item_info = json.dumps(item_info, ensure_ascii=False)
            action_list = json.dumps(action_list, ensure_ascii=False)
            agent_inputs = {
                "item_info": item_info,
                "action_list": action_list,
            }
            dialogue_item = {
                "candidate_system_prompt": candidate_system_prompt,
                "product_info": item_info,
                "ground_truth_action": ground_truth_action,
                "choice_space": choice_space,
                "buyer_question": "",
            }

        elif mode == "multiturn":
            # ground_truth_conditions = [item["Tool"] for item in action_list]
            agent_inputs = prompt_lookup["MULTITURN_PROMPT"]\
                .replace("{product_info}", json.dumps(item_info, ensure_ascii=False))\
                .replace("{ground_truth_conditions}", json.dumps(ground_truth_conditions, ensure_ascii=False))\
                .replace("{number_of_turns}", str(len(ground_truth_conditions)))
            dialogue_item = {
                "candidate_system_prompt": candidate_system_prompt,
                "product_info": item_info,
                "ground_truth_action": ground_truth_action,
                "choice_space": choice_space,
                "multi_turn_dialogue": [],
            }
        else:
            raise ValueError(f"Unknown MODE: {mode}")

        input_item = {
            "inputs": agent_inputs,
            "outputs": dialogue_item,
        }
        input_queue.put(input_item)

    # add END
    for _ in range(N_thread):
        input_queue.put(None)

    # threading
    threads = []
    pbar = tqdm(total=generation_num, desc="Generating...", unit="item")
    if mode == "singleturn":
        for _ in range(N_thread):
            t = threading.Thread(
                target=worker_multi_agents,
                args=(input_queue, output_queue, sentence_agent, merge_agent, pbar, mode),
            )
            t.daemon = True
            t.start()
            threads.append(t)
    elif mode == "multiturn":
        for _ in range(N_thread):
            t = threading.Thread(
                target=worker,
                args=(input_queue, output_queue, problem_weaver, pbar, mode),
            )
            t.daemon = True
            t.start()
            threads.append(t)

    # wait
    input_queue.join()
    pbar.close()

    # get results
    final_outputs = list(output_queue.queue)
    save_path = configs_lookup["DIALOGUE_SAVE_PATH"]
    with open(save_path, "w") as f:
        for item in final_outputs:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")


if __name__ == "__main__":
    # main()
    main_multi_thread()
