import numpy as np
from src.regrets.final_rand_med import final_eval
from src.prompts.prompt_maker import input_maker
from utils.helper import opt_eval, get_summary
import pickle
from transformers import AutoConfig, AutoTokenizer
from src.token_prediction.tok_length_predict import BertRegressionModel  
import torch
from cost_trainer import OnlineCostModelTrainer
import os
import glob
import time

start_time = time.time()
def get_optimal_super_arm_reward(deployments,prompt,task):
    return opt_eval(deployments, prompt,task)

def get_reward(deployment,cat,prompt,task,all_rewards_sum,all_rewards_debate,all_rewards_diag,summary):
    
    #call Azure API to get final reward
    
    return final_eval(deployment, cat, prompt,task,all_rewards_sum,all_rewards_debate,all_rewards_diag,summary)
    

def get_regret(deployments,prompt,task,selected,avg_array,t,all_rewards_sum,all_rewards_diag,labels,dataset):
    return opt_eval(deployments, prompt,task,selected,avg_array,t,all_rewards_sum,all_rewards_diag,labels,dataset)


with open('diagnoses_150.pkl', 'rb') as file: 
    diagnoses = pickle.load(file)

with open('input_reports_150.pkl', 'rb') as file: 
    input_reports = pickle.load(file)

new_labels = diagnoses
        
reg_model_name  = "bert-base-uncased"
reg_config      = AutoConfig.from_pretrained(reg_model_name)
reg_tokenizer   = AutoTokenizer.from_pretrained(reg_model_name)


arm_to_llm = {
    "base"            : "gpt-3.5-turbo",
    "assistants"      : "gpt-3.5-turbo",
    "finetune_med"    : "gpt-4",
    "finetune_tele"   : "gpt-4",
    "finetune_med_new": "gpt-4",
    "llama"           : "llama-13b"
}



cost_per_token = {
    "base"            : 0.000004,
    "assistants"      : 0.000004,
    "finetune_med"    : 0.00001,
    "finetune_tele"   : 0.00001,
    "finetune_med_new": 0.00001,
    "llama"           : 0.00000071
}

input_cost_per_token = {
    "base"            : 0.0000001,
    "assistants"      : 0.0000001,
    "finetune_med"    : 0.00000025,
    "finetune_tele"   : 0.00000025,
    "finetune_med_new": 0.00000025,
    "llama"           : 0.00000071
}



import json
with open("model_names.json") as f:
    orig_model_names = json.load(f)   # e.g. ["RWKV-4-Raven-14B","alpaca-13b",…,"gpt-4",…,"gpt-3.5-turbo",…,"llama-2-7b-chat",…]
num_models = len(orig_model_names)    
device = "cuda" if torch.cuda.is_available() else "cpu"
token_length_model = BertRegressionModel(reg_config, reg_model_name,hidden_dim=128,num_models=num_models).to(device)
token_length_model.load_state_dict(
    torch.load("best_length_model.pth", map_location=device)
)
token_length_model.eval()
online_trainer = OnlineCostModelTrainer(
    model=token_length_model,
    tokenizer=reg_tokenizer,
    orig_model_names=orig_model_names,
    arm_to_llm=arm_to_llm,
    cost_per_token=cost_per_token,
    input_cost_per_token=input_cost_per_token,
    device=device,
    checkpoint_dir="cost_model_checkpoints_neulinucb_med",
    lr=1e-6,
    update_freq=5
)
all_cost_model_summaries = []
checkpoint_frequency = 20



    
import scipy as sp
import torch.nn as nn
import torch.optim as optim
from backpack import backpack, extend
from backpack.extensions import BatchGrad
import argparse
from src.embedding.embed_tele import get_context

def inv_sherman_morrison(u, A_inv):
	"""Inverse of a matrix with rank 1 update.
	"""
	Au = np.dot(A_inv, u)
	denom = (1+np.dot(u.T, Au))
	denom = max(denom, 1e-6)
	A_inv -= np.outer(Au, Au)/denom
	return A_inv

