
# LLM Inference for Symptom Extraction
import os
import json
import time
import pandas as pd
import openai
import tiktoken
from dotenv import load_dotenv
from openai import AzureOpenAI


def load_environment(env_path):
    # Load environment variables
    load_dotenv(env_path)
    API_KEY = os.environ.get('API_KEY')
    API_VERSION = os.environ.get('API_VERSION')
    RESOURCE_ENDPOINT = os.environ.get('RESOURCE_ENDPOINT')
    openai.api_type = "azure"
    openai.api_base = RESOURCE_ENDPOINT
    openai.api_version = API_VERSION
    openai.api_key = API_KEY
    print(API_VERSION, RESOURCE_ENDPOINT)

    client = AzureOpenAI(
        api_key=API_KEY,
        api_version=API_VERSION,
        azure_endpoint=RESOURCE_ENDPOINT,
    )
    return client

def num_tokens(string: str, encoding_name = "cl100k_base") -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.get_encoding(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens

import math

def parse_logprob(token_logprobs):
    symptom_keys = [
        "(A): Diarrhea", "(B): Constipation", "(C): Nausea", "(D): Vomiting",
        "(E): Abdominal Pain", "(F): Abdominal Distension", "(G): Fatigue",
        "(H): Allergic reaction", "(I): Weight loss", "(J): Erythema",
        "(K): Hair loss", "(L): Neutropenia", "(M): Anemia",
        "(N): Abnormal liver function", "(O): Dyspnea", "(P): Appetite Loss",
        "(Q): Fever", "(R): Chills", "(S): Jaundice", "(T): Thrombocytopenia",
        "(U): Sensory Neuropathy", "(V): Motor Neuropathy", "(W): Cold-induced Neuropathy"
    ]

    prob_json = {
        symptom: {}  for symptom in symptom_keys
    }
    counter = 0
    for token_info in token_logprobs:
        if token_info.token in ['0', '1']:
            for alt in token_info.top_logprobs:
                if alt.token in ['0', '1']:
                    prob = math.exp(alt.logprob)
                    prob_json[symptom_keys[counter]][alt.token] = prob
            counter += 1
    return prob_json

def run_inference(client, file_path, prompt_path, output_path):
    test_df = pd.read_csv(file_path)
    max_tokens = 8192
    model_dict = {
        'gpt-4o': 'gpt-4o-2024-08-06'
    }

    prompt_instructions = open(prompt_path).read()

    for model_id, deployment_id in model_dict.items():
        for index, row in test_df.iterrows():
            
            input_text = row['Report Text']
            input_tokens = num_tokens(input_text) 
            current_tokens = input_tokens
            if current_tokens <= max_tokens:
                success = False
                attempts = 0
                start = time.time()
                while not success and attempts < 3:
                    try:
                        # Send the chat request
                        response = client.chat.completions.create(
                            model='gpt-4o-2024-08-06',
                            messages=[
                                        {"role": "system", "content": prompt_instructions},
                                        {"role": "user", "content": input_text}
                                    ],
                            temperature=0.0,
                            top_p=0,
                            max_tokens=8000,
                            logprobs=True,
                            top_logprobs=2,
                            stream=False,
                            n=1
                        )
                        # output_text = response['choices'][0]['message']['content']
                        res = json.loads(response.model_dump_json(indent=2))
                        output_text = res['choices'][0]['message']['content']
                        output_tokens = num_tokens(output_text)

                        try:
                            prob_json = parse_logprob(response.choices[0].logprobs.content)
                            output_json = json.loads(response.model_dump_json(indent=2))
                            output_json = res['choices'][0]['message']['content']
                            output_tokens = num_tokens(str(output_json))

                            test_df.at[index, 'response_json'] = str(output_json)
                            test_df.at[index, 'probability_json'] = str(prob_json)
                            test_df.at[index, 'input_tokens'] = input_tokens
                            test_df.at[index, 'output_tokens'] = output_tokens

                            success = True
                            end = time.time()
                            print("The time of execution is :", (end - start), "s")
                        except json.JSONDecodeError:
                            print(f"Attempt {attempts + 1} failed: output_text is not in JSON format")

                    except Exception as e:
                        print("An error occurred:", e)
                    
                    attempts += 1
                    if not success:
                        time.sleep(5)
        test_df.to_csv(output_path)


if __name__ == "__main__":
    client = load_environment(".env")
    run_inference(client, "data/additional_data_ontreatment.csv", "prompts/ontreatment/binary_encoding_prompt.txt", "output/ontreatment/binary_encoding_output.csv")