# -*- coding: utf-8 -*-
"""
Created on Fri Jul  4 16:40:06 2025

@author: baran
"""

# -*- coding: utf-8 -*-
"""
Created on Sat May 10 17:42:55 2025

@author: baran
"""

# -*- coding: utf-8 -*-
"""
Created on Sun May  4 21:10:52 2025

@author: baran
"""
from azure.core.exceptions import HttpResponseError
import time

import numpy as np
from prompt_maker import input_maker
#from optimal_rand_tele import opt_eval
#from sum_call import get_summary
from optimal_rand_tele import opt_eval    as _opt_eval
from sum_call            import get_summary as _get_summary
import argparse
import pickle
from final_rand_med import final_eval as _final_eval
from azure.core.exceptions import AzureError  # <-- Update this import

# def azure_retry(func):
#     def wrapper(*args, **kwargs):
#         max_retries = 5
#         for attempt in range(1, max_retries + 1):
#             try:
#                 result = func(*args, **kwargs)
#                 # throttle to avoid bursting the API
#                 time.sleep(0.2)
#                 return result
#             except AzureError  as e:
#                 if e.status_code == 429:
#                     # honor Retry-After if provided, else default to 1s
#                     retry_after = int(e.response.headers.get("Retry-After", 1))
#                     print(f"[Azure 429] retry #{attempt}/{max_retries} after {retry_after}s")
#                     time.sleep(retry_after)
#                 else:
#                     # non-rate-limit errors bubble up
#                     raise
#         raise RuntimeError(f"Azure retries exhausted for {func.__name__}")
#     return wrapper

def azure_retry_timed(func):
    def wrapper(*args, **kwargs):
        max_retries = 5
        for attempt in range(1, max_retries + 1):
            try:
                result = func(*args, **kwargs)
                return result
            except AzureError as e:
                if hasattr(e, "status_code") and e.status_code == 429:
                    retry_after = int(e.response.headers.get("Retry-After", 1))
                    print(f"[Azure 429] retry #{attempt}/{max_retries} after {retry_after}s")
                    time.sleep(retry_after)
                else:
                    print(f"[Azure Error] {e}. Retrying #{attempt}/{max_retries}...")
                    time.sleep(2)
        raise RuntimeError(f"Azure retries exhausted for {func.__name__}")
    return wrapper
# get_summary = azure_retry(_get_summary)
# opt_eval    = azure_retry(_opt_eval)
# final_eval = azure_retry(_final_eval)
get_summary_core = azure_retry_timed(_get_summary)
opt_eval         = azure_retry_timed(_opt_eval)
final_eval_core  = azure_retry_timed(_final_eval)
# def timed_get_summary(prompt, arm):
#     start = time.time()
#     result = get_summary_core(prompt, arm)
#     end = time.time()
#     return result, end - start
def timed_get_summary(prompt, arm):
    start = time.time()
    result = get_summary_core(prompt, arm)
    end = time.time()

    # Tokenize output using correct tokenizer
    output_len = len(unified_encoder.encode(result))
    return result, end - start, output_len

# def timed_final_eval(dep, arm, prompt, task, all_rewards_sum, all_rewards_diag, summary):
#     start = time.time()
#     result = final_eval_core(dep, arm, prompt, task, all_rewards_sum, all_rewards_diag, summary)
#     end = time.time()
#     return result, end - start
def timed_final_eval(dep, arm, prompt, task, all_rewards_sum, all_rewards_diag, summary):
    start = time.time()
    result = final_eval_core(dep, arm, prompt, task, all_rewards_sum, all_rewards_diag, summary)
    end = time.time()

    output_len = len(unified_encoder.encode(str(result)))
    return result, end - start, output_len

def get_regret(deployments,prompt,task,selected,avg_array,t,all_rewards_sum,all_rewards_diag,labels,dataset):
    #return get_optimal_super_arm_reward(deployments,prompt,task)-reward
    return opt_eval(deployments, prompt,task,selected,avg_array,t,all_rewards_sum,all_rewards_diag,labels,dataset)
parser = argparse.ArgumentParser(description='NeuralUCB')
parser.add_argument('--size', default=100, type=int, help='number of rounds')
parser.add_argument('--number_tasks', default=2)
parser.add_argument('--no_runs',   default=5,   type=int, help='how many independent runs')

# def get_reward(deployment,cat,prompt,task,all_rewards_sum,all_rewards_diag,summary):
    
#     #call Azure API to get final reward
    
#     return final_eval(deployment, cat, prompt,task,all_rewards_sum,all_rewards_diag,summary)
    

with open('diagnoses_100.pkl', 'rb') as file: 

    # Call load method to deserialze 
    diagnoses = pickle.load(file)

with open('input_reports_100.pkl', 'rb') as file: 

    # Call load method to deserialze 
    input_reports = pickle.load(file)

new_labels = diagnoses

args, unknown = parser.parse_known_args()
no_tasks = args.number_tasks
no_runs = args.no_runs

num_rounds = args.size


#input_reports,labels = input_maker('rand',"med",0)[0:args.size]
deployments_1 = {"base" : ("gpt-35-turbo","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,vertebral/basilar stenosis,hypoxia,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"), 
                 "finetune_med" : ("Med","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,vertebral/basilar stenosis,hypoxia,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"),
                 "finetune_tele" : ("Tele","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,vertebral/basilar stenosis,hypoxia,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"), 
                 "finetune_med_new": ("Med_New","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,hypoxia,vertebral/basilar stenosis,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"),
                 "llama": ("llama","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,hypoxia,vertebral/basilar stenosis,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication")}

