import requests
import time
import yaml
import json
import json_repair
import pandas as pd
import pickle
import os
import re
import argparse
import wandb
from tqdm import tqdm

import multiprocessing
from functools import partial

import random
random.seed(42)

index = 1
summary_ask_flag = True
summary_answer_flag = True

def load_config():
    with open("config.yaml", "r") as f:
        return yaml.safe_load(f)

def load_conversations():
    with open("conversations_cleaned.json", "r") as f:
        return json.load(f)

def post_message(messages, config):
    api_key = config[config["model"]]["api_key"]
    base_url = config[config["model"]]["api_base"]
    headers = {
      "Content-Type": "application/json",
      "Authorization": f"Bearer {api_key}"
    }

    payload = {
      "model": config["model"],
      "messages": messages,
    #   "max_tokens": 300
    }
    response = requests.post(base_url, headers=headers, json=payload)
    # print(response)
    # print(response.json())

    return response.json()

def parse_response(content, levels):
    # try:
    #     matches = re.findall(r'\[([^\]]+)\]', response)
    #     answer = matches[0]
    #     if answer not in levels:
    #         raise ValueError(f"The answer {answer} is not in the list of levels: {levels}")
    #     else:
    #         return answer
    # except Exception as e:
    #     print(e)
    #     print(response)
    #     raise SystemError("LLM response parsing error!!!")
    
    # try:
    #     match = re.search(r'\{(.*?)\}', content, re.DOTALL)
    #     if match:
    #         content = match.group(1)
    #         content = '{' + content + '}'
    #         result = eval(content)
    #     else:
    #         result = eval(content.strip())
    # except Exception as e:
    #     print(e)
    #     print(content)
    #     raise SystemError("LLM response parsing error!!!")
    
    # if not isinstance(result, dict):
    #     print(content)
    #     raise ValueError("LLM Response Structure Error!")
    
    # if result["level"] not in levels:
    #     raise ValueError(f"The level {result['level']} is not in the list of levels: {levels}")
    
    pattern = r"```json\s+(.+?)\s+```"
    match = re.search(pattern, content, re.DOTALL)
    if match:
        json_str = match.group(1)
    else:
        json_str = content
    
    try:
        parsed_obj = json_repair.loads(json_str)
        assert "choice" in parsed_obj
    except Exception:
        try:
            choice_start = json_str.find("choice")
            if choice_start == -1:
                raise SystemError("LLM Response Structure Error!!!")
            choice_end = json_str.find("}", choice_start)
            new_json_str = "{\"" + json_str[choice_start:choice_end]
            parsed_obj = json_repair.loads(new_json_str)
            assert "choice" in parsed_obj
            analysis_start = json_str.find("analysis")
            analysis_end = json_str.find(":", analysis_start)
            parsed_obj["analysis"] = json_str[analysis_end+1:choice_start-1]
        except Exception as e:
            print(e)
            print(content)
            raise SystemError("LLM Response Structure Error!!!")
    
    return parsed_obj

def do_query(messages, levels, config):
    costs = []
    response = ""
    for _ in range(config["max_retry"]):
        try:
            response = post_message(messages, config)
            cost = response["usage"]["prompt_tokens"]*config[config["model"]]["cost_input"] + response["usage"]["completion_tokens"]*config[config["model"]]["cost_output"]
            costs.append(cost)
            answer = parse_response(response["choices"][0]["message"]["content"], levels)
            break
        except Exception as e:
            print(e)
            if response != "":
                print(response)
            else:
                print("No response!!!")
            time.sleep(5)
            if _ == config["max_retry"] - 1:
                raise SystemError("Experiment interrupted with max_retry reached!!!")
    return response, costs

def generate_answer_messages(topic, name1, background1, level1, level2, dialogue1, dialogue2, levels, config):
    
    paragraph1 = ""
    for i, sentence in enumerate(dialogue1["conversation"]):
        if sentence["speaker"] == name1:
            continue
        else:
            paragraph1 += f"{sentence['text']}\n\n"
    
    paragraph2 = ""
    for i, sentence in enumerate(dialogue2["conversation"]):
        if sentence["speaker"] == name1:
            continue
        else:
            paragraph2 += f"{sentence['text']}\n\n"

    gt = "Paragraph 1"

    if random.random() < 0.5:
        gt = "Paragraph 2"
        paragraph1, paragraph2 = paragraph2, paragraph1
    
    few_shot_prompt = ""
    if config["few_shot"]:
        few_shot_prompt = " You can refer to the following examples to help you make the judgment.\n\n# Examples\n\n"
        few_shot_df = pd.read_csv("./few_shot.csv")
        few_shot_df = few_shot_df[few_shot_df["role"] == "answer"]

        few_shot_paragraph1 = few_shot_df[few_shot_df["level"] == level1]["content"].values[0]
        few_shot_paragraph2 = few_shot_df[few_shot_df["level"] == level2]["content"].values[0]

        few_shot_level1 = level1
        
        few_shot_prompt += f"## Level: {few_shot_level1}\n"
        few_shot_prompt += f"{few_shot_paragraph1}\n\n"

    system_prompt = f"""# Requirement
You are an expression and content evaluation assistant. Your task is to determine which of the following two paragraphs is spoken by a {level1}. The given paragraphs both contain all responses of the speaker from a conversation. Analyze the speaking style and the knowledge level presented in the content to make your judgment. Return the most appropriate paragraph in the following format.{few_shot_prompt}

# Paragraph 1
{paragraph1}

# Paragraph 2
{paragraph2}

# Response Format
```
{{
    "analysis": "the analysis of the paragraphs",
    "choice": "Paragraph 1 or Paragraph 2"
}}
```"""
    
    messages = []
    messages.append({
        "role": "user",
        "content": system_prompt
    })
   
    return messages, gt

