import Util_import
from utils.Util_print import print_elapsed_time, print_separator
from openai import OpenAI
import json, sys

client = OpenAI()
# MODEL_NAME = "gpt-4o-mini"
# MODEL_NAME = "gpt-4o"
# MODEL_NAME = "gpt-4-turbo"

MODEL_NAME = "gpt-4o-2024-08-06"
# MODEL_NAME = "gpt-4-turbo-2024-04-09"


prior_excellent = 1 / 3
prior_weak = 2 / 3


@print_elapsed_time
def query_GPT(conversation_history, model_name=MODEL_NAME):
    assert conversation_history
    completion = client.chat.completions.create(
        model=model_name,
        messages=conversation_history,
        # response_format={"type": "json_object"},
        max_tokens=2048,
        n=1,  # Generate 1 answer.
        temperature=0,
    )
    return completion


from prompts_persuasion_revelation import (
    agent_assumption_system_prompt,
    persuasion_student_task_description_system_prompt,
    commitment_student_system_prompt,
    ultimatum_but_may_meet_again_system_prompt,
    revelation_student_system_prompt,
    sender_proposal_user_prompt,
    proposal_template,
    receiver_check_student_user_prompt,
    receiver_check_product_user_prompt
)

from prompts_persuasion_revelation import (
    persuasion_product_task_description_system_prompt,
    commitment_product_system_prompt,
    revelation_product_system_prompt,
)


def extract_data(text):
    try:
        json_data = json.loads(text.replace("\\", "\\\\"))
    except json.JSONDecodeError:
        print("Failed to decode JSON.")
        sys.exit()
    return json_data


def calculate_payoffs(
    prob_score1_weak, prob_score1_excellent, prob_hire_score0, prob_hire_score1
):

    prob_hire_excellent = (
        prior_excellent * prob_score1_excellent * prob_hire_score1
        + prior_excellent * (1 - prob_score1_excellent) * prob_hire_score0
    )

    prob_hire_weak = (
        prior_weak * prob_score1_weak * prob_hire_score1
        + prior_weak * (1 - prob_score1_weak) * prob_hire_score0
    )

    sender_payoff = prob_hire_excellent + prob_hire_weak
    receiver_payoff = prob_hire_excellent - prob_hire_weak

    return sender_payoff, receiver_payoff


@print_elapsed_time
def sender_propose(init_conversation_history):

    sender_query_input = init_conversation_history + [
        {"role": "user", "content": sender_proposal_user_prompt}
    ]

    sender_output_raw = query_GPT(sender_query_input)
    sender_output_text = sender_output_raw.choices[0].message.content
    print(sender_output_text)

    sender_output = extract_data(sender_output_text)
    # eta = float(sender_output["Signaling Scheme"])
    eta = sender_output["Signaling Scheme"]

    proposal_text = proposal_template.format(
        sender_payoff=(1 + 2 * eta) / 3, receiver_payoff=(1 - 2 * eta) / 3, eta=eta
    )
    return proposal_text, eta


def receiver_reacts(task_type, init_conversation_history, proposal_text):

    if task_type == "student":
        responder_query_input = init_conversation_history + [
            {"role": "user", "content": proposal_text},
            {"role": "user", "content": receiver_check_student_user_prompt},
        ]
    else:
        responder_query_input = init_conversation_history + [
            {"role": "user", "content": proposal_text},
            {"role": "user", "content": receiver_check_product_user_prompt},
        ]

    receiver_output_raw = query_GPT(responder_query_input)
    receiver_output_text = receiver_output_raw.choices[0].message.content
    print(receiver_output_text)

    receiver_output = extract_data(receiver_output_text)
    if task_type == "student":
        prob_hire_score0 = receiver_output["Probability of hiring upon scoring 0"]
        prob_hire_score1 = receiver_output["Probability of hiring upon scoring 1"]
    else:
        prob_hire_score0 = receiver_output["Probability of buying upon scoring 0"]
        prob_hire_score1 = receiver_output["Probability of buying upon scoring 1"]
    return prob_hire_score0, prob_hire_score1