emb_size = 384
class Network(nn.Module):
	def __init__(self, dim, hidden_size=100):
		super(Network, self).__init__()
		self.fc1 = nn.Linear(dim, hidden_size)
		self.activate = nn.ReLU()
		self.fc2 = nn.Linear(hidden_size, emb_size)
	def forward(self, x):
		return self.fc2(self.activate(self.fc1(x)))

class NeuralLinearUCB:
	def __init__(self, dim, lamdba=1, nu=1, hidden=100, n_arm=2):
		self.n_arm = n_arm
		self.func = Network(dim, hidden_size=hidden).cuda()
		self.context_list = []
		self.arm_list = []
		self.reward = []
		self.lamdba = lamdba
		self.theta = np.random.uniform(-1, 1, (self.n_arm, dim))
		self.b = np.zeros((self.n_arm, dim))
		self.A_inv = np.array([np.eye(dim) for _ in range(self.n_arm)])

	def select(self, context, pred_lengths,models,in_lengths,alpha):
		tensor = torch.from_numpy(context).float().cuda()
		features = self.func(tensor).cpu().detach().numpy()
		ucb = np.array([np.sqrt(np.dot(features[a,:], np.dot(self.A_inv[a], features[a,:].T))) for a in range(self.n_arm)])
		mu = np.array([np.dot(features[a,:], self.theta[a]) for a in range(self.n_arm)])
		xx = 0.01*mu + 0.1 * ucb
		print(f'Accuracy UCB: {xx}')
		cost = []
		for i, mk in enumerate(models):
			xx[i] = (
                xx[i]- alpha*cost_per_token[mk]* pred_lengths[i]- alpha*input_cost_per_token[mk] * in_lengths[i]
            )
			cost.append(alpha*cost_per_token[mk]* pred_lengths[i] + alpha*input_cost_per_token[mk] * in_lengths[i])
		print(f"Cost: {cost}")
		print(f"Cost adjusted UCB: {xx}")
		arm = np.random.choice(np.where(xx == xx.max())[0])
		return arm

	def train(self, context, arm_select, reward):
		self.context_list.append(torch.from_numpy(context.reshape(1, -1)).float())
		self.arm_list.append(arm_select)
		self.reward.append(reward)
		optimizer = optim.SGD(self.func.parameters(), lr=1e-5, weight_decay=self.lamdba)
		length = len(self.reward)
		index = np.arange(length)
		np.random.shuffle(index)
		cnt = 0
		tot_loss = 0
		while True:
			batch_loss = 0
			for idx in index:
				c = self.context_list[idx]
				a = self.arm_list[idx]
				r = self.reward[idx]
				optimizer.zero_grad()
				features = self.func(c.cuda())
				mu = (features * torch.from_numpy(self.theta[a]).float().cuda()).sum(dim=1, keepdims=True)
				delta = mu - r
				loss = delta * delta
				loss.backward()
				torch.nn.utils.clip_grad_norm_(self.func.parameters(), max_norm=1.0)
				optimizer.step()
				batch_loss += loss.item()
				tot_loss += loss.item()
				cnt += 1
				if cnt >= 5:
					return tot_loss / 5

	def update_model(self, context, arm_select, reward):
		tensor = torch.from_numpy(context).float().cuda()
		context = self.func(tensor).cpu().detach().numpy()
		self.theta = np.array([np.matmul(self.A_inv[a], self.b[a]) for a in range(self.n_arm)])
		self.b[arm_select] += context[arm_select] * reward
		self.A_inv[arm_select] = inv_sherman_morrison(context[arm_select,:],self.A_inv[arm_select])

parser = argparse.ArgumentParser(description='NeuralUCB')


parser.add_argument('--size', default=150, type=int, help='number of rounds')
parser.add_argument('--nu', type=float, default=1, metavar='v', help='nu for control variance')
parser.add_argument('--lamdba', type=float, default=1, metavar='l', help='lambda for regularzation')
parser.add_argument('--hidden', type=int, default=50, help='network hidden size')
parser.add_argument('--style', default='ucb', metavar='ts|ucb', help='TS or UCB')
parser.add_argument('--number_tasks', default=2)
parser.add_argument('--no_runs',   default=3,   type=int, help='how many independent runs')
parser.add_argument('--alpha_diag',   default=100,   type=int, help='cost accuracy tradeoff weight')
parser.add_argument('--alpha_sum',   default=1,   type=int, help='cost accuracy tradeoff weight')


