
# LLM Inference for Symptom Extraction
import os
import json
import time
import pandas as pd
import openai
import tiktoken
from dotenv import load_dotenv
from openai import AzureOpenAI


def load_environment(env_path):
    # Load environment variables
    load_dotenv(env_path)
    API_KEY = os.environ.get('API_KEY')
    API_VERSION = os.environ.get('API_VERSION')
    RESOURCE_ENDPOINT = os.environ.get('RESOURCE_ENDPOINT')
    openai.api_type = "azure"
    openai.api_base = RESOURCE_ENDPOINT
    openai.api_version = API_VERSION
    openai.api_key = API_KEY
    print(API_VERSION, RESOURCE_ENDPOINT)

    client = AzureOpenAI(
        api_key=API_KEY,
        api_version=API_VERSION,
        azure_endpoint=RESOURCE_ENDPOINT,
    )
    return client

def num_tokens(string: str, encoding_name = "cl100k_base") -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.get_encoding(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens

def run_inference(client, file_path, prompt_path, output_path):
    test_df = pd.read_csv(file_path)
    max_tokens = 8192
    model_dict = {
        'gpt-4o': 'gpt-4o-2024-08-06'
    }

    prompt_instructions = open(prompt_path).read()

    for model_id, deployment_id in model_dict.items():
        for index, row in test_df.iterrows():
            
            input_text = row['Report Text']
            input_tokens = num_tokens(input_text) 
            current_tokens = input_tokens
            if current_tokens <= max_tokens:
                success = False
                attempts = 0
                start = time.time()
                while not success and attempts < 3:
                    try:
                        response = client.chat.completions.create(
                            model=deployment_id,
                            messages=[
                                {"role": "system", "content": prompt_instructions},
                                {"role": "user", "content": input_text}
                            ],
                            top_p=1,
                            seed=8152,
                            temperature=0.5,
                            max_tokens=max_tokens - current_tokens,
                        )
                        res = json.loads(response.model_dump_json(indent=2))
                        output_text = res['choices'][0]['message']['content']
                        output_tokens = num_tokens(output_text)
                        try:
                            output_json = json.loads(response.model_dump_json(indent=2))
                            output_json = res['choices'][0]['message']['content']
                            output_tokens = num_tokens(str(output_json))
                            test_df.at[index, 'response_json'] = str(output_json)
                            test_df.at[index, 'input_tokens'] = input_tokens
                            test_df.at[index, 'output_tokens'] = output_tokens
                            success = True
                            end = time.time()
                            print("The time of execution is :", (end - start), "s")
                        except json.JSONDecodeError:
                            print(f"Attempt {attempts + 1} failed: output_text is not in JSON format")
                    except Exception as e:
                        print("An error occurred:", e)
                    attempts += 1
                    if not success:
                        time.sleep(5)
        test_df.to_csv(output_path)


if __name__ == "__main__":
    client = load_environment(".env")
    run_inference(client, "data/test.csv", "prompts/ontreatment/binary_encoding_prompt.txt", "output/ontreatment/binary_encoding_output.csv")