import os
import sys
import ast
import json 
import yaml
import argparse
from tqdm import tqdm
from pathlib import Path
from random import random
from dataclasses import dataclass
from typing import Literal, Optional, Union, Tuple
import random
from rich import print
import time
# from src.openai_utils import OpenAI
from openai import OpenAI
# from utils import make_chat_call

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="redacted")
    parser.add_argument("--output_dir", type=str, default="redacted")
    parser.add_argument("--prompt_name", type=str, default="Zeroshot_Prompt_Memorize")
    parser.add_argument("--base_model_name", type=str, default="gpt4o")
    args = parser.parse_args()
        
    with open('redacted') as f:
        dataset = json.load(f)
    
    with open("redacted", "r") as file:
        template = yaml.safe_load(file)
    
    prompt_template = template["Zeroshot_Prompt"]
    print(f"prompt_template: {prompt_template}")
    
    output_dir = Path(args.output_dir, f"Zeroshot_Prompt_o1")
    output_dir.mkdir(parents=True, exist_ok=True)
    output_file = output_dir / "outputs.jsonl"
    eval_file = output_dir / "eval.jsonl"
    
    dataset = dataset[0:10]
    
    acc = []
    with open(output_file, "w") as f: 
        for datapoint in tqdm(dataset, total=len(dataset)):
            test = datapoint['test']
            fake = datapoint['fakes']
            test_examples = test + fake
            random.shuffle(test_examples)
            train_examples = '\n'.join(datapoint['train'])
            
            completions = []
            accuracy = []
            for test_prompt in tqdm(test_examples, total=len(test_examples), desc="Test examples"):
                prompt = prompt_template.format(train_examples=train_examples, test_example=test_prompt)
                client = OpenAI()
                client.api_key = "redacted"
                response = client.chat.completions.create(
                        model="o1-preview",
                        messages=[
                            {
                                "role": "user",
                                "content": [
                                    {
                                        "type": "text",
                                        "text": prompt
                                    },
                                ],
                            }
                        ],
                    )
            
                output = response.choices[0].message.content 
                print(f"output: {output}")
                completions.append(output)
                if test_prompt in test:
                    if 'yes' in output.lower():
                        accuracy.append(1)
                    elif 'no' in output.lower():
                        accuracy.append(0)
                    else:
                        accuracy.append(0)
                elif test_prompt in fake:
                    if 'yes' in output.lower():
                        accuracy.append(0)
                    elif 'no' in output.lower():
                        accuracy.append(1)
                    else:
                        accuracy.append(0)
                else:
                    accuracy.append(0)
            final_accuracy = sum(accuracy) / len(accuracy)
            acc.append(final_accuracy)
            print(f"final_accuracy: {final_accuracy}")
            output = dict(
                test_examples=test_examples,
                train_examples=train_examples,
                completions=completions,
                accuracy=accuracy,
                final_accuracy=final_accuracy
            )
            f.write(json.dumps(output) + "\n")
            
        metrics = dict(
            accuracy=sum(acc) / len(acc)
        )
        with open(eval_file, "w") as f:
            f.write(json.dumps(metrics) + "\n")   
        
