import argparse
import random
import os
import json
from openai import OpenAI
from together import Together
import anthropic
from dotenv import load_dotenv
import yaml

load_dotenv()
OpenAI_API_KEY = os.getenv("OpenAI_API_KEY")
Anthropic_API_KEY = os.getenv("Anthropic_API_KEY")
Gemini_API_KEY = os.getenv("Gemini_API_KEY")
Together_API_KEY = os.getenv("Together_API_KEY")

def init_client(args):
    if "gpt" in args.model:
        return OpenAI(api_key=OpenAI_API_KEY)
    elif "gemini" in args.model:
        return OpenAI(api_key=Gemini_API_KEY, base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
    elif "claude" in args.model:
        return anthropic.Anthropic(api_key=Anthropic_API_KEY)
    else:
        return Together(api_key=Together_API_KEY)
    
def chat_completion(args, client, messages, max_tokens=500, temperature=1):
    if args.alter_temperature:
        temperature = args.temperature
    if "claude" in args.model:
        # Anthropic Claude API
        return client.messages.create(
            max_tokens = 500,
            messages=messages,
            model=args.model,
            temperature=temperature
        ).content[0].text
    else:
        # OpenAI and Together APIs (assumes OpenAI-style client)
        return client.chat.completions.create(
            model=args.model,
            messages=messages,
            temperature=temperature
        ).choices[0].message.content
    
def load_yaml_file(file_path):
    # load the yaml file
    with open(file_path, 'r') as file:
        data = yaml.safe_load(file)
    
    return data

prompt_path = f"prompts.yaml"
prompts = load_yaml_file(prompt_path)

def get_opinions(args, client):
    job_list = ["Lawyers", "Financial Advisors", "Managers", "Bankers", "Politicians", "Doctors", "Psychiatrists", "Veterinarians", "Teachers", "Professors", "Childcare Aides", "Receptionists", "Rehabilitation Counselors", "Waiters", "Homemakers", "Janitors", "Custodians", "Garbage Collectors", "Dishwashers", "Cashiers"]
    for _ in range(7): random.shuffle(job_list)

    job_prob_judgments = []
    messages = []

    for job_num, job in enumerate(job_list):
        if job_num == 0:
            job_prompt = prompts["job_opinion_prompt_initial"].format(job=job[:-1].lower())
        else:
            job_prompt = prompts["job_opinion_prompt_subsequent"].format(job=job[:-1].lower())

        messages.append({"role": "user", "content": job_prompt})
        print(job_prompt)
        model_response = chat_completion(args, client, messages)
        print(model_response)
        messages.append({"role": "assistant", "content": model_response})

        job_prob_judgments.append({"job": job, "model_response": model_response})

    return job_prob_judgments, messages
    
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="gpt-4o")
    parser.add_argument("--alter_temperature", action="store_true")
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--num_runs", type=int, default=30)
    return parser.parse_args()

def main(args):
    client = init_client(args)

    folder_path_name = os.path.join("job_success_probability_guesses", f"{args.model}")
    os.makedirs(folder_path_name, exist_ok=True)

    # Get the list of existing run files (assumes naming like `run_0.json`, `run_1.json`, ...)
    existing_runs = [
        int(filename.split('_')[1].split('.')[0])
        for filename in os.listdir(folder_path_name)
        if filename.startswith("run_") and filename.endswith(".json")
    ]
    completed_runs = max(existing_runs) + 1 if existing_runs else 0

    for run_id in range(completed_runs, args.num_runs):
        print(f"Starting run {run_id}")
        job_prob_judgments, messages = get_opinions(args, client)

        # Save results
        run_data = {
            "job_prob_judgments": job_prob_judgments,
            "messages": messages,
        }
        with open(os.path.join(folder_path_name, f"run_{run_id}.json"), "w") as f:
            json.dump(run_data, f, indent=2)

if __name__ == "__main__":
    args = parse_args()
    main(args)