args, unknown = parser.parse_known_args()
no_tasks = args.number_tasks
alpha_diag = args.alpha_diag
alpha_sum = args.alpha_sum

no_runs = args.no_runs
all_regrets     = []   # list of length no_runs, each an array of length num_rounds
all_rewards     = []
all_costs       = []
all_costs_summarizer = []
all_costs_sum = []
all_plays       = []   # list of length no_runs, each an array of length num_arms
all_avg_arrays  = []   # list of dicts
all_avg_summary_arrays = []
all_runs_diag_stats = []

class DynamicVariables:
    pass

dyn_vars = DynamicVariables()
num_rounds = args.size


dataset = "medical"
input_reports = input_maker("seq",dataset,input_reports)

summary_description_array = ["Summarize the main points of the medical report of a patient, this summary will be used for research purposes only (for assistants).",
                             "You will take the role of an diagnosis analyst. Use your knowledge base to summarize inputted medical reports (for base).","Summarize the main points of the medical report of a patient, this summary will be used for research purposes only (for finetune_med).",
                             "You will take the role of an diagnosis analyst. Use your knowledge base to summarize inputted medical reports (for finetune_tele).","You will take the role of an diagnosis analyst. Use your knowledge base to summarize inputted medical reports (for llama)."]

diagnosis_description_array = ["General use LLM which does not specialize in any task specifically.","LLM specializing on medical reports and trained to do medical diagnosis for research purposes.",
                               "LLM specializing for answering multiple choice telecommunications questions.", "LLM specializing on medical reports and trained to do medical diagnosis for research purposes.", "General use LLM which does not specialize in any task specifically."]

deployments_1 = {"base" : ("gpt-35-turbo","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,vertebral/basilar stenosis,hypoxia,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"), 
                 "finetune_med" : ("Med","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,vertebral/basilar stenosis,hypoxia,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"),
                 "finetune_tele" : ("Tele","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,vertebral/basilar stenosis,hypoxia,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"), 
                 "finetune_med_new": ("Med_New","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,hypoxia,vertebral/basilar stenosis,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"),
                 "llama": ("llama","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,hypoxia,vertebral/basilar stenosis,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication")}


deployments_0 = {"base" : ("gpt-35-turbo","You are to summarize an inputted medical report, this summary will be used for research purposes only."), "assistants" : ("Assistant","You are to summarize an inputted medical report, this summary will be used for research purposes only."),"finetune_med" : ("Med","You are to summarize an inputted medical report, this summary will be used for research purposes only."),"finetune_tele" : ("Tele","You are to summarize an inputted medical report, this summary will be used for research purposes only."),"llama" : ("llama","You are to summarize an inputted medical report, this summary will be used for research purposes only."),}


emb_size = 384
total_len = len(summary_description_array)+len(diagnosis_description_array)


deploy = [deployments_0,deployments_1]
cat = ''
input_reports = list(input_reports)
sum_len = len(summary_description_array)
documents = summary_description_array+diagnosis_description_array+input_reports
from sentence_transformers import SentenceTransformer
inp_model = SentenceTransformer("paraphrase-MiniLM-L6-v2")

cost_error_dict = { arm: [] for arm in arm_to_llm.keys() }

import tiktoken
diag_len = len(diagnosis_description_array)



