import os
import sys
import ast
import json 
import yaml
import argparse
from tqdm import tqdm
from pathlib import Path
from random import random
from dataclasses import dataclass
from typing import Literal, Optional, Union, Tuple
import random
from rich import print
import time
from openai import OpenAI
from utils import make_chat_call


def inference(args, client, template=None, queries=None, groupa1=None, groupa2=None):
    
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    output_file = output_dir / "output.jsonl"
    if os.path.exists(output_file):
        check = input(f"File {output_file} already exists. Overwrite? (y/n): ")
        if check.lower() != "y":
            sys.exit()
    
    with open(output_file, "w") as f:
        for query in tqdm(queries, total=len(queries)):
            prompt = template.format(query=query)
            if query in groupa1:
                groudtruth = "A1"
            else:
                groudtruth = "A2"
            outputs = make_chat_call(client, args.model_name, prompt, args.max_tokens)
            for out in outputs.choices:
                generated_text = out.message.content.strip()
                print(f"generated_text: {generated_text}")

            output = dict(
                query=query,
                prompt=prompt,
                generated_text=generated_text,
                groudtruth=groudtruth
            )
            f.write(json.dumps(output) + "\n")
    
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="gpt-4o-2024-05-13")
    parser.add_argument("--output_dir", type=str, default="GPT_output")
    parser.add_argument("--task", type=str, default="agl")
    parser.add_argument("--max_tokens", type=int, default=1000)
    args = parser.parse_args()
    
    if args.task == "agl":
        with open("configs/prompts.yaml", "r") as file:
            template = yaml.safe_load(file)
        
        prompts_templates = "## Instruction: " + template["instruction"] + "\n\n" + template["study_stimuli"] + "\n\n" + template["question"]
    
        with open("configs/agl.yaml", "r") as file:
            test = yaml.safe_load(file)
        
        group_a1 = test['queries']['A1']
        group_a2 = test['queries']['A2']
        queries = group_a1 + group_a2
        # random shuffle the queries 
        random.shuffle(queries)

        if args.model_name == "gpt-4o-2024-05-13":
            api_key_path = 'redacted'
            with open(os.path.expanduser(api_key_path), 'r') as f:
                api_key = f.read().strip()
            client = OpenAI(api_key=api_key)

        inference(args, client, prompts_templates, queries, group_a1, group_a2)
   
    
    
    
    