#deployments_0 = {"base" : ("gpt-35-turbo","You are to summarize an inputted medical report, this summary will be used for research purposes only."), "assistants" : ("Assistant","You are to summarize an inputted medical report, this summary will be used for research purposes only."),"finetune_med" : ("Med","You are to summarize an inputted medical report, this summary will be used for research purposes only."),"finetune_tele" : ("Tele","You are to summarize an inputted medical report, this summary will be used for research purposes only."),"llama" : ("llama","You are to summarize an inputted medical report, this summary will be used for research purposes only."),}

#total_len = len(summary_description_array)+len(diagnosis_description_array)

emb_size = 384
deploy = [deployments_1]
cat = ''

#input_reports = list(input_reports)
arm_to_llm = {
    "base"            : "gpt-3.5-turbo",
    "assistants"      : "gpt-3.5-turbo",
    "finetune_med"    : "gpt-4",
    "finetune_tele"   : "gpt-4",
    "finetune_med_new": "gpt-4",
    "llama"           : "llama-13b"
}

# Format: {model_name: [list of token lengths]}, {model_name: [list of latencies]}
#input_lengths_summary = {arm: [] for arm in deployments_0}
#latencies_summary     = {arm: [] for arm in deployments_0}
input_lengths_diag    = {arm: [] for arm in deployments_1}
latencies_diag        = {arm: [] for arm in deployments_1}


        
dataset = "medical"
input_reports = input_maker("rand",dataset,input_reports)
import tiktoken

from transformers import AutoConfig, AutoTokenizer

from sentence_transformers import SentenceTransformer
inp_model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
import random
dataset = "medical"
#avg_array = {"gpt-35-turbo":0,"Med":0,"Tele":0,"Med_New":0,"llama":0}
openai_models = {"gpt-3.5-turbo","gpt-4"}
encodings = { m: tiktoken.encoding_for_model(m) for m in openai_models }
unified_encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")

# for llama we use HuggingFace:
llama_tok = AutoTokenizer.from_pretrained("openlm-research/open_llama_13b")
import time
arm_encoders = {}
for mk,llm in arm_to_llm.items():
    if llm in encodings:
        arm_encoders[mk] = encodings[llm]
    else:
        arm_encoders[mk] = llama_tok

# How many times each model is selected
#summary_selection_counts = {arm: 0 for arm in deployments_0}
diagnosis_selection_counts = {arm: 0 for arm in deployments_1}

# Inside the main loop
for run in range(no_runs):
    print(f"\n===== Starting run {run+1}/{no_runs} =====")
    for t in range(num_rounds):
        print(f"Starting round {t}")
        # if i == 0:  # Summarization
        #     models = list(deployments_0.keys())
        #     prompt_to_model = input_reports[t]
        #     arm_idx = random.randint(0, len(models)-1)
        #     arm = models[arm_idx]
        #     dep = deployments_0[arm]
        #     summary_selection_counts[arm] += 1

        #     # Token length
        #     in_len = len(unified_encoder.encode(prompt_to_model))

        #     summary, latency, out_len = timed_get_summary(prompt_to_model, arm)
        #     total_len = in_len + out_len


            
        #     # Store
        #     input_lengths_summary[arm].append(total_len)
        #     latencies_summary[arm].append(latency)

        # else:  # Diagnosis
        models = list(deployments_1.keys())
        arm_idx = 3
        arm = models[arm_idx]
        diagnosis_selection_counts[arm] += 1
        print(f"Chosen model for diag: {arm}")
        dep = deployments_1[arm]
        fin_prompt = input_reports[t]
        print(f"Input to diagnoser: {fin_prompt}")
        
        # Token length
        in_len = len(unified_encoder.encode(fin_prompt))

        
        # Latency
        # start = time.time()
        # _ = final_eval(dep, arm, fin_prompt, task="diagnosis", all_rewards_sum=[], all_rewards_diag=[], summary=summary)
        # end = time.time()
        # latency = end - start
        #_, latency = timed_final_eval(dep, arm, fin_prompt, "diagnosis", [], [], summary)
        _, latency, out_len = timed_final_eval(dep, arm, fin_prompt, "diagnosis", [], [], '')
        total_len = in_len + out_len


        # Store
        input_lengths_diag[arm].append(total_len)
        latencies_diag[arm].append(latency)



import pandas as pd

import pickle
import numpy as np


# for model, count in summary_selection_counts.items():
#     print(f"{model}: {count}")

print("\nDiagnosis model selection counts:")
for model, count in diagnosis_selection_counts.items():
    print(f"{model}: {count}")

# pickle.dump(input_lengths_summary, open("input_lengths_summary_2.pkl", "wb"))
# pickle.dump(latencies_summary, open("latencies_summary_2.pkl", "wb"))
pickle.dump(input_lengths_diag, open("input_lengths_diag_med_III_1.pkl", "wb"))
pickle.dump(latencies_diag, open("latencies_diag_med_III_1.pkl", "wb"))
# pickle.dump(summary_selection_counts, open("summary_selection_counts_2.pkl", "wb"))
pickle.dump(diagnosis_selection_counts, open("diagnosis_selection_counts_med_III_1.pkl", "wb"))


print("\nAll runs complete. Token/latency stats and selection counts saved.")