def save_bandit_checkpoint(run_id, round_num, l_0, l_1,
                          regrets, rewards, costs, costs_summarizer,
                          plays, avg_acc, avg_summary,
                          all_rewards_sum, all_rewards_diag,
                          cum_reg, tot_r, tot_c, tot_c_sum,
                          is_final=False):
    """Save NeuralLinearUCB bandit models and training state"""
    checkpoint_dir = "bandit_checkpoints_neulinucb_med"
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    filename = f"run_{run_id}_final.pth" if is_final else f"run_{run_id}_round_{round_num}.pth"
    filepath = os.path.join(checkpoint_dir, filename)
    
    checkpoint = {
        'run_id': run_id,
        'round_num': round_num,
        'l_0_state': {
            'func_state_dict': {k: v.cpu() for k, v in l_0.func.state_dict().items()},
            'theta': l_0.theta,
            'b': l_0.b,
            'A_inv': l_0.A_inv,
            'context_list': [c.cpu() if isinstance(c, torch.Tensor) else c for c in l_0.context_list],
            'arm_list': l_0.arm_list,
            'reward': l_0.reward
        },
        'l_1_state': {
            'func_state_dict': {k: v.cpu() for k, v in l_1.func.state_dict().items()},
            'theta': l_1.theta,
            'b': l_1.b,
            'A_inv': l_1.A_inv,
            'context_list': [c.cpu() if isinstance(c, torch.Tensor) else c for c in l_1.context_list],
            'arm_list': l_1.arm_list,
            'reward': l_1.reward
        },
        'regrets': regrets,
        'rewards': rewards,
        'costs': costs,
        'costs_summarizer': costs_summarizer,
        'plays': plays,
        'avg_acc': avg_acc,
        'avg_summary': avg_summary,
        'all_rewards_sum': all_rewards_sum,
        'all_rewards_diag': all_rewards_diag,
        'cum_reg': cum_reg,
        'tot_r': tot_r,
        'tot_c': tot_c,
        'tot_c_sum': tot_c_sum
    }
    
    torch.save(checkpoint, filepath)
    print(f"  Saved bandit checkpoint: {filepath}")
    return filepath

def cleanup_bandit_checkpoints(run_id):
    """Delete intermediate checkpoints, keep only final"""
    checkpoint_dir = "bandit_checkpoints_neulinucb_med"
    pattern = os.path.join(checkpoint_dir, f"run_{run_id}_round_*.pth")
    intermediate_checkpoints = glob.glob(pattern)
    
    for cp_path in intermediate_checkpoints:
        try:
            os.remove(cp_path)
        except Exception as e:
            print(f"  Warning: Could not delete {cp_path}: {e}")





