import requests
import time
import yaml
import json
import json_repair
import pandas as pd
import pickle
import os
import re
import argparse
import wandb
from tqdm import tqdm

import multiprocessing
from functools import partial

index = 1
summary_ask_flag = True
summary_answer_flag = True

def load_config():
    with open("config.yaml", "r") as f:
        return yaml.safe_load(f)

def load_conversations():
    with open("conversations_cleaned.json", "r") as f:
        return json.load(f)

def post_message(messages, config):
    api_key = config[config["model"]]["api_key"]
    base_url = config[config["model"]]["api_base"]
    headers = {
      "Content-Type": "application/json",
      "Authorization": f"Bearer {api_key}"
    }
    if config["model"] == "abab6.5s-chat":
        for message in messages:
            message["sender_type"] = "USER"
            message["sender_name"] = "user"
            message["text"] = message["content"]
        payload = {
            "model": config["model"],
            "bot_setting":[
                {
                    "bot_name":"MM智能助理",
                    "content":"MM智能助理是一款由MiniMax自研的，没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司，一直致力于进行大模型相关的研究。",
                }
            ],
            "reply_constraints":{"sender_type":"BOT", "sender_name":"MM智能助理"},
            "messages": messages,
}

    else:
        payload = {
            "model": config["model"],
            "messages": messages,
            #   "max_tokens": 300
        }
    response = requests.post(base_url, headers=headers, json=payload)

    return response.json()

def parse_response(content, levels):
    # try:
    #     matches = re.findall(r'\[([^\]]+)\]', response)
    #     answer = matches[0]
    #     if answer not in levels:
    #         raise ValueError(f"The answer {answer} is not in the list of levels: {levels}")
    #     else:
    #         return answer
    # except Exception as e:
    #     print(e)
    #     print(response)
    #     raise SystemError("LLM response parsing error!!!")

    # try:
    #     match = re.search(r'\{(.*?)\}', content, re.DOTALL)
    #     if match:
    #         content = match.group(1)
    #         content = '{' + content + '}'
    #         result = eval(content)
    #     else:
    #         result = eval(content.strip())
    # except Exception as e:
    #     print(e)
    #     print(content)
    #     raise SystemError("LLM response parsing error!!!")
    
    # if not isinstance(result, dict):
    #     print(content)
    #     raise ValueError("LLM Response Structure Error!")
    
    # if result["level"] not in levels:
    #     raise ValueError(f"The level {result['level']} is not in the list of levels: {levels}")
    
    pattern = r"```json\s+(.+?)\s+```"
    match = re.search(pattern, content, re.DOTALL)
    if match:
        json_str = match.group(1)
    else:
        json_str = content
    
    try:
        parsed_obj = json_repair.loads(json_str)
        assert "level" in parsed_obj
    except Exception:
        try:
            level_start = json_str.find("level")
            if level_start == -1:
                raise SystemError("LLM Response Structure Error!!!")
            level_end = json_str.find("}", level_start)
            new_json_str = "{\"" + json_str[level_start:level_end]
            parsed_obj = json_repair.loads(new_json_str)
            assert "level" in parsed_obj
            analysis_start = json_str.find("analysis")
            analysis_end = json_str.find(":", analysis_start)
            parsed_obj["analysis"] = json_str[analysis_end+1:level_start-1]
        except Exception as e:
            print(e)
            print(content)
            raise SystemError("LLM Response Structure Error!!!")
    
    if parsed_obj["level"] not in levels:
        raise ValueError(f"The level {parsed_obj['level']} is not in the list of levels: {levels}")
    return parsed_obj