@print_elapsed_time
def persuasion_ultimatum(task_type):
    assert task_type in ["student", "product"]

    if task_type == "student":
        init_conversation_history = [
            {"role": "system", "content": agent_assumption_system_prompt},
            {
                "role": "system",
                "content": persuasion_student_task_description_system_prompt,
            },
            {"role": "system", "content": commitment_student_system_prompt},
            {"role": "system", "content": ultimatum_but_may_meet_again_system_prompt},
            {"role": "system", "content": revelation_student_system_prompt},
        ]
    else:
        init_conversation_history = [
            {"role": "system", "content": agent_assumption_system_prompt},
            {
                "role": "system",
                "content": persuasion_product_task_description_system_prompt,
            },
            {"role": "system", "content": commitment_product_system_prompt},
            {"role": "system", "content": ultimatum_but_may_meet_again_system_prompt},
            {"role": "system", "content": revelation_product_system_prompt},
        ]

    proposal_text, eta = sender_propose(init_conversation_history)

    prob_score1_weak = eta
    prob_score1_excellent = 1

    print_separator("=")
    print("Proposal:\n", proposal_text)
    print_separator("=")

    prob_hire_score0, prob_hire_score1 = receiver_reacts(
        task_type, init_conversation_history, proposal_text
    )

    print_separator("=")

    if task_type == "student":
        result = {
            "signaling scheme": {
                "prob_score1_weak": prob_score1_weak,
                "prob_score1_excellent": prob_score1_excellent,
            },
            "action rule": {
                "prob_hire_score0": prob_hire_score0,
                "prob_hire_score1": prob_hire_score1,
            },
        }
    else:
        result = {
            "signaling scheme": {
                "prob_score1_bad": prob_score1_weak,
                "prob_score1_good": prob_score1_excellent,
            },
            "action rule": {
                "prob_buy_score0": prob_hire_score0,
                "prob_buy_score1": prob_hire_score1,
            },
        }
    # formatted_result = json.dumps(result, indent=4, sort_keys=True)
    formatted_result = json.dumps(result, indent=4)
    print("Final decisions:\n", formatted_result)

    sender_payoff, receiver_payoff = calculate_payoffs(
        prob_score1_weak,
        prob_score1_excellent,
        prob_hire_score0,
        prob_hire_score1,
    )
    print(f"Sender's expected payoff: {sender_payoff:.3f}")
    print(f"Receiver's expected payoff: {receiver_payoff:.3f}")
    social_welfare = sender_payoff + receiver_payoff
    print(f"Social welfare: {social_welfare:.3f}")

    return (result, sender_payoff, receiver_payoff, social_welfare)


@print_elapsed_time
def execute_experiment(game, args=[], times=1):
    results = []
    for i in range(times):
        print_separator()
        print("i: ", i)
        inner_result = game(*args)
        results.append(inner_result)
    print_separator()
    return results


# ========================================


from prompts_persuasion_revelation import (
    alternating_offer_system_prompt,
    alternating_offer_history_record_prompt,
)
import random
import numpy as np


def initialize_t_max():
    shape = 1.0  # k
    scale = 1.2  # theta
    size = 1

    samples = np.random.gamma(shape, scale, size)
    rounded_samples = np.ceil(samples)

    least_step = 5
    t_max = np.clip(rounded_samples + least_step - 1, least_step, least_step + 10)

    return t_max