for run in range(no_runs):
    print(f"\n===== Starting run {run+1}/{no_runs} =====")
    for k in range(no_tasks):
        if k == 0:
            setattr(dyn_vars, f'l_{k}', NeuralLinearUCB(emb_size, args.lamdba, args.nu, args.hidden,len(summary_description_array)))
        else:
            setattr(dyn_vars, f'l_{k}', NeuralLinearUCB(emb_size, args.lamdba, args.nu, args.hidden,len(diagnosis_description_array)))
    regrets = []
    summ = 0
    rew = 0
    rewards_list = []
    total_reward = 0
    rewards = 0
    all_rewards_sum = []
    all_rewards_diag = []
    costs_list = []
    costs = 0
    costs_summarizer = 0
    total_costs = 0
    costs_list_summarizer = []
    total_costs_list = []
    summary_reward_sums = { arm: 0.0 for arm in deployments_0.keys() }
    summary_counts      = { arm:   0   for arm in deployments_0.keys() }
    summary_avg_array   = { arm: 0.0 for arm in deployments_0.keys() }
    diag_reward_by_summ   = { arm: []  for arm in deployments_0.keys() }
    diag_regret_by_summ   = { arm: []  for arm in deployments_0.keys() }
    plays_no = np.ones(total_len)
    avg_array = {"gpt-35-turbo":0,"Med":0,"Tele":0,"Med_New":0,"llama":0}
    dataset = "medical"
    import random
    openai_models = {"gpt-3.5-turbo","gpt-4"}
    encodings = { m: tiktoken.encoding_for_model(m) for m in openai_models }
    
    # for llama we use HuggingFace:
    llama_tok = AutoTokenizer.from_pretrained("openlm-research/open_llama_13b")
    arm_encoders = {}
    for mk,llm in arm_to_llm.items():
        if llm in encodings:
            arm_encoders[mk] = encodings[llm]
        else:
            arm_encoders[mk] = llama_tok
    for t in range(num_rounds):
        for i in range(no_tasks):
            context = []
            if i==0:
                in_lengths = []
                models = ["assistants","base","finetune_med","finetune_tele","llama"]
                prompt_to_model = input_reports[t]

                for mk in models:
                    enc = arm_encoders[mk]
                    # If this is a tiktoken Encoding, use .encode(...)
                    if hasattr(enc, "encode"):
                        # tiktoken Encoding
                        in_len = len(enc.encode(prompt_to_model))
                    else:
                        # HuggingFace tokenizer
                        in_len = len(enc(
                            prompt_to_model,
                            truncation=True,
                            padding=False
                        )["input_ids"])
                    in_lengths.append(in_len)
                toks = reg_tokenizer(
                    prompt_to_model,
                    truncation=True,
                    padding="max_length",
                    max_length=256,
                    return_tensors="pt"
                ).to(device) 
                pred_lengths = []
                for mk in models:                   # models = ["base",…,"llama"]
                   llm_name = arm_to_llm[mk]       # e.g. "gpt-4"
                   idx      = orig_model_names.index(llm_name)
                   onehot   = torch.zeros(len(orig_model_names), device=device)
                   onehot[idx] = 1.0
                   onehot   = onehot.unsqueeze(0)  # [1×25]
            
                   with torch.no_grad():
                       pred = token_length_model(
                           toks["input_ids"],
                           toks["attention_mask"],
                           onehot
                       )
                   pred_lengths.append(pred.item())
                pred_cost = [
                    input_cost_per_token[mk] * in_lengths[j]
                    + cost_per_token[mk]  * pred_lengths[j]
                    for j, mk in enumerate(models)
                ]
                print(f"Pred cost summarizer: {pred_cost}")
                task = 'summary'
                values = []
                l = getattr(dyn_vars, f'l_{i}')
                for j in range(len(summary_description_array)):
                    cont = get_context(documents,t,i,j,len(summary_description_array),len(diagnosis_description_array),len(input_reports),inp_model,dataset)
                    context.append(cont)
                if t < 2 * sum_len:
                    arm = t % sum_len
                else:
                    arm = l.select(np.array(context),pred_lengths,models,in_lengths,alpha_sum)
                plays_no[arm] += 1 
                arm_select = models[arm]
                summarizer_choice = arm_select
                print(f"Selected summarizer: {arm_select}")
                summary = get_summary(input_reports[t], arm_select)
                prompt_to_model = (
                    input_reports[t]
                    + "\n\nBelow is the summary of this report:\n\n"
                    + summary
                )
                
            
            else:
                models = ["base","finetune_med","finetune_tele","finetune_med_new","llama"]
                values=[]
                task = 'diagnosis'
                prompt_to_model = get_summary(input_reports[t],cat)
                prompt_to_model = prompt_to_model.replace('\n','')
                in_lengths = []
                for mk in models:
                    enc = arm_encoders[mk]
                    # If this is a tiktoken Encoding, use .encode(...)
                    if hasattr(enc, "encode"):
                        # tiktoken Encoding
                        in_len = len(enc.encode(prompt_to_model))
                    else:
                        # HuggingFace tokenizer
                        in_len = len(enc(
                            prompt_to_model,
                            truncation=True,
                            padding=False
                        )["input_ids"])
                    in_lengths.append(in_len)
                toks = reg_tokenizer(
                    prompt_to_model,
                    truncation=True,
                    padding="max_length",
                    max_length=256,
                    return_tensors="pt"
                ).to(device) 
                pred_lengths = []
                for mk in models:                   # models = ["base",…,"llama"]
                   llm_name = arm_to_llm[mk]       # e.g. "gpt-4"
                   idx      = orig_model_names.index(llm_name)
                   onehot   = torch.zeros(len(orig_model_names), device=device)
                   onehot[idx] = 1.0
                   onehot   = onehot.unsqueeze(0)  # [1×25]
            
                   with torch.no_grad():
                       pred = token_length_model(
                           toks["input_ids"],
                           toks["attention_mask"],
                           onehot
                       )
                   pred_lengths.append(pred.item())
                pred_cost = [
                    input_cost_per_token[mk] * in_lengths[j]
                    + cost_per_token[mk]  * pred_lengths[j]
                    for j, mk in enumerate(models)
                ]
                print(f"Pred cost diagnoser: {pred_cost}")
                documents = documents+[prompt_to_model]
                l = getattr(dyn_vars, f'l_{i}')
                for j in range(len(diagnosis_description_array)):
                    cont = get_context(documents,t,i,j,len(summary_description_array),len(diagnosis_description_array),len(input_reports),inp_model,dataset)
                    context.append(cont)
                arm = l.select(np.array(context),pred_lengths,models,in_lengths,alpha_diag)
                plays_no[arm+sum_len] += 1 
                arm_select = models[arm]
                print(f"Selected diagnoser: {arm_select}")
                
            cat = models[arm]
            dep = deploy[i]
        
            selected = arm_select
            fin_prompt = prompt_to_model
            deployment = dep[selected]
        
            if i==no_tasks-1:
                reg,reward,out_len,avg_array,all_rewards_sum,all_rewards_diag = get_regret(deployments_1,fin_prompt,task,selected,avg_array,t,all_rewards_sum,all_rewards_diag,new_labels, dataset) 
                online_trainer.add_observation(
                    prompt=fin_prompt,
                    model_name=selected,
                    actual_output_length=out_len,
                    predicted_length=pred_lengths[arm]
                )
                costs += input_cost_per_token[arm_select]*in_lengths[arm]+ cost_per_token[arm_select]* out_len
                costs_list.append(costs)
                total_costs += input_cost_per_token[arm_select]*in_lengths[arm]+ cost_per_token[arm_select]* out_len
                total_costs_list.append(total_costs)
                actual_cost = (
                    input_cost_per_token[arm_select] * in_lengths[arm]
                    + cost_per_token[arm_select] * out_len
                )
                print(f"Actual cost diagnoser: {actual_cost}")
                error = (pred_cost[arm] - actual_cost)
                cost_error_dict[arm_select].append(error)
                diag_reward_by_summ[summarizer_choice].append(reward)
                diag_regret_by_summ[summarizer_choice].append(reg)
                rewards += int(reward)
                rewards_list.append(rewards)
                print(f'Reward: {reward}')
                print(f'Regret: {reg}')
                print(plays_no)
                print("Done")
                summ+= reg
                regrets.append(summ)
                
                
                if (t + 1) % checkpoint_frequency == 0:
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    
                    save_bandit_checkpoint(
                        run_id=run,
                        round_num=t+1,
                        l_0=getattr(dyn_vars, 'l_0'),
                        l_1=getattr(dyn_vars, 'l_1'),
                        regrets=regrets,
                        rewards=rewards_list,
                        costs=costs_list,
                        costs_summarizer=costs_list_summarizer,
                        plays=plays_no,
                        avg_acc=avg_array,
                        avg_summary=summary_avg_array,
                        all_rewards_sum=all_rewards_sum,
                        all_rewards_diag=all_rewards_diag,
                        cum_reg=summ,
                        tot_r=rewards,
                        tot_c=costs,
                        tot_c_sum=costs_summarizer,
                        is_final=False
                    )
                
                if (t+1) % 10 == 0:
                    metrics = online_trainer.compute_metrics(round_num=t+1)
                    if metrics:
                        print(f"  [Round {t+1}] Cost Model - MAE: {metrics['mae']:.2f}, "
                              f"MAPE: {metrics['mape']:.2f}%, R²: {metrics['r2']:.4f}")
                
                l = getattr(dyn_vars, f'l_{i}')
                
                
                
                
                
                l = getattr(dyn_vars, f'l_{i}')
                l.update_model(np.array(context), arm, reward)
                new_rews = all_rewards_diag
                if t!= 0:
                    new_rews = (new_rews-np.mean(all_rewards_diag))/np.std(all_rewards_diag)
                index = all_rewards_diag.index(reward)
                loss = l.train(context[arm], arm, new_rews[index])
            else:
                reward,out_len,all_rewards_sum,_,all_rewards_diag = get_reward(deployment,cat,fin_prompt,task,all_rewards_sum,"",all_rewards_diag,summary)
                online_trainer.add_observation(
                    prompt=fin_prompt,
                    model_name=arm_select,
                    actual_output_length=out_len,
                    predicted_length=pred_lengths[arm]
                )
                actual_cost = (input_cost_per_token[arm_select] * in_lengths[arm]+ cost_per_token[arm_select] * out_len
                )
                print(f"Actual cost summarizer: {actual_cost}")
                error = (pred_cost[arm] - actual_cost)
                cost_error_dict[arm_select].append(error)
                l = getattr(dyn_vars, f'l_{i}')
                l.update_model(np.array(context), arm, reward)
                costs_summarizer += input_cost_per_token[arm_select]*in_lengths[arm]+ cost_per_token[arm_select]* out_len
                total_costs += input_cost_per_token[arm_select]*in_lengths[arm]+ cost_per_token[arm_select]* out_len
                costs_list_summarizer.append(costs_summarizer)
                summary_reward_sums[arm_select] += reward
                summary_counts[arm_select]      += 1
                summary_avg_array[arm_select]    = (
                    summary_reward_sums[arm_select]
                    / summary_counts[arm_select]
                )
                if t ==0:
                    all_rewards_sum.append(reward+1)
                    new_rews = all_rewards_sum
                    index = all_rewards_sum.index(reward)
                    loss = l.train(context[arm], arm, new_rews[index]/max(new_rews))
                else:
                    new_rews = all_rewards_sum
                    new_rews = (new_rews-np.mean(all_rewards_sum))/np.std(all_rewards_sum)
                    index = all_rewards_sum.index(reward)
                    loss = l.train(context[arm],arm, new_rews[index])
        
        print(f"Round {t+1} summary averages: {summary_avg_array}")

        if (t+1) % 5 == 0:
            print('{}: {:.3f}, {:.3f}, {:.3f}, {:.3f}'.format(t+1, summ, rewards, loss, total_costs)) 
    all_regrets.append(regrets)
    all_rewards.append(rewards_list)
    all_costs.append(costs_list)
    all_costs_summarizer.append(costs_list_summarizer)
    all_costs_sum.append(total_costs_list)
    all_plays.append(plays_no.copy())
    all_avg_arrays.append(avg_array.copy())
    all_avg_summary_arrays.append(summary_avg_array.copy()) 
    all_diag_reward_means = {
        arm: np.mean(diag_reward_by_summ[arm]) 
            if diag_reward_by_summ[arm] else 0.0
        for arm in diag_reward_by_summ
    }
    all_diag_reward_stds  = {
        arm: np.std(diag_reward_by_summ[arm])  
            if diag_reward_by_summ[arm] else 0.0
        for arm in diag_reward_by_summ
    }
    # pack into one dict (or DataFrame row) and save
    all_runs_diag_stats.append({
        arm: (all_diag_reward_means[arm], all_diag_reward_stds[arm])
        for arm in all_diag_reward_means
    })  
    
    save_bandit_checkpoint(
        run_id=run,
        round_num=num_rounds,
        l_0=getattr(dyn_vars, 'l_0'),
        l_1=getattr(dyn_vars, 'l_1'),
        regrets=regrets,
        rewards=rewards_list,
        costs=costs_list,
        costs_summarizer=costs_list_summarizer,
        plays=plays_no,
        avg_acc=avg_array,
        avg_summary=summary_avg_array,
        all_rewards_sum=all_rewards_sum,
        all_rewards_diag=all_rewards_diag,
        cum_reg=summ,
        tot_r=rewards,
        tot_c=costs,
        tot_c_sum=costs_summarizer,
        is_final=True
        )
    cleanup_bandit_checkpoints(run_id=run)
    online_trainer.final_update()
    online_trainer.print_summary()
    online_trainer.save_stats(f"med_results/cost_model_run_{run}_neulinucb")
    
    run_cost_summary = {
        'overall': online_trainer.compute_metrics(),
        'per_model': online_trainer.compute_per_model_metrics()
    }
    all_cost_model_summaries.append(run_cost_summary)