def do_query(messages, levels, config):
    costs = []
    for _ in range(config["max_retry"]):
        try:
            response = post_message(messages, config)
            if config["model"] == "abab6.5s-chat":
                cost = response["usage"]["total_tokens"]*config[config["model"]]["cost_input"]
                content = response["reply"]
            else:
                cost = response["usage"]["prompt_tokens"]*config[config["model"]]["cost_input"] + response["usage"]["completion_tokens"]*config[config["model"]]["cost_output"]
                content = response["choices"][0]["message"]["content"]
            costs.append(cost)
            answer = parse_response(content, levels)
            break
        except Exception as e:
            print(e)
            # print(response)
            time.sleep(5)
            if _ == config["max_retry"] - 1:
                raise SystemError("Experiment interrupted with max_retry reached!!!")
    return response, costs

def play_answer_messages(topic, name1, background1, name2, background2, conversation, levels, config):
    messages = []
    message_set = []

    turn = config["turns"]
    if config["all_turns"]:
        turn = "all"
    
    few_shot_prompt = ""
    if config["few_shot"]:
        few_shot_prompt = " You can refer to the following examples to help you make the judgment.\n\n# Examples\n\n"
        few_shot_df = pd.read_csv("./few_shot.csv")
        few_shot_df = few_shot_df[few_shot_df["role"] == "answer"]
        for i in range(len(few_shot_df)):
            few_shot_prompt += f"## Level: {few_shot_df.loc[i, 'level']}\n"
            few_shot_prompt += f"{few_shot_df.loc[i, 'content']}\n\n"

    system_prompt = f"""# Requirement
You are an expression and content evaluation assistant. Your task is to determine which of the following levels the speaker of a given passage belongs to: {levels}. The given passage contains {turn} responses of the speaker from a conversation. Analyze the speaking style and the knowledge level presented in the content to make your judgment. Return the most appropriate level in the following format.{few_shot_prompt}

# Response Format
```
{{
    "analysis": "the analysis of the passage",
    "level": "the most appropriate level"
}}
```"""
    
    if config["all_turns"]:
        messages = []
        messages.append({
            "role": "user",
            "content": system_prompt
        })
        text = "# The given passage\n"
        for i, sentence in enumerate(conversation):
            if sentence["speaker"] == name1:
                continue
            else:
                text += f"\n{sentence['text']}\n"
        messages.append({
            "role": "user",
            "content": text
        })
        message_set.append(messages.copy())
        return message_set

    for i, sentence in enumerate(conversation):
        if sentence["speaker"] == name1:
            continue
        else:
            messages = []
            messages.append({
                "role": "user",
                "content": system_prompt
            })
            text = "# The given passage\n"
            if i+(config["turns"]-1)*2 >= len(conversation):
                break
            for j in range(config["turns"]):
                text += f"\n{conversation[i+j*2]['text']}\n"
            messages.append({
                "role": "user",
                "content": text
            })
            message_set.append(messages.copy())

    return message_set

def play_ask_messages(topic, name1, background1, name2, background2, conversation, levels, config):
    messages = []
    message_set = []
    
    turn = config["turns"]
    if config["all_turns"]:
        turn = "all"

    few_shot_prompt = ""
    if config["few_shot"]:
        few_shot_prompt = " You can refer to the following examples to help you make the judgment.\n\n# Examples\n\n"
        few_shot_df = pd.read_csv("./few_shot.csv")
        few_shot_df = few_shot_df[few_shot_df["role"] == "ask"]
        few_shot_df = few_shot_df.reset_index(drop=True)

        for i in range(len(few_shot_df)):
            few_shot_prompt += f"## Level: {few_shot_df.loc[i, 'level']}\n"
            few_shot_prompt += f"{few_shot_df.loc[i, 'content']}\n\n"

    system_prompt = f"""# Requirement
You are an expression and content evaluation assistant. Your task is to determine which of the following levels a given passage is intended for: {levels}. The given passage contains {turn} responses of the speaker from a conversation. Analyze the speaking style and the knowledge level presented in the content to make your judgment. Return the most appropriate level in the following format.{few_shot_prompt}

# Response Format
```
{{
    "analysis": "the analysis of the passage",
    "level": "the most appropriate level"
}}
```"""
    
    if config["all_turns"]:
        messages = []
        messages.append({
            "role": "user",
            "content": system_prompt
        })
        text = "# The given passage\n"
        for i, sentence in enumerate(conversation):
            if sentence["speaker"] == name2:
                continue
            else:
                text += f"\n{sentence['text']}\n"
        messages.append({
            "role": "user",
            "content": text
        })
        message_set.append(messages.copy())
        return message_set

    for i, sentence in enumerate(conversation):
        if sentence["speaker"] == name2:
            continue
        else:
            messages = []
            messages.append({
                "role": "user",
                "content": system_prompt
            })
            text = "# The given passage\n"
            if i+(config["turns"]-1)*2 >= len(conversation):
                break
            for j in range(config["turns"]):
                text += f"\n{conversation[i+j*2]['text']}\n"
            messages.append({
                "role": "user",
                "content": text
            })
            message_set.append(messages.copy())

    return message_set

