# the openai tester API

from typing import Literal
from openai import OpenAI
import time
from pprint import pprint
from benchmark.api import BenchmarkManager
from const import COST_DICT
from pipeline.utils import find_id_in_label, print_colored
import os
from tqdm import tqdm
import time
import json
from pathlib import Path
from argparse import ArgumentParser
from ..model_prototype import ModelPrototype
not_supported_id = []


class OpenAIModel(ModelPrototype):
    def __init__(self, model_name = "gpt-4-0125-preview", enable_vector_store = False, max_prompt_tokens=50000, max_completion_tokens=50000) -> None:
        super().__init__()
        self.client = OpenAI()
        # self.client = OpenAI(base_url = "https://api.gptsapi.net/v1") # use wildcard domain
        self.model_name = model_name
        self.enable_vector_store = enable_vector_store
        self.support_file_format = ["c", "cpp", "css", "csv", "docx", "gif", "html", "java", "jpeg", "jpg", "js", "json", "md", "pdf", "php", "png", "pptx", "py", "rb", "tar", "tex", "ts", "txt", "xlsx", "xml", "zip", "yaml", "LICENSE", "ipynb"]
        self.max_prompt_tokens = max_prompt_tokens
        self.max_completion_tokens = max_completion_tokens
        
    def _upload_files(self, file_paths):
        uploaded_files_id = []
        
        for file_path in file_paths:
            with open(file_path, 'rb') as file:
                uploaded_file = self.client.files.create(file=file, purpose='assistants')
                print("uploaded file ======")
                print_colored(f"id: {uploaded_file.id} | name: {uploaded_file.filename}", "green")
                print("========")
                uploaded_files_id.append(uploaded_file.id)
        
        return uploaded_files_id

    
    
    def run(self, input, file_paths):
        uploaded_files_id = self._upload_files(file_paths)
        if self.enable_vector_store:
            tools = [{"type": "file_search"},{"type": "code_interpreter"}]
            # Create a vector store caled "Financial Statements"
            vector_store = self.client.beta.vector_stores.create(name="files storage")
            
            file_streams = [open(path, "rb") for path in file_paths]
            
            # Use the upload and poll SDK helper to upload the files, add them to the vector store, and poll the status of the file batch for completion.
            file_batch = self.client.beta.vector_stores.file_batches.upload_and_poll(
                vector_store_id=vector_store.id, files=file_streams
            )
            
            # You can print the status and the file counts of the batch to see the result of this operation. 
            print(file_batch.status)
            print(file_batch.file_counts)
            
            status = vector_store.status
            
            while status != 'completed':
                time.sleep(1)
                print(status)
                status = self.client.beta.vector_stores.retrieve(vector_store.id).status

            
            tool_resouces = {
                "code_interpreter": {
                "file_ids": uploaded_files_id
                },
                "file_search": {
                "vector_store_ids": [vector_store.id]
                }
            }
        else:

            tools = [{"type": "code_interpreter"}]
            tool_resouces = {
                "code_interpreter": {
                "file_ids": uploaded_files_id
                }
            }
            
        self.assistant = self.client.beta.assistants.create(
            model=self.model_name,
            tools=tools,
            tool_resources=tool_resouces,
            temperature = 1 # 0~2, default to 1
        )
        
        # msg_file_ids = []
        # for file_path in file_paths:
        #     message_file = self.client.files.create(
        #     file=open(file_path, "rb"), purpose="assistants"
        #     )
        #     msg_file_ids.append(message_file.id)
        
        
        # attachments = [{"id": file_id, "tools": [{"type": "file_search"}]} for file_id in msg_file_ids]
        
        thread = self.client.beta.threads.create(
        messages=[
            {
            "role": "user",
            "content": input,
            # Attach the new file to the message.
            # "attachments": attachments
            }
        ]
        )
        
        run = self.client.beta.threads.runs.create(
        thread_id=thread.id,
        assistant_id=self.assistant.id,
        max_prompt_tokens=self.max_prompt_tokens,
        max_completion_tokens=self.max_completion_tokens,
        )

        while run.status in ['queued', 'in_progress', 'cancelling']:
            time.sleep(1) # Wait for 1 second
            # print(run.status)
            run = self.client.beta.threads.runs.retrieve(
                thread_id=thread.id,
                run_id=run.id
            )
            
            if run.status == 'completed': 
                messages = self.client.beta.threads.messages.list(
                    thread_id=thread.id
                )
                # pprint(messages) # count session number with code interpreter
                try:
                    output = ""
                    for message in messages.data[:-1]:
                        output = message.content[0].text.value + output
                    completion_cost= run.usage.completion_tokens * COST_DICT[self.model_name]["output"] / 1000
                    prompt_cost = run.usage.prompt_tokens * COST_DICT[self.model_name]["input"] / 1000
                    
                    total_cost = completion_cost + prompt_cost + COST_DICT['code_interpreter']
                    print(f"cost: {total_cost}")
                    print(f"output: {output}")
                    return output, total_cost
                except:
                    return "Failed Output", 0
            else:
                pass
                
        self.clean_files(self.assistant.id, thread.id)
    
    def clean_files(self, assistant_id, thread_id):
        # assistant_files = self.client.beta.assistants.files.list(assistant_id=assistant_id)
        # for file in assistant_files:
        #     self.client.files.delete(file_id=file.id)

        messages = self.client.beta.threads.messages.list(
        thread_id=thread_id
        )

        for message in messages: 
            if hasattr(self.client.beta.threads.messages, "files"):
                message_files = self.client.beta.threads.messages.files.list(
                    thread_id=thread_id,
                    message_id=message.id
                )
                for file in message_files:
                    self.client.files.delete(file_id=file.id)

        self.client.beta.assistants.delete(assistant_id=assistant_id)
        self.client.beta.threads.delete(thread_id=thread_id)
        return