import pickle
import pandas as pd
avg_error = { arm: sum(errs)/len(errs) 
              for arm, errs in cost_error_dict.items()
              if errs }
regrets_arr = np.array(all_regrets)     # shape (no_runs, num_rounds)
rewards_arr = np.array(all_rewards)
costs_arr   = np.array(all_costs)
costs_summarizer_arr   = np.array(all_costs_summarizer)
costs_sum_arr   = np.array(all_costs_sum)
plays_arr   = np.array(all_plays)       # shape (no_runs, num_arms)
avg_df      = pd.DataFrame(all_avg_arrays)  # columns=model names

mean_df = pd.DataFrame(
    { arm: [stats[arm][0] for stats in all_runs_diag_stats]
      for arm in diag_reward_by_summ }
)
std_df  = pd.DataFrame(
    { arm: [stats[arm][1] for stats in all_runs_diag_stats]
      for arm in diag_reward_by_summ }
)
overall_mean = mean_df.mean(axis=0).to_dict()
overall_std  = std_df.std(axis=0).to_dict()


regrets_mean = regrets_arr.mean(axis=0)
regrets_std  = regrets_arr.std(axis=0)
rewards_mean = rewards_arr.mean(axis=0)
rewards_std  = rewards_arr.std(axis=0)
costs_mean   = costs_arr.mean(axis=0)
costs_std    = costs_arr.std(axis=0)
costs_summarizer_mean   = costs_summarizer_arr.mean(axis=0)
costs_summarizer_std    = costs_summarizer_arr.std(axis=0)
costs_total_mean   = costs_sum_arr.mean(axis=0)  
costs_total_std    = costs_sum_arr.std(axis=0)
plays_mean = plays_arr.mean(axis=0)
plays_std  = plays_arr.std(axis=0)
avg_mean = avg_df.mean(axis=0).to_dict()
avg_std  = avg_df.std(axis=0).to_dict()

