# -*- coding: utf-8 -*-
"""
Created on Tue Jun 24 17:56:38 2025

@author: baran
"""

# -*- coding: utf-8 -*-
"""
Created on Mon Jun 23 14:35:39 2025

@author: baran
"""

# -*- coding: utf-8 -*-
"""
Created on Mon Jun 23 09:47:41 2025

@author: baran
"""

# -*- coding: utf-8 -*-
"""
Created on Sat May 10 17:42:55 2025

@author: baran
"""

# -*- coding: utf-8 -*-
"""
Created on Sun May  4 21:10:52 2025

@author: baran
"""
from azure.core.exceptions import HttpResponseError
import time

import numpy as np
from prompt_maker import input_maker
#from optimal_rand_tele import opt_eval
#from sum_call import get_summary
from optimal_rand_tele import opt_eval    as _opt_eval
from sum_call            import get_summary as _get_summary
import argparse
import pickle
from final_rand_med import final_eval

def azure_retry(func):
    def wrapper(*args, **kwargs):
        max_retries = 5
        for attempt in range(1, max_retries + 1):
            try:
                result = func(*args, **kwargs)
                # throttle to avoid bursting the API
                time.sleep(0.2)
                return result
            except HttpResponseError as e:
                if e.status_code == 429:
                    # honor Retry-After if provided, else default to 1s
                    retry_after = int(e.response.headers.get("Retry-After", 1))
                    print(f"[Azure 429] retry #{attempt}/{max_retries} after {retry_after}s")
                    time.sleep(retry_after)
                else:
                    # non-rate-limit errors bubble up
                    raise
        raise RuntimeError(f"Azure retries exhausted for {func.__name__}")
    return wrapper
get_summary = azure_retry(_get_summary)
opt_eval    = azure_retry(_opt_eval)
def get_regret(deployments,prompt,task,selected,avg_array,t,all_rewards_sum,all_rewards_diag,labels,dataset):
    #return get_optimal_super_arm_reward(deployments,prompt,task)-reward
    return opt_eval(deployments, prompt,task,selected,avg_array,t,all_rewards_sum,all_rewards_diag,labels,dataset)
parser = argparse.ArgumentParser(description='NeuralUCB')
parser.add_argument('--size', default=100, type=int, help='number of rounds')
parser.add_argument('--number_tasks', default=2)
parser.add_argument('--no_runs',   default=5,   type=int, help='how many independent runs')

def get_reward(deployment,cat,prompt,task,all_rewards_sum,all_rewards_diag,summary):
    
    #call Azure API to get final reward
    
    return final_eval(deployment, cat, prompt,task,all_rewards_sum,all_rewards_diag,summary)
    

with open('diagnoses_100.pkl', 'rb') as file: 

    # Call load method to deserialze 
    diagnoses = pickle.load(file)

with open('input_reports_100.pkl', 'rb') as file: 

    # Call load method to deserialze 
    input_reports = pickle.load(file)

new_labels = diagnoses

args, unknown = parser.parse_known_args()
no_tasks = args.number_tasks
no_runs = args.no_runs

num_rounds = args.size
all_regrets     = []   # list of length no_runs, each an array of length num_rounds
all_rewards     = []
all_costs       = []
all_costs_summarizer = []
all_plays       = []   # list of length no_runs, each an array of length num_arms
all_avg_arrays  = []   # list of dicts
all_avg_summary_arrays = []
all_runs_diag_stats = []

#input_reports,labels = input_maker('rand',"med",0)[0:args.size]
# deployments_1 = {"base" : ("gpt-35-turbo","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,vertebral/basilar stenosis,hypoxia,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"), 
#                  "finetune_med" : ("Med","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,vertebral/basilar stenosis,hypoxia,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"),
#                  "finetune_tele" : ("Tele","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,vertebral/basilar stenosis,hypoxia,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"), 
#                  "finetune_med_new": ("Med_New","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,hypoxia,vertebral/basilar stenosis,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"),
#                  "llama": ("llama","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,hypoxia,vertebral/basilar stenosis,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication")}