class OpenAIModelRephraser(ModelPrototype):
    def __init__(self, api_key = None, model_id: Literal["gpt-4-turbo-2024-04-09", "gpt-4-0125-preview", "gpt-3.5-turbo"] = "gpt-4-0125-preview"):
        self.api_key = api_key
        self.temperature = 0 # default temperature is 1
        # self.api = OpenAI(base_url = "https://api.gptsapi.net/v1") # use wildcard domain
        self.api = OpenAI()
        self.model =  model_id # "gpt-4-turbo-2024-04-09" # "gpt-4-0125-preview"
        #TODO: pass model config (model, cache_dir, temperature) as config dict


    def run(self, system_msg, user_msg):
        input_message=[
            {"role": "system", "content": system_msg},
            {"role": "user", "content": user_msg},
        ]
        response = self.api.chat.completions.create(
            model=self.model,
            messages=input_message,
            temperature=self.temperature,
        )         
        completion_cost= response.usage.completion_tokens * COST_DICT[self.model]["output"] / 1000
        prompt_cost = response.usage.prompt_tokens * COST_DICT[self.model]["input"] / 1000
        response = response.choices[-1].message.content
        cost = completion_cost + prompt_cost
        return response, cost
    
    
    
    
if __name__ == "__main__":
    

    parser = ArgumentParser(description=__doc__)
    parser.add_argument(
        "--max_files",
        type=int,
        help="The maximum number of files that can be processed at once",
        default=10,
    )
    parser.add_argument(
        "--scenarios",
        nargs='+',
        help="The scenarios to process, usage ex. --scenarios Kaggle TensorFlow HuggingFace",
        default=["Kaggle", "TensorFlow", "HuggingFace"],
    )
    
    parser.add_argument(
        "--model_name",
        type=str,
        help="The model name to use",
        default="gpt-4-0125-preview",
    )
    
    parser.add_argument(
        "--output_dir",
        type=str,
        help="The output directory",
        default="pipeline/output",
    )
    
    parser = parser.parse_args()
    

    too_much_file = []

    output_dir = parser.output_dir
    model_name = parser.model_name

    manager = BenchmarkManager()    
    too_much_file = [ 'BigBench-4d8529a4-fc22-435e-b2fb-d13bc307b773']
    

    # modified_hints = ["c8b658f6-f5da-493c-9448-a1c7ad3579db", "453e8db5-4921-4c70-92d1-5cf7a888b497","49b43f03-7181-45ee-8c16-1f6a52c99281","3b02d0a3-d839-4743-bb67-44f969a217c2","9a0b60a0-ed2b-4e93-b6c8-b4f980c22716",
    #                   "cd5a39c9-f14b-4672-a58d-eafd98558832","5130db91-144c-472d-8650-f5b96d695c17","0a9c2e90-fcca-48c1-8afd-e1393d149388"]
    
    scenarios = manager.get_scenarios()
    # scenarios = ["GLI"]
    excluded_ids = []
    
    test_sample_path = "eval_test_samples.json"
    
    with open(test_sample_path, "r") as f:
        test_samples = json.load(f)
        
        
    total_cost = 0
    total_model_time = 0
    process_sample_num = 0 
    
    for scenario in scenarios:
        
        ids = manager.get_written_input_ids(scenario)
        if os.path.exists(os.path.join(output_dir, model_name, scenario)):
            processed_ids = [id for id in os.listdir(os.path.join(output_dir, model_name, scenario)) if not id.startswith(".") and len(os.listdir(os.path.join(output_dir, model_name, scenario, id)))==4]
            processed_ids = []
            inputs_id = [id for id in ids if (id not in processed_ids and id not in excluded_ids)]
        else:
            inputs_id = [id for id in ids if id not in excluded_ids]


        model = OpenAIModel(model_name=model_name)

        id_bar = tqdm(inputs_id, total=len(inputs_id))
        # id_bar = tqdm(modified_hints, total=len(modified_hints))
        for id in id_bar:
            files = manager.get_files(scenario, id, flat=True)
            # print(files)
            for file in files:
                if file.split(".")[-1] not in model.support_file_format:
                    not_supported_id.append(f"{scenario}-{id}")
                    continue
            inputs_prompt = manager.get_written_input_of_id(scenario, id)
            if len(files) > parser.max_files:
                too_much_file.append(f"{scenario}-{id}")
                continue
            for input_path in inputs_prompt:
                hint_level = input_path.split("_")[-1].split(".")[0]
                # if id not in test_samples[str(hint_level)][scenario]:
                #     continue
                if os.path.exists(os.path.join(output_dir, model_name, scenario, id, f"hint_level_{hint_level}")):
                    pprint(f"skip {scenario} {id} hint_level_{hint_level}")
                    continue
                id_bar.set_description(f"Processing {scenario} {id} hint_level_{hint_level}")
                with open(input_path, "r") as f:
                    input = f.read()
                start = time.time()
                output, cost = model.run(input=input, file_paths=files)
                t = time.time() - start
                total_model_time += t
                output_path = os.path.join(output_dir, model_name, scenario, id, f"hint_level_{hint_level}")
                if not os.path.exists(output_path):
                    os.makedirs(output_path, exist_ok=True)
                    
                with open(os.path.join(output_path,"output.txt"), "w") as f:
                    f.write(output)
                
                total_cost += cost
                process_sample_num += 1
                id_bar.set_postfix(cost=total_cost, average_cost=total_cost / process_sample_num, time=t, average_time=total_model_time / process_sample_num)

                    
    print("too much: ",too_much_file)
    print("file-format not support:", not_supported_id)
    # save_to_log
    with open("log.txt", "a") as f:
        f.write("too much file: ")
        f.write(str(too_much_file))
        f.write("\n")
        f.write("file-format not support:")
        f.write(str(not_supported_id))
            