def alternating_offer_persuasion(task_type):
    assert task_type in ["student", "product"]

    if task_type == "student":
        init_conversation_history = [
            {"role": "system", "content": agent_assumption_system_prompt},
            {
                "role": "system",
                "content": persuasion_student_task_description_system_prompt,
            },
            {"role": "system", "content": commitment_student_system_prompt},
            {"role": "system", "content": alternating_offer_system_prompt},
            {"role": "system", "content": revelation_student_system_prompt},
        ]
    else:
        init_conversation_history = [
            {"role": "system", "content": agent_assumption_system_prompt},
            {
                "role": "system",
                "content": persuasion_product_task_description_system_prompt,
            },
            {"role": "system", "content": commitment_product_system_prompt},
            {"role": "system", "content": alternating_offer_system_prompt},
            {"role": "system", "content": revelation_product_system_prompt},
        ]

    agents_conversation_history = {
        "sender": init_conversation_history.copy(),
        "receiver": init_conversation_history.copy(),
    }

    deal = False
    timestep = 0
    t_max = initialize_t_max()
    while not deal and timestep < t_max:

        proposer_query_prompt = agents_conversation_history["sender"] + [
            {"role": "user", "content": f"The current timestep is {timestep}."},
            {"role": "user", "content": sender_proposal_user_prompt},
        ]

        sender_output_raw = query_GPT(proposer_query_prompt)
        sender_output_text = sender_output_raw.choices[0].message.content
        print(sender_output_text)

        sender_output = extract_data(sender_output_text)
        # eta = float(sender_output["Signaling Scheme"])
        eta = sender_output["Signaling Scheme"]

        prob_score1_weak = eta
        prob_score1_excellent = 1

        proposal_text = proposal_template.format(
            sender_payoff=(1 + 2 * eta) / 3, receiver_payoff=(1 - 2 * eta) / 3, eta=eta
        )
        print_separator(".")

        if task_type == "student":
            responder_query_prompt = agents_conversation_history["receiver"] + [
                {"role": "user", "content": f"The current timestep is {timestep}."},
                {"role": "user", "content": proposal_text},
                {"role": "user", "content": receiver_check_student_user_prompt},
            ]
        else:
            responder_query_prompt = agents_conversation_history["receiver"] + [
                {"role": "user", "content": f"The current timestep is {timestep}."},
                {"role": "user", "content": proposal_text},
                {"role": "user", "content": receiver_check_product_user_prompt},
            ]

        receiver_output_raw = query_GPT(responder_query_prompt)
        receiver_output_text = receiver_output_raw.choices[0].message.content
        print(receiver_output_text)
        print_separator("-")

        receiver_output = extract_data(receiver_output_text)
        if task_type == "student":
            prob_hire_score0 = receiver_output["Probability of hiring upon scoring 0"]
            prob_hire_score1 = receiver_output["Probability of hiring upon scoring 1"]
        else:
            prob_hire_score0 = receiver_output["Probability of buying upon scoring 0"]
            prob_hire_score1 = receiver_output["Probability of buying upon scoring 1"]

        sender_payoff, receiver_payoff = calculate_payoffs(
            prob_score1_weak,
            prob_score1_excellent,
            prob_hire_score0,
            prob_hire_score1,
        )

        responder_decision = receiver_output[
            "Satisfied with the current committed signaling scheme and its corresponding reward outcome"
        ]
        if responder_decision == "yes":
            decision_verb = "accepted"
            deal = True
        elif responder_decision == "no":
            decision_verb = "rejected"
        else:
            raise Exception("LLM output error.")

        record = {
            "role": "system",
            "content": alternating_offer_history_record_prompt.format(
                timestep=timestep,
                eta=eta,
                sender_payoff=sender_payoff,
                receiver_payoff=receiver_payoff,
                receiver_verb=decision_verb,
            ),
        }

        agents_conversation_history["sender"] += [record.copy()]
        agents_conversation_history["receiver"] += [record.copy()]

        timestep += 1

    if task_type == "student":
        result = {
            "signaling scheme": {
                "prob_score1_weak": prob_score1_weak,
                "prob_score1_excellent": prob_score1_excellent,
            },
            "action rule": {
                "prob_hire_score0": prob_hire_score0,
                "prob_hire_score1": prob_hire_score1,
            },
        }
    else:
        result = {
            "signaling scheme": {
                "prob_score1_bad": prob_score1_weak,
                "prob_score1_good": prob_score1_excellent,
            },
            "action rule": {
                "prob_buy_score0": prob_hire_score0,
                "prob_buy_score1": prob_hire_score1,
            },
        }
    # formatted_result = json.dumps(result, indent=4, sort_keys=True)
    formatted_result = json.dumps(result, indent=4)
    print("Final decisions:\n", formatted_result)

    sender_payoff, receiver_payoff = calculate_payoffs(
        prob_score1_weak,
        prob_score1_excellent,
        prob_hire_score0,
        prob_hire_score1,
    )
    print(f"Sender's expected payoff: {sender_payoff:.3f}")
    print(f"Receiver's expected payoff: {receiver_payoff:.3f}")
    social_welfare = sender_payoff + receiver_payoff
    print(f"Social welfare: {social_welfare:.3f}")

    return (result, sender_payoff, receiver_payoff, social_welfare)


if __name__ == "__main__":
    # task_type = "student"
    task_type = "product"
    assert task_type in ["student", "product"] # student = a, product = b

    # ==============================
    # persuasion (2)
    # ==============================

    # results = execute_experiment(persuasion_ultimatum, [task_type], times=30) # (1b)

    # ==============================
    # persuasion (3)
    # ==============================

    results = execute_experiment(alternating_offer_persuasion, [task_type], times=30) # (2b)

    # ==============================
    # print results
    # ==============================

    print_separator()
    print("Results: eta")
    for inner_result in results:
        if task_type == "student":
            print(inner_result[0]["signaling scheme"]["prob_score1_weak"])
        else:
            print(inner_result[0]["signaling scheme"]["prob_score1_bad"])

    print_separator()
    print("Results: sender_payoff")
    for inner_result in results:
        print(inner_result[1])

    print_separator()
    print("Results: receiver_payoff")
    for inner_result in results:
        print(inner_result[2])

    print_separator()
    print("Results: social_welfare")
    for inner_result in results:
        print(inner_result[3])