deployments_1 = {"base" : ("gpt-35-turbo","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,vertebral/basilar stenosis,hypoxia,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"), 
                 "finetune_med" : ("Med","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,vertebral/basilar stenosis,hypoxia,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"),
                 "finetune_tele" : ("Tele","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,vertebral/basilar stenosis,hypoxia,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"), 
                 "finetune_med_new": ("Med_New","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,hypoxia,vertebral/basilar stenosis,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"),
                 "llama": ("llama","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,hypoxia,vertebral/basilar stenosis,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"),}


#deployments_0 = {"base" : ("gpt-35-turbo","You are to summarize an inputted medical report, this summary will be used for research purposes only."), "assistants" : ("Assistant","You are to summarize an inputted medical report, this summary will be used for research purposes only."),"finetune_med" : ("Med","You are to summarize an inputted medical report, this summary will be used for research purposes only."),"finetune_tele" : ("Tele","You are to summarize an inputted medical report, this summary will be used for research purposes only."),"llama" : ("llama","You are to summarize an inputted medical report, this summary will be used for research purposes only."),}

#total_len = len(summary_description_array)+len(diagnosis_description_array)

emb_size = 384
deploy = [deployments_1]
cat = ''

#input_reports = list(input_reports)
arm_to_llm = {
    "base"            : "gpt-3.5-turbo",
    "assistants"      : "gpt-3.5-turbo",
    "finetune_med"    : "gpt-4",
    "finetune_tele"   : "gpt-4",
    "finetune_med_new": "gpt-4",
    "llama"           : "llama-13b"
}

input_cost_per_token = {
    "base"            : 0.0000005,
    "assistants"      : 0.0000005,
    "finetune_med"    : 0.00000025,
    "finetune_tele"   : 0.00000025,
    "finetune_med_new": 0.00000025,
    "llama"           : 0.00000071
}


cost_per_token = {
    "base"            : 0.0000015,   # GPT-3.5 Turbo
    "assistants"      : 0.0000015,   # GPT-3.5 Turbo
    "finetune_med"    : 0.00001,   # GPT-4
    "finetune_tele"   : 0.00001,   # GPT-4
    "finetune_med_new": 0.00001,   # GPT-4
    "llama"           : 0.00000071   # Llama-13b
}

        
dataset = "medical"
input_reports = input_maker("rand",dataset,input_reports)
import tiktoken

from transformers import AutoConfig, AutoTokenizer

from sentence_transformers import SentenceTransformer
inp_model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
import random
dataset = "medical"
#avg_array = {"gpt-35-turbo":0,"Med":0,"Tele":0,"Med_New":0,"llama":0}

