# -*- coding: utf-8 -*-
"""
Created on Thu Feb 13 11:33:47 2025

@author: baran
"""
# -*- coding: utf-8 -*-
from client import get_client
from openai import AssistantEventHandler, OpenAI
from azure.ai.inference import ChatCompletionsClient
from azure.ai.inference.models import SystemMessage, UserMessage
import re
import tiktoken
from transformers import AutoTokenizer

'''
deployments variable structure

{
    assistants : [(names,context)],
    finetune : [(names,context)],
    base : [(names,context)]
}

Structured this way since different categories of llms were made in different resouce
groups when creating them in azure.

output is a dict with model name as key and the evaluation score as value.
'''

def opt_eval(deployments, prompt,task,selected,avg_array,t,all_rewards_sum,all_rewards_diag,labels,dataset):
    result = {}
    #print(deployments)
    actual = ""
    arm_to_llm = {
         "base"            : "gpt-3.5-turbo",
         "assistants"      : "gpt-3.5-turbo",
         "finetune_med"    : "gpt-4",
         "finetune_tele"   : "gpt-4",
         "finetune_med_new": "gpt-4",
         "llama"           : "llama-13b"
     }
    #number = -5000
    reward = None
    print(labels[t])
    #print(actual)
    for cat in deployments:
        match = None
    #print(cat)
        #number = -5000
        if cat == "finetune_med" :
            cat = "finetune"
            actual = "Med"
        elif cat == "finetune_med_new" :
            cat = "finetune"
            actual = "Med_New"
        elif cat == "finetune_tele":
            cat = "finetune"
            actual = "Tele"
        client = get_client(cat)
        # print("here")
        # print(cat)
        # print(selected)
        # print(actual)
        # print("now")
        #while int(number) < 0 or int(number) > 100:
            #print(number)
        while match == None:
            if cat == 'assistants':
            # print("doing assistants")
            #print("assistant")

                for x in deployments[cat]:
                    name = x[0]
                    context = x[1]
                    #print(name)
                    #print(context)
                    assistant = client.beta.assistants.create(
                        name="Diagnosis Summarizer",
                        instructions=context,
                        tools=[{"type": "file_search"}],
                        model=name
                    )
    
                    vector_store = client.beta.vector_stores.create(name="Diagnosis Reports")
                    
                    f = open("data/input.txt", "w")
                    f.write(prompt)
                    f.close()
    
                    # Ready the files for upload to OpenAI
                    file_paths = ["data/input.txt"]
                    file_streams = [open(path, "rb") for path in file_paths]
                    
                    # Use the upload and poll SDK helper to upload the files, add them to the vector store,
                    # and poll the status of the file batch for completion.
                    file_batch = client.beta.vector_stores.file_batches.upload_and_poll(
                    vector_store_id=vector_store.id, files=file_streams
                    )
                    # print(file_batch.status)
                    # print(file_batch.file_counts)
    
                    assistant = client.beta.assistants.update(
                        assistant_id=assistant.id,
                        tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
                    )
    
                    # Upload the user provided file to OpenAI
                    message_file = client.files.create(
                    file=open("data/input.txt", "rb"), purpose="assistants"
                    )
                    
                    # Create a thread and attach the file to the message
                    thread = client.beta.threads.create(
                    messages=[
                        {
                        "role": "user",
                        "content": prompt,
                        # Attach the new file to the message.
                        "attachments": [
                            { "file_id": message_file.id, "tools": [{"type": "file_search"}] }
                        ],
                        }
                    ]
                    )
                    
                    class EventHandler(AssistantEventHandler):
                        def on_message_done(self, message) -> None:
                            # print a citation to the file searched
                            message_content = message.content[0].text
                            annotations = message_content.annotations
                            citations = []
                            for index, annotation in enumerate(annotations):
                                message_content.value = message_content.value.replace(
                                    annotation.text, f"[{index}]"
                                )
                                if file_citation := getattr(annotation, "file_citation", None):
                                    cited_file = client.files.retrieve(file_citation.file_id)
                                    citations.append(f"[{index}] {cited_file.filename}")
    
                            # print(message_content.value)
                            # print("\n".join(citations))
                            f = open("data/outputs.txt", "w", encoding="utf-8")
                            f.write(message_content.value)
                            f.close()
                            # print(message.model_dump_json(indent=2))
    
                    with client.beta.threads.runs.stream(
                        thread_id=thread.id,
                        assistant_id=assistant.id,
                        instructions="Please address the user as Jane Doe. The user has a premium account.",
                        event_handler=EventHandler(),
                    ) as stream:
                        stream.until_done()
                    d = open("data/outputs.txt","r", encoding="utf-8")
                    message = d.read()
                    d.close()
                    if cat == selected:
                        openai_models = {"gpt-3.5-turbo","gpt-4"}
                        encodings = {m: tiktoken.encoding_for_model(m) for m in openai_models}
                        llama_tok = AutoTokenizer.from_pretrained("openlm-research/open_llama_13b")
                    
                        # pick the right one for this selected arm
                        llm_name = arm_to_llm[selected]
                        if llm_name in encodings:
                            if not isinstance(message, str):
                                # if it has a `.content` or `.text` field, extract that first
                                if hasattr(message, "content"):
                                    message = message.content
                                elif hasattr(message, "text"):
                                    message = message.text
                                else:
                                    # fallback to Python’s built-in string conversion
                                    message = str(message)
                            
                            out_len = len(encodings[llm_name].encode(message))
                        else:
                            out_len = len(llama_tok(message, truncation=True, padding=False)["input_ids"])

                    if dataset == "telecom":
                        match = re.search(r"option \s*(\d+)", message)
                        label = re.search(r"option \s*(\d+)", labels[t])
                        if match == None:
                            match = re.search(r"Option \s*(\d+)", message)
                        if match:
                            num = int(match.group(1))
                            lab = int(label.group(1))
                            if lab == num:
                                number = 1
                            else:
                                number = 0
                            all_rewards_diag.append(number)
                            avg_array[name] = avg_array[name]*t + number
                            avg_array[name] = avg_array[name]/(t+1)
                            result[name] = number
                            if cat == selected or actual == selected:
                                reward = number
                        #print(match.group(1))
                        #print(label.group(1))
                    # if number == "":
                    #     number = -5000
                        #print(cat)
                        #print(message)
                    
                    elif dataset == "medical":
                        cnt = 0
                        if message is None:
                            message = ""
                        elif not isinstance(message, str):
                            if hasattr(message, "content"):
                                message = message.content
                            elif hasattr(message, "text"):
                                message = message.text
                            else:
                                message = str(message)
                        new_msg = message.lower()
                        for el in labels[t]:
                            #for d in range(len(new_msg)):
                            #    if el == new_msg[d]:
                            #        cnt += 1
                            if new_msg.find(el) != -1:
                                cnt +=1
                            
                        number = cnt/len(labels[t])
                        all_rewards_diag.append(number)
                        avg_array[name] = avg_array[name]*t + number
                        avg_array[name] = avg_array[name]/(t+1)
                        result[name] = number
                        if cat == selected or actual == selected:
                            reward = number
                            #print("222")
                            #print(reward)
                        #print(cat)
                        #print(message)
                        match = True
                        

                        
            elif cat == "small" or cat == "llama" or cat == "deepseek" or cat == "phi":

                name = deployments[cat][0]
                context = deployments[cat][1]
                #print("small diag")
                response = client.complete(
                    messages=[
                        SystemMessage(content=context),
                        UserMessage(content=prompt)
                    ],
                    max_tokens=2048,
                    temperature=0.8,
                    top_p=0.1,
                    model=name
                )
    
                message = response.choices[0].message.content
                if cat == selected:
                    print(message)
                    openai_models = {"gpt-3.5-turbo","gpt-4"}
                    encodings = {m: tiktoken.encoding_for_model(m) for m in openai_models}
                    llama_tok = AutoTokenizer.from_pretrained("openlm-research/open_llama_13b")
                
                    # pick the right one for this selected arm
                    llm_name = arm_to_llm[selected]
                    if llm_name in encodings:
                        if not isinstance(message, str):
                            # if it has a `.content` or `.text` field, extract that first
                            if hasattr(message, "content"):
                                message = message.content
                            elif hasattr(message, "text"):
                                message = message.text
                            else:
                                # fallback to Python’s built-in string conversion
                                message = str(message)
                        
                        out_len = len(encodings[llm_name].encode(message))
                    else:
                        out_len = len(llama_tok(message, truncation=True, padding=False)["input_ids"])
                if dataset == "telecom":
                    if message == "{5}":
                        message = "option 5"
                    if message.isnumeric():
                        message = f"option {message}"
                    if message.find("{") != -1:
                        message = message.replace("{","")
                        message = message.replace("}","")
                        message = f"option {message}"
                    listt = ["1","2","3","4","5"]
                    for num in listt:
                        if message[0] == num:
                            message = f"option {num}"
                        elif message.find(f"{num}:")!=-1 or message.find(f"{num}: ")!=-1:
                            message = f"option {num}"
                        elif message.splitlines()[-1][0] == num:
                            message = f"option {num}"
                    match = re.search(r"option \s*(\d+)", message)
                    label = re.search(r"option \s*(\d+)", labels[t])
                    if match == None:
                        match = re.search(r"Option \s*(\d+)", message)
                    if match:
                        num = int(match.group(1))
                        lab = int(label.group(1))
                        if lab == num:
                            number = 1
                        else:
                            number = 0
                        all_rewards_diag.append(number)
                        avg_array[name] = avg_array[name]*t + number
                        avg_array[name] = avg_array[name]/(t+1)
                        result[name] = number
                        if cat == selected or actual == selected:
                            reward = number
                    #print(cat)
                    #print(message)
                
                elif dataset == "medical":
                    cnt = 0
                    if message is None:
                        message = ""
                    elif not isinstance(message, str):
                        if hasattr(message, "content"):
                            message = message.content
                        elif hasattr(message, "text"):
                            message = message.text
                        else:
                            message = str(message)
                            
                    new_msg = message.lower()
                    for el in labels[t]:
                        #for d in range(len(new_msg)):
                        #    if el == new_msg[d]:
                        #        cnt += 1
                        if new_msg.find(el) != -1:
                            cnt +=1
                    number = cnt/len(labels[t])
                    all_rewards_diag.append(number)
                    avg_array[name] = avg_array[name]*t + number
                    avg_array[name] = avg_array[name]/(t+1)
                    result[name] = number
                    if cat == selected or actual == selected:
                        reward = number
                        #print("000")
                        #print(reward)
                #print(match.group(1))
                #print(label.group(1))
                # result[name] = number
                    #print(cat)
                    #print(message)
                    match = True
                    

            else:
            # print("doing finetuing")
            #print(actual)
            #while match == None:
                if actual == "Med":
                    cat = "finetune_med"
                    #print("medical diag:")
                elif actual =="Tele":
                    cat = "finetune_tele"
                elif actual == "Med_New":
                    cat = "finetune_med_new"
                    #print("telecom diag:")
                # else:
                #     print("base diag:")
                    
                    #while int(number) < 0 or int(number) > 100:
                name = deployments[cat][0]
                context = deployments[cat][1]
                #print(name)
                #print(context)
                response = client.chat.completions.create(
                    model= name,
                    messages=[
                        {"role": "system", "content": context},
                        {"role": "user", "content": prompt}
                    ]
                )
                message = response.choices[0].message.content
                #print(message)
                if cat == selected:
                    print(message)
                    openai_models = {"gpt-3.5-turbo","gpt-4"}
                    encodings = {m: tiktoken.encoding_for_model(m) for m in openai_models}
                    llama_tok = AutoTokenizer.from_pretrained("openlm-research/open_llama_13b")
                
                    # pick the right one for this selected arm
                    llm_name = arm_to_llm[selected]
                    if llm_name in encodings:
                        #out_len = len(encodings[llm_name].encode(message))
                        if not isinstance(message, str):
                            # if it has a `.content` or `.text` field, extract that first
                            if hasattr(message, "content"):
                                message = message.content
                            elif hasattr(message, "text"):
                                message = message.text
                            else:
                                # fallback to Python’s built-in string conversion
                                message = str(message)
                        
                        out_len = len(encodings[llm_name].encode(message))
                    else:
                        out_len = len(llama_tok(message, truncation=True, padding=False)["input_ids"])
                if dataset == "telecom":
                    match = re.search(r"option \s*(\d+)", message)
                    label = re.search(r"option \s*(\d+)", labels[t])
                    if match == None:
                        match = re.search(r"Option \s*(\d+)", message)
                    if match:
                        num = int(match.group(1))
                        lab = int(label.group(1))
                        if lab == num:
                            number = 1
                        else:
                            number = 0
                        all_rewards_diag.append(number)
                        avg_array[name] = avg_array[name]*t + number
                        avg_array[name] = avg_array[name]/(t+1)
                        result[name] = number
                        if cat == selected or actual == selected:
                            reward = number
                    
                    #print(cat)
                    #print(message)
                    
                elif dataset == "medical":
                # if cat == selected or actual == selected:
                #     reward = number
                    cnt = 0
                    if message is None:
                        message = ""
                    elif not isinstance(message, str):
                        if hasattr(message, "content"):
                            message = message.content
                        elif hasattr(message, "text"):
                            message = message.text
                        else:
                            message = str(message)
                    new_msg = message.lower()
                    for el in labels[t]:
                        #for d in range(len(new_msg)):
                        #    if el == new_msg[d]:
                        #        cnt += 1
                        if new_msg.find(el) != -1:
                            cnt +=1
                    number = cnt/len(labels[t])
                    all_rewards_diag.append(number)
                    avg_array[name] = avg_array[name]*t + number
                    avg_array[name] = avg_array[name]/(t+1)
                    result[name] = number
                    if cat == selected or actual == selected:
                        reward = number
                        #print("111")
                        #print(reward)
                #print(match.group(1))
                #print(label.group(1))
                # result[name] = number
                    #print(cat)
                    #print(message)
                    match = True
                
            
            
    best_dep = ""
    best_score = 0
    print(result)
    #print(reward)
    print("avg:")
    print(avg_array)
    # for dep in result:
    #     if float(result[dep]) >= best_score:
    #         best_score = float(result[dep])
    #         best_dep = dep
    #return best_dep, best_score
    if reward is None:
        raise ValueError("Reward was not calculated for the selected deployment.")
    if dataset == "telecom":
        for dep in result:
            if int(result[dep]) >= best_score:
                best_score = int(result[dep])
                best_dep = dep
        return int(best_score)-int(reward),int(reward),out_len,avg_array,all_rewards_sum,all_rewards_diag
    elif dataset == "medical":
        for dep in result:
            if float(result[dep]) >= best_score:
                best_score = float(result[dep])
                best_dep = dep
        return best_score-reward,reward,out_len,avg_array,all_rewards_sum,all_rewards_diag

#reg,reward,avg_array,all_rewards_sum,all_rewards_diag = opt_eval(deployments_1,fin_prompt,task,selected,avg_array,t,all_rewards_sum,all_rewards_diag,labels)

# deployment = {'assistants' : [("Assistant", "You are an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only.")], 'finetune' : [("Med","You are an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only.")], 'base' : []}
# prompt = """{"question": "What is the purpose of the Nmfaf_3daDataManagement_Deconfigure service operation? [3GPP Release 18]",
# 		"option 1": "To configure the MFAF to map data or analytics received by the MFAF to out-bound notification endpoints",
# 		"option 2": "To configure the MFAF to stop mapping data or analytics received by the MFAF to out-bound notification endpoints",
# 		"option 3": "To supply data or analytics from the MFAF to notification endpoints",
# 		"option 4": "To fetch data or analytics from the MFAF based on fetch instructions",
# 		"answer": "option 2: To configure the MFAF to stop mapping data or analytics received by the MFAF to out-bound notification endpoints",
# 		"explanation": "The Nmfaf_3daDataManagement_Deconfigure service operation is used to stop mapping data or analytics received by the MFAF to one or more out-bound notification endpoints.",
# 		"category": "Standards specifications"}"""

# print(opt_eval(deployment, prompt))