def process_message(enum_messages, config, topic, message_set, gt, levels):
    i, messages = enum_messages
    response, costs = do_query(messages, levels, config)
    cost = sum(costs)
    if config["model"] == "abab6.5s-chat":
        content = response["reply"]
    else:
        content = response["choices"][0]["message"]["content"]
    answer = parse_response(content, levels)
    acc = 1 if answer["level"] == gt else 0

    return {
        "topic": topic,
        "prompt": str(messages),
        "response": content,
        "answer": answer["level"],
        "analysis": answer["analysis"],
        "accuracy": acc,
        "model": config["model"],
        "ground_truth": gt,
        "cost": cost
    }

def play_multiprocess(message_set, role, gt, levels, topic, name1, background1, name2, background2, conversation, config):
    global index, summary_ask_flag, summary_answer_flag

    few_shot_path = ""
    if config["few_shot"]:
        few_shot_path = "_few-shot"

    run_name = f"classification{few_shot_path}_{index}_{role}_{config['turns']}_turns_{topic}_{config['model']}_{name2}"
    if config["all_turns"]:
        run_name = f"classification{few_shot_path}_{index}_{role}_all_turns_{topic}_{config['model']}_{name2}"
    run_name = run_name.replace(" ", "")

    dir_path = f"./results/classification{few_shot_path}/{config['model']}/{config['turns']}_turns/{role}"
    if config["all_turns"]:
        dir_path = f"./results/classification{few_shot_path}/{config['model']}/all_turns/{role}"
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    
    if config["resume"] and os.path.exists(f"{dir_path}/{run_name}.csv"):
        index += 1
        return

    project = f"llm_classification{few_shot_path}_{config['turns']}_turns_{config['model']}"
    if config["all_turns"]:
        project = f"llm_classification{few_shot_path}_all_turns_{config['model']}"
    wandb.init(project=project, name=run_name, entity=config["wandb"]["entity"])

    total_cost = 0
    columns = ["topic", "model", "prompt", "response", "analysis", "answer", "ground_truth", "accuracy", "cost"]
    df = pd.DataFrame(columns=columns)

    with multiprocessing.Pool(processes=config["num_processes"]) as pool:
        process_func = partial(process_message, config=config, topic=topic, message_set=message_set, gt=gt, levels=levels)
        chunksize = max(1, len(message_set) // (4 * config["num_processes"]))
        results = list(tqdm(pool.imap(process_func, enumerate(message_set), chunksize=chunksize), total=len(message_set)))

    for result in results:
        df = pd.concat([df, pd.DataFrame([result])], ignore_index=True)
        total_cost += result["cost"]
    
    df.to_csv(f"{dir_path}/{run_name}.csv", index=False)

    table = wandb.Table(data=df)
    wandb.log({f"results_{run_name}": table}, commit=True)

    pickle.dump(results, open(f"{dir_path}/{run_name}.pkl", "wb"))

    wandb.log({"total_cost": total_cost}, commit=True)

    acc = df["accuracy"].mean()
    wandb.log({"accuracy": acc}, commit=True)
    wandb.log({"num_samples": len(df)}, commit=True)

    summary_path = f"{dir_path}/summary_classification{few_shot_path}_{role}_{config['model']}.csv"
    summary_path = summary_path.replace(" ", "")
    if role == "ask":
        summary_flag = summary_ask_flag
    else:
        summary_flag = summary_answer_flag
    if not config["resume"] and summary_flag:
        summary = pd.DataFrame(columns=["level", "accuracy", "num_samples", "total_cost"])
        for level in levels:
            summary.loc[len(summary)] = [level, 0.0, 0, 0.0]
        summary.loc[len(summary)] = ["total", 0.0, 0, 0.0]
        summary.to_csv(summary_path, index=False)
        if role == "ask":
            summary_ask_flag = False
        else:
            summary_answer_flag = False
    else:
        summary = pd.read_csv(summary_path)

    old_num_samples = summary.loc[summary["level"] == gt, "num_samples"].values[0]
    old_accuracy = summary.loc[summary["level"] == gt, "accuracy"].values[0]

    summary.loc[summary["level"] == gt, "accuracy"] = (old_accuracy*old_num_samples + acc*len(df)) / (old_num_samples + len(df))
    summary.loc[summary["level"] == gt, "num_samples"] += len(df)
    summary.loc[summary["level"] == gt, "total_cost"] += total_cost

    summary.loc[summary["level"] == "total", "num_samples"] = summary.loc[summary["level"] != "total", "num_samples"].sum()
    summary.loc[summary["level"] == "total", "total_cost"] = summary.loc[summary["level"] != "total", "total_cost"].sum()
    acc = 0
    for i in range(len(summary)):
        if summary.loc[i, "level"] != "total":
            acc += summary.loc[i, "accuracy"] * summary.loc[i, "num_samples"]
    summary.loc[len(summary)-1, "accuracy"] = acc / summary.loc[len(summary)-1, "num_samples"]

    summary.to_csv(summary_path, index=False)
    wandb.finish()
    index += 1

def main(args):
    config = load_config()
    conversations = load_conversations()

    config["model"] = args.model
    config["turns"] = args.turns
    config["all_turns"] = args.all_turns
    config["resume"] = args.resume
    config["few_shot"] = args.few_shot
    config["num_processes"] = args.num_processes
    for conversation in conversations["conversations_by_topic"]:
        topic = conversation["topic"]
        name1 = conversation["main_speaker"]["name"]
        background1 = conversation["main_speaker"]["background"]
        dialogues = conversation["dialogues"]

        levels = []
        for dialogue in dialogues:
            level = dialogue["other_speaker"]["class"]
            levels.append(level)

        for dialogue in dialogues:
            name2 = dialogue["other_speaker"]["name"]
            background2 = dialogue["other_speaker"]["background"]

            answer_messages = play_answer_messages(topic, name1, background1, name2, background2, dialogue["conversation"], levels, config)

            ask_messages = play_ask_messages(topic, name1, background1, name2, background2, dialogue["conversation"], levels, config)
            gt = dialogue["other_speaker"]["class"]

            if len(ask_messages) > 0:
                play_multiprocess(ask_messages, "ask", gt, levels, topic, name1, background1, name2, background2, dialogue["conversation"], config)
            if len(answer_messages) > 0:
                play_multiprocess(answer_messages, "answer", gt, levels, topic, name1, background1, name2, background2, dialogue["conversation"], config)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process state information.")
    parser.add_argument("--model", type=str, default="gpt-4o", help="Model to use")
    parser.add_argument("--num_processes", type=int, default=10, help="Number of processes to use")
    parser.add_argument("--turns", type=int, default=5, help="Number of turns to evaluate")
    parser.add_argument("--all_turns", action="store_true", help="Evaluate all turns")
    parser.add_argument("--resume", action="store_true", help="Resume the experiment")
    parser.add_argument("--few_shot", action="store_true", help="Use few-shot prompt")

    args = parser.parse_args()
    
    main(args)