for run in range(no_runs):
    all_rewards_sum = []
    all_rewards_diag = []
    print(f"\n===== Starting run {run+1}/{no_runs} =====")
    regrets = []
    plays_no = np.ones(len(deployments_1)) 
    #summary_reward_sums = { arm: 0.0 for arm in deployments_0.keys() }
    #summary_counts      = { arm:   0   for arm in deployments_0.keys() }
    #summary_avg_array   = { arm: 0.0 for arm in deployments_0.keys() }
    #diag_reward_by_summ   = { arm: []  for arm in deployments_0.keys() }
    #diag_regret_by_summ   = { arm: []  for arm in deployments_0.keys() }

    summ = 0
    rew = 0
    rewards_list = []
    costs = 0
    #costs_summarizer = 0
    costs_list = []
    #costs_list_summarizer = []
    total_reward = 0
    avg_array = {"gpt-35-turbo":0,"Med":0,"Tele":0,"Med_New":0,"llama":0}
    rewards = 0
    openai_models = {"gpt-3.5-turbo","gpt-4"}
    encodings = { m: tiktoken.encoding_for_model(m) for m in openai_models }
    
    # for llama we use HuggingFace:
    llama_tok = AutoTokenizer.from_pretrained("openlm-research/open_llama_13b")
    
    # now build a lookup that is either a tiktoken encoder or HF tokenizer
    arm_encoders = {}
    for mk,llm in arm_to_llm.items():
        if llm in encodings:
            arm_encoders[mk] = encodings[llm]
        else:
            arm_encoders[mk] = llama_tok
            
    for t in range(num_rounds):

        models = ["base","finetune_med","finetune_tele","finetune_med_new","llama"]
        fin_prompt = input_reports[t]
        prompt_to_model = fin_prompt
        #arm = random.randint(0,len(models)-1)
        arm = 3
        plays_no[arm] += 1 
        arm_select = models[arm] 
        task = 'diagnosis'
        print(f"Selected diagnoser: {arm_select}")
    
        if models[arm] == "finetune_med" or models[arm]=="finetune_tele" or models[arm]=="finetune_med_new":
            cat = "finetune"
            
        else:    
            cat = models[arm]
        dep = deploy[0]
        #print(f"Prompt to diagnoser: {fin_prompt}")
        in_lengths = []
        for mk in models:
            enc = arm_encoders[mk]
            # If this is a tiktoken Encoding, use .encode(...)
            if hasattr(enc, "encode"):
                # tiktoken Encoding
                in_len = len(enc.encode(prompt_to_model))
            else:
                # HuggingFace tokenizer
                in_len = len(enc(
                    prompt_to_model,
                    truncation=True,
                    padding=False
                )["input_ids"])
            in_lengths.append(in_len)
        deployment = dep[arm_select]
        selected= arm_select
        task = 'diagnosis'
        reg,reward,out_len,avg_array,all_rewards_sum,all_rewards_diag = get_regret(deployments_1,fin_prompt,task,selected,avg_array,t,all_rewards_sum,all_rewards_diag,new_labels,dataset) #fill arguments here
        #diag_reward_by_summ[summarizer_choice].append(reward)
        #diag_regret_by_summ[summarizer_choice].append(reg)
        rewards += int(reward)
        rewards_list.append(rewards)
        costs += input_cost_per_token[arm_select]*in_lengths[arm]+ cost_per_token[arm_select]* out_len
        costs_list.append(costs)

        print(f"Reward (diagnoser): {reward}")
        #print(reward)
        print(f"Regret (diagnoser): {reg}")
        #print(reg)
        print(plays_no)
        #print("Done")
        summ+= reg
        regrets.append(summ)

        if (t+1) % 5 == 0:
            print('{}: {:.3f}, {:.3f}'.format(t+1, summ, rewards))
    all_regrets.append(regrets)
    all_rewards.append(rewards_list)
    all_costs.append(costs_list)
    #all_costs_summarizer.append(costs_list_summarizer)
    all_plays.append(plays_no.copy())
    all_avg_arrays.append(avg_array.copy())
    #all_avg_summary_arrays.append(summary_avg_array.copy())   # <<<<< NEW
    #all_diag_reward_means = {
    #    arm: np.mean(diag_reward_by_summ[arm]) 
    #        if diag_reward_by_summ[arm] else 0.0
    #    for arm in diag_reward_by_summ
    #}
    # all_diag_reward_stds  = {
    #     arm: np.std(diag_reward_by_summ[arm])  
    #         if diag_reward_by_summ[arm] else 0.0
    #     for arm in diag_reward_by_summ
    # }
    # # pack into one dict (or DataFrame row) and save
    # all_runs_diag_stats.append({
    #     arm: (all_diag_reward_means[arm], all_diag_reward_stds[arm])
    #     for arm in all_diag_reward_means
    # })  

import pandas as pd
regrets_arr = np.array(all_regrets)     # shape (no_runs, num_rounds)
rewards_arr = np.array(all_rewards)
costs_arr   = np.array(all_costs)
costs_summarizer_arr   = np.array(all_costs_summarizer)
plays_arr   = np.array(all_plays)       # shape (no_runs, num_arms)
avg_df      = pd.DataFrame(all_avg_arrays)  # columns=model names