summary_avg_df   = pd.DataFrame(all_avg_summary_arrays)
summary_mean     = summary_avg_df.mean(axis=0).to_dict()
summary_std      = summary_avg_df.std(axis=0).to_dict()

import pickle
pickle.dump(regrets_mean, open("regrets_mean_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(regrets_std,  open("regrets_std_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(rewards_mean, open("rewards_mean_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(rewards_std,  open("rewards_std_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(costs_mean,   open("costs_mean_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(costs_std,    open("costs_std_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(costs_summarizer_mean,   open("costs_summarizer_mean_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(costs_summarizer_std,    open("costs_summarizer_std_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(costs_total_mean,   open("costs_total_mean_med_budgeted_neulinucb_2.pkl","wb"))  
pickle.dump(costs_total_std,    open("costs_total_std_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(plays_mean,   open("plays_mean_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(plays_std,    open("plays_std_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(avg_mean,     open("avg_accuracy_mean_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(avg_std,      open("avg_accuracy_std_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(summary_mean,     open("sum_avg_accuracy_mean_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(summary_std,      open("sum_avg_accuracy_std_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(avg_error,      open("avg_cost_err_med_budgeted_neulinucb_2.pkl","wb"))
pickle.dump(overall_mean,     open("summ_diag_med_eff_mean_rew_2_neulinucb.pkl","wb"))
pickle.dump(overall_std,      open("summ_diag_med_eff_std_rew_2_neulinucb.pkl","wb"))
print(f"Final mean regret: {regrets_mean[-1]}")
print(f"Final mean reward: {rewards_mean[-1]}")
print(f"Final mean cost total: {costs_total_mean[-1]}")
print(f"Final mean cost (diag): {costs_mean[-1]}")
print(f"Final mean summarizer cost: {costs_summarizer_mean[-1]}")
print(f"Final mean plays: {plays_mean}")
print(f"Final mean average array: {avg_mean}")
print(f"Final mean summary average array: {summary_mean}")
print(f"Final mean cost error: {avg_error}")
print(f"Final mean summary diag effect: {overall_mean}")


print("All runs complete. Summary pickles written.")