def generate_ask_messages(topic, name1, background1, level1, level2, dialogue1, dialogue2, levels, config):
    
    paragraph1 = ""
    name2 = dialogue2["other_speaker"]["name"]
    for i, sentence in enumerate(dialogue1["conversation"]):
        if sentence["speaker"] == name2:
            continue
        else:
            paragraph1 += f"{sentence['text']}\n\n"
    
    paragraph2 = ""
    name2 = dialogue2["other_speaker"]["name"]
    for i, sentence in enumerate(dialogue2["conversation"]):
        if sentence["speaker"] == name2:
            continue
        else:
            paragraph2 += f"{sentence['text']}\n\n"

    gt = "Paragraph 1"

    if random.random() < 0.5:
        gt = "Paragraph 2"
        paragraph1, paragraph2 = paragraph2, paragraph1
    
    few_shot_prompt = ""
    if config["few_shot"]:
        few_shot_prompt = " You can refer to the following examples to help you make the judgment.\n\n# Examples\n\n"
        few_shot_df = pd.read_csv("./few_shot.csv")
        few_shot_df = few_shot_df[few_shot_df["role"] == "ask"]

        few_shot_paragraph1 = few_shot_df[few_shot_df["level"] == level1]["content"].values[0]
        few_shot_paragraph2 = few_shot_df[few_shot_df["level"] == level2]["content"].values[0]

        few_shot_level1 = level1
        
        few_shot_prompt += f"## Level: {few_shot_level1}\n"
        few_shot_prompt += f"{few_shot_paragraph1}\n\n"


    system_prompt = f"""# Requirement
You are an expression and content evaluation assistant. Your task is to determine which of the following two paragraphs is addressed to a {level1}. The given paragraphs both contain all responses of the speaker from a conversation. Analyze the speaking style and the knowledge level presented in the content to make your judgment. Return the most appropriate paragraph in the following format.{few_shot_prompt}

# Paragraph 1
{paragraph1}

# Paragraph 2
{paragraph2}

# Response Format
```
{{
    "analysis": "the analysis of the paragraphs",
    "choice": "Paragraph 1 or Paragraph 2"
}}
```"""
    
    messages = []
    messages.append({
        "role": "user",
        "content": system_prompt
    })
   
    return messages, gt

def process_message(enum_messages, config, topic, message_set, levels):

    i, (messages, gt, level1, level2) = enum_messages
    response, costs = do_query(messages, levels, config)
    cost = sum(costs)
    answer = parse_response(response["choices"][0]["message"]["content"], levels)
    acc = 1 if answer["choice"] == gt else 0

    return {
        "topic": topic,
        "prompt": str(messages),
        "response": response["choices"][0]["message"]["content"],
        "answer": answer["choice"],
        "level1": level1,
        "level2": level2,
        "analysis": answer["analysis"],
        "accuracy": acc,
        "model": config["model"],
        "ground_truth": gt,
        "cost": cost
    }

def play_multiprocess(message_set, role, levels, topic, config):
    global index, summary_ask_flag, summary_answer_flag

    few_shot_path = ""
    if config["few_shot"]:
        few_shot_path = "_few-shot"

    dir_path = f"./results/classification2{few_shot_path}/{config['model']}/{role}"
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    
    run_name = f"classification2{few_shot_path}_{index}_{role}_{topic}_{config['model']}"
    run_name = run_name.replace(" ", "")

    if config["resume"] and os.path.exists(f"{dir_path}/{run_name}.csv"):
        index += 1
        return

    project = f"llm_classification2{few_shot_path}_{config['model']}"
    # project = "test"
    wandb.init(project=project, name=run_name, entity=config["wandb"]["entity"])

    total_cost = 0
    columns = ["model","topic", "level1", "level2", "prompt", "response", "analysis", "answer", "ground_truth", "accuracy", "cost"]
    df = pd.DataFrame(columns=columns)

    with multiprocessing.Pool(processes=config["num_processes"]) as pool:
        process_func = partial(process_message, config=config, topic=topic, message_set=message_set, levels=levels)
        chunksize = max(1, len(message_set) // (4 * config["num_processes"]))
        results = list(tqdm(pool.imap(process_func, enumerate(message_set), chunksize=chunksize), total=len(message_set)))

    for result in results:
        df = pd.concat([df, pd.DataFrame([result])], ignore_index=True)
        total_cost += result["cost"]
    
    df.to_csv(f"{dir_path}/{run_name}.csv", index=False)

    table = wandb.Table(data=df)
    wandb.log({f"results_{run_name}": table}, commit=True)

    pickle.dump(results, open(f"{dir_path}/{run_name}.pkl", "wb"))

    wandb.log({"total_cost": total_cost}, commit=True)

    acc = df["accuracy"].mean()
    wandb.log({"accuracy": acc}, commit=True)
    wandb.log({"num_samples": len(df)}, commit=True)

    summary_path = f"{dir_path}/summary_classification2{few_shot_path}_{role}_{config['model']}.csv"
    summary_path = summary_path.replace(" ", "")
    if role == "ask":
        summary_flag = summary_ask_flag
    else:
        summary_flag = summary_answer_flag
    if not config["resume"] and summary_flag:
        summary = pd.DataFrame(columns=["level1", "level2", "accuracy", "num_samples", "total_cost"])
        for level1 in levels:
            for level2 in levels:
                if level1 == level2:
                    continue
                summary.loc[len(summary)] = [level1, level2, 0.0, 0, 0.0]
        summary.loc[len(summary)] = ["total", "total", 0.0, 0, 0.0]
        summary.to_csv(summary_path, index=False)
        if role == "ask":
            summary_ask_flag = False
        else:
            summary_answer_flag = False
    else:
        summary = pd.read_csv(summary_path)
    
    for i in range(len(df)):
        level1 = df.loc[i, "level1"]
        level2 = df.loc[i, "level2"]

        old_num_samples = summary[(summary["level1"] == level1) & (summary["level2"] == level2)]["num_samples"].values[0]
        old_accuracy = summary[(summary["level1"] == level1) & (summary["level2"] == level2)]["accuracy"].values[0]

        summary.loc[(summary["level1"] == level1) & (summary["level2"] == level2), "accuracy"] = (old_accuracy*old_num_samples + df.loc[i, "accuracy"]) / (old_num_samples + 1)
        summary.loc[(summary["level1"] == level1) & (summary["level2"] == level2), "num_samples"] += 1
        summary.loc[(summary["level1"] == level1) & (summary["level2"] == level2), "total_cost"] += df.loc[i, "cost"]

        summary.loc[summary["level1"] == "total", "num_samples"] = summary.loc[summary["level1"] != "total", "num_samples"].sum()
        summary.loc[summary["level1"] == "total", "total_cost"] = summary.loc[summary["level1"] != "total", "total_cost"].sum()
        acc = 0
        for i in range(len(summary)):
            if summary.loc[i, "level1"] != "total":
                acc += summary.loc[i, "accuracy"] * summary.loc[i, "num_samples"]
        summary.loc[len(summary)-1, "accuracy"] = acc / summary.loc[len(summary)-1, "num_samples"]

    summary.to_csv(summary_path, index=False)
    wandb.finish()
    index += 1

def main(args):
    config = load_config()
    conversations = load_conversations()

    config["model"] = args.model
    config["resume"] = args.resume
    config["few_shot"] = args.few_shot
    config["num_processes"] = args.num_processes

    for conversation in conversations["conversations_by_topic"]:
        topic = conversation["topic"]
        name1 = conversation["main_speaker"]["name"]
        background1 = conversation["main_speaker"]["background"]
        dialogues = conversation["dialogues"]

        levels = []
        for dialogue in dialogues:
            level = dialogue["other_speaker"]["class"]
            levels.append(level)

        ask_messages = []
        answer_messages = []
        for i in range(len(levels)):
            for j in range(len(levels)):
                if i == j:
                    continue
                level1 = levels[i]
                level2 = levels[j]
                dialogue1 = dialogues[i]
                dialogue2 = dialogues[j]

                answer_message, gt = generate_answer_messages(topic, name1, background1, level1, level2, dialogue1, dialogue2, levels, config)
                answer_messages.append((answer_message, gt, level1, level2))

                ask_message, gt = generate_ask_messages(topic, name1, background1, level1, level2, dialogue1, dialogue2, levels, config)
                ask_messages.append((ask_message, gt, level1, level2))

        if len(ask_messages) > 0:
            play_multiprocess(ask_messages, "ask", levels, topic, config)
        if len(answer_messages) > 0:
            play_multiprocess(answer_messages, "answer", levels, topic, config)
        
        # break

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process state information.")
    parser.add_argument("--model", type=str, default="gpt-4o", help="Model to use")
    parser.add_argument("--num_processes", type=int, default=10, help="Number of processes to use")
    parser.add_argument("--resume", action="store_true", help="Resume the experiment")
    parser.add_argument("--few_shot", action="store_true", help="Use few-shot prompt")

    args = parser.parse_args()
    
    main(args)