# mean_df = pd.DataFrame(
#     { arm: [stats[arm][0] for stats in all_runs_diag_stats]
#       for arm in diag_reward_by_summ }
# )
# std_df  = pd.DataFrame(
#     { arm: [stats[arm][1] for stats in all_runs_diag_stats]
#       for arm in diag_reward_by_summ }
# )
# overall_mean = mean_df.mean(axis=0).to_dict()
# overall_std  = std_df.std(axis=0).to_dict()


regrets_mean = regrets_arr.mean(axis=0)
regrets_std  = regrets_arr.std(axis=0)
rewards_mean = rewards_arr.mean(axis=0)
rewards_std  = rewards_arr.std(axis=0)
costs_mean   = costs_arr.mean(axis=0)
costs_std    = costs_arr.std(axis=0)
#costs_summarizer_mean   = costs_summarizer_arr.mean(axis=0)
#costs_summarizer_std    = costs_summarizer_arr.std(axis=0)
plays_mean = plays_arr.mean(axis=0)
plays_std  = plays_arr.std(axis=0)
avg_mean = avg_df.mean(axis=0).to_dict()
avg_std  = avg_df.std(axis=0).to_dict()

#summary_avg_df   = pd.DataFrame(all_avg_summary_arrays)
#summary_mean     = summary_avg_df.mean(axis=0).to_dict()
#summary_std      = summary_avg_df.std(axis=0).to_dict()

import pickle
pickle.dump(regrets_mean, open("regrets_mean_med_rand_med_III_2.pkl","wb"))
pickle.dump(regrets_std,  open("regrets_std_med_rand_med_III_2.pkl","wb"))
pickle.dump(rewards_mean, open("rewards_mean_med_rand_med_III_2.pkl","wb"))
pickle.dump(rewards_std,  open("rewards_std_med_rand_med_III_2.pkl","wb"))
pickle.dump(costs_mean,   open("costs_mean_med_rand_med_III_2.pkl","wb"))
pickle.dump(costs_std,    open("costs_std_med_rand_med_III_2.pkl","wb"))
# pickle.dump(costs_summarizer_mean,   open("costs_summarizer_mean_med_rand_sum_llama_diag_1.pkl","wb"))
# pickle.dump(costs_summarizer_std,    open("costs_summarizer_std_med_rand_sum_llama_diag_1.pkl","wb"))
pickle.dump(plays_mean,   open("plays_mean_med_rand_med_III_2.pkl","wb"))
pickle.dump(plays_std,    open("plays_std_med_rand_med_III_2.pkl","wb"))
pickle.dump(avg_mean,     open("avg_accuracy_mean_med_rand_med_III_2.pkl","wb"))
pickle.dump(avg_std,      open("avg_accuracy_std_med_rand_med_III_2.pkl","wb"))
# pickle.dump(summary_mean,     open("sum_avg_accuracy_mean_med_rand_med_III_1.pkl","wb"))
# pickle.dump(summary_std,      open("sum_avg_accuracy_std_med_rand_med_III_1.pkl","wb"))
# pickle.dump(overall_mean,     open("summ_diag_med_eff_mean_rew_rand_sum_llama_diag_1.pkl","wb"))
# pickle.dump(overall_std,      open("summ_diag_med_eff_std_rew_rand_sum_llama_diag_1.pkl","wb"))
print(f"Final mean regret: {regrets_mean[-1]}")
print(f"Final mean reward: {rewards_mean[-1]}")
print(f"Final mean cost: {costs_mean[-1]}")
# print(f"Final mean summarizer cost: {costs_summarizer_mean[-1]}")
print(f"Final mean plays: {plays_mean}")
print(f"Final mean average array: {avg_mean}")
# print(f"Final mean summary average array: {summary_mean}")
# print(f"Final mean summary diag effect: {overall_mean}")

print("All runs complete. Summary pickles written.")