import Util_import
from utils.Util_print import print_elapsed_time, print_separator
from openai import OpenAI
import json, sys

client = OpenAI()
# MODEL_NAME = "gpt-4o-mini"
# MODEL_NAME = "gpt-4o"
# MODEL_NAME = "gpt-4-turbo"

MODEL_NAME = "gpt-4o-2024-08-06"
# MODEL_NAME = "gpt-4-turbo-2024-04-09"


@print_elapsed_time
def query_GPT(conversation_history, model_name=MODEL_NAME):
    assert conversation_history
    completion = client.chat.completions.create(
        model=model_name,
        messages=conversation_history,
        response_format={"type": "json_object"},
        max_tokens=2048,
        n=1,  # Generate 1 answer.
        temperature=0,
    )
    return completion


from prompts_bargaining_restriction import (
    agent_assumption_system_prompt,
    bargaining_task_description_system_prompt,
    ultimatum_system_prompt,
    ultimatum_but_may_meet_again_system_prompt,
    role_descriptions_system_prompt,
    proposal_template,
    proposer_role_assignment_user_prompt,
    responder_role_assignment_user_prompt,
)


@print_elapsed_time
def ultimatum_game(may_meet_again=False):

    init_conversation_history = [
        {"role": "system", "content": agent_assumption_system_prompt},
        {"role": "system", "content": bargaining_task_description_system_prompt},
        {
            "role": "system",
            "content": (
                ultimatum_system_prompt
                if not may_meet_again
                else ultimatum_but_may_meet_again_system_prompt
            ),
        },
        {"role": "system", "content": role_descriptions_system_prompt},
    ]

    proposer_conversation_history = init_conversation_history + [
        {"role": "user", "content": proposer_role_assignment_user_prompt},
    ]

    proposer_output = query_GPT(proposer_conversation_history)
    proposer_output_text = proposer_output.choices[0].message.content
    print(proposer_output_text)

    try:
        proposer_output = json.loads(proposer_output_text.replace("\\", "\\\\"))
    except json.JSONDecodeError:
        print("Failed to decode JSON.")
        sys.exit()
    x = proposer_output["Proposer's offer"]
    proposal = proposal_template.format(
        x=x, proposer_payoff=(1+2*x)/3, responder_payoff=(1-2*x)/3
    )
    print("Proposal:\n", proposal)

    print_separator("-")

    responder_user_prompt_format = init_conversation_history + [
        {"role": "user", "content": proposal},
        {"role": "user", "content": responder_role_assignment_user_prompt},
    ]
    responder_output = query_GPT(responder_user_prompt_format)
    responder_output_text = responder_output.choices[0].message.content
    print(responder_output_text)

    try:
        responder_output = json.loads(responder_output_text.replace("\\", "\\\\"))  # {"decision": "no"}
    except json.JSONDecodeError:
        print("Failed to decode JSON.")
        sys.exit()

    responder_decision = responder_output["Responder's decision"]
    if responder_decision == "yes":
        result = 1
    elif responder_decision == "no":
        result = 0
    else:
        raise Exception("LLM output error.")

    proposer_wants_payoff = (1+2*x)/3
    proposer_payoff = proposer_wants_payoff * result
    return (x, proposer_wants_payoff, proposer_payoff, result)


@print_elapsed_time
def execute_experiment(game, args=[], times=1):
    # count = 0
    results = []
    for i in range(times):
        print_separator()
        print("i: ", i)
        inner_result = game(*args)
        results.append(inner_result)
    print_separator()
    # print(f"Count: {count}/{times}")
    # print_separator()
    return results


# ========================================

from prompts_bargaining_restriction import (
    alternating_offer_system_prompt,
    alternating_offer_history_record_prompt,
)
import random
import numpy as np


def initialize_t_max():
    shape = 1.0  # k
    scale = 1.2  # theta
    size = 1

    samples = np.random.gamma(shape, scale, size)
    rounded_samples = np.ceil(samples)

    least_step = 5
    t_max = np.clip(rounded_samples + least_step - 1, least_step, least_step + 10)

    return t_max


def alternating_offer_bargaining():
    agents_conversation_history = {
        "0": [
            {"role": "system", "content": agent_assumption_system_prompt},
            {"role": "system", "content": bargaining_task_description_system_prompt},
            {"role": "system", "content": alternating_offer_system_prompt},
            {"role": "system", "content": role_descriptions_system_prompt},
            {
                "role": "system",
                "content": "Remember the agent indices: You are agent 0, and your opponent is agent 1.",
            },
        ],
        "1": [
            {"role": "system", "content": agent_assumption_system_prompt},
            {"role": "system", "content": bargaining_task_description_system_prompt},
            {"role": "system", "content": alternating_offer_system_prompt},
            {"role": "system", "content": role_descriptions_system_prompt},
            {
                "role": "system",
                "content": "Remember the agent indices: You are agent 1, and your opponent is agent 0.",
            },
        ],
    }

    proposer_index = random.randint(0, 1)
    responder_index = 1 - proposer_index

    deal = False
    timestep = 0
    t_max = initialize_t_max()
    # termination_prob = 0.9
    # while not deal and (timestep < 20 or random.random() > termination_prob):
    while not deal and timestep < t_max:

        proposer_query_prompt = agents_conversation_history[str(proposer_index)] + [
            {"role": "user", "content": f"The current timestep is {timestep}."},
            {"role": "user", "content": proposer_role_assignment_user_prompt},
        ]

        proposer_output = query_GPT(proposer_query_prompt)
        proposer_output_text = proposer_output.choices[0].message.content
        print(f"Agent {proposer_index}:\n")
        print(proposer_output_text)

        try:
            proposer_output = json.loads(proposer_output_text.replace("\\", "\\\\"))
        except json.JSONDecodeError:
            print("Failed to decode JSON.")
            sys.exit()
        x = proposer_output["Proposer's offer"]
        proposal = proposal_template.format(
            x=x, proposer_payoff=(1+2*x)/3, responder_payoff=(1-2*x)/3
        )
        print("Proposal:\n", proposal)
        print_separator(".")

        responder_query_prompt = agents_conversation_history[str(responder_index)] + [
            {"role": "user", "content": f"The current timestep is {timestep}."},
            {"role": "user", "content": proposal},
            {"role": "user", "content": responder_role_assignment_user_prompt},
        ]

        responder_output = query_GPT(responder_query_prompt)
        responder_output_text = responder_output.choices[0].message.content
        print(f"Agent {responder_index}:\n")
        print(responder_output_text)
        print_separator("-")

        try:
            responder_output = json.loads(responder_output_text.replace("\\", "\\\\"))  # {"decision": "no"}
        except json.JSONDecodeError:
            print("Failed to decode JSON.")
            sys.exit()

        responder_decision = responder_output["Responder's decision"]
        if responder_decision == "yes":
            decision_verb = "accepted"
            deal = True
        elif responder_decision == "no":
            decision_verb = "rejected"
        else:
            raise Exception("LLM output error.")

        agents_conversation_history[str(proposer_index)] += [
            {
                "role": "system",
                "content": alternating_offer_history_record_prompt.format(
                    timestep=timestep,
                    proposer_index=proposer_index,
                    proposer_who="you",
                    responder_index=responder_index,
                    responder_who="your opponent",
                    x=x,
                    proposer_payoff=(1+2*x)/3,
                    responder_payoff=(1-2*x)/3,
                    decision_verb=decision_verb,
                ),
            },
        ]

        agents_conversation_history[str(responder_index)] += [
            {
                "role": "system",
                "content": alternating_offer_history_record_prompt.format(
                    timestep=timestep,
                    proposer_index=proposer_index,
                    proposer_who="your opponent",
                    responder_index=responder_index,
                    responder_who="you",
                    x=x,
                    proposer_payoff=(1+2*x)/3,
                    responder_payoff=(1-2*x)/3,
                    decision_verb=decision_verb,
                ),
            },
        ]

        proposer_index = 1 - proposer_index
        responder_index = 1 - responder_index
        timestep += 1

    proposer_wants_payoff = (1+2*x)/3
    proposer_payoff = proposer_wants_payoff * int(deal)
    return (x, proposer_wants_payoff, proposer_payoff, int(deal))


if __name__ == "__main__":
    # ==============================
    # bargaining (1)
    # ==============================

    # results = execute_experiment(ultimatum_game, times=30)

    # ==============================
    # bargaining (2). Tell them that they might meet again in the future.
    # ==============================

    # results = execute_experiment(ultimatum_game, [True], times=30)

    # ==============================
    # bargaining (3)
    # ==============================

    results = execute_experiment(alternating_offer_bargaining, times=30)

    # ==============================
    # print results
    # ==============================

    print_separator()
    print("Results: x")
    for inner_result in results:
        print(inner_result[0])

    print_separator()
    print("Results: proposer wants payoff")
    for inner_result in results:
        print(inner_result[1])

    print_separator()
    print("Results: proposer payoff")
    for inner_result in results:
        print(inner_result[2])

    print_separator()
    print("Results: deal")
    for inner_result in results:
        print(inner_result[3])
