import numpy as np
from src.regrets.final_rand_med import final_eval
from src.prompts.prompt_maker_seq import input_maker
from src.regrets.optimal_rand_tele import opt_eval
from src.regrets.sum_call import get_summary
import pickle
from transformers import AutoConfig, AutoTokenizer
import scipy as sp
import torch
import torch.nn as nn
import torch.optim as optim
from backpack import backpack, extend
from backpack.extensions import BatchGrad
import argparse
from src.embedding.embed_tele import get_context
from src.token_prediction.tok_length_predict import BertRegressionModel  
from cost_trainer import OnlineCostModelTrainer
import os
import glob
import time

start_time = time.time()
arm_to_llm = {
    "base"            : "gpt-3.5-turbo",
    "assistants"      : "gpt-3.5-turbo",
    "finetune_med"    : "gpt-4",
    "finetune_tele"   : "gpt-4",
    "finetune_med_new": "gpt-4",
    "llama"           : "llama-13b"
}
cost_per_token = {
    "base"            : 0.000004,
    "assistants"      : 0.000004,
    "finetune_med"    : 0.00001,
    "finetune_tele"   : 0.00001,
    "finetune_med_new": 0.00001,
    "llama"           : 0.00000071
}

input_cost_per_token = {
    "base"            : 0.0000001,
    "assistants"      : 0.0000001,
    "finetune_med"    : 0.00000025,
    "finetune_tele"   : 0.00000025,
    "finetune_med_new": 0.00000025,
    "llama"           : 0.00000071
}

with open('diagnoses_150.pkl', 'rb') as file: 

    # Call load method to deserialze 
    diagnoses = pickle.load(file)

with open('input_reports_150.pkl', 'rb') as file: 

    # Call load method to deserialze 
    input_reports = pickle.load(file)

new_labels = diagnoses

def get_optimal_super_arm_reward(deployments,prompt,task):
    
    #call Azure API to get optimal reward by trying out all combinations

    return opt_eval(deployments, prompt,task)

def get_reward(deployment,cat,prompt,task,all_rewards_sum,all_rewards_debate,all_rewards_diag,summary):
    
    #call Azure API to get final reward
    
    return final_eval(deployment, cat, prompt,task,all_rewards_sum,all_rewards_debate,all_rewards_diag,summary)
    

def get_regret(deployments,prompt,task,selected,avg_array,t,all_rewards_sum,all_rewards_diag,labels,dataset):
    return opt_eval(deployments, prompt,task,selected,avg_array,t,all_rewards_sum,all_rewards_diag,labels,dataset)

reg_model_name  = "bert-base-uncased"
reg_config      = AutoConfig.from_pretrained(reg_model_name)
reg_tokenizer   = AutoTokenizer.from_pretrained(reg_model_name)
import json
with open("model_names.json") as f:
    orig_model_names = json.load(f)   # e.g. ["RWKV-4-Raven-14B","alpaca-13b",…,"gpt-4",…,"gpt-3.5-turbo",…,"llama-2-7b-chat",…]
num_models = len(orig_model_names)    
device = "cuda" if torch.cuda.is_available() else "cpu"
token_length_model = BertRegressionModel(reg_config, reg_model_name,hidden_dim=128,num_models=num_models).to(device)
token_length_model.load_state_dict(
    torch.load("best_length_model.pth", map_location=device)
)
token_length_model.eval()
online_trainer = OnlineCostModelTrainer(
    model=token_length_model,
    tokenizer=reg_tokenizer,
    orig_model_names=orig_model_names,
    arm_to_llm=arm_to_llm,
    cost_per_token=cost_per_token,
    input_cost_per_token=input_cost_per_token,
    device=device,
    checkpoint_dir="cost_model_checkpoints_neucb_joint_med",
    lr=1e-6,
    update_freq=5
)
all_cost_model_summaries = []
checkpoint_frequency = 50
#cost_per_token = {
#    "base"            : 0.0000015,   # GPT-3.5 Turbo
#    "assistants"      : 0.0000015,   # GPT-3.5 Turbo
#    "finetune_med"    : 0.00001,   # GPT-4
#    "finetune_tele"   : 0.00001,   # GPT-4
#    "finetune_med_new": 0.00001,   # GPT-4
#    "llama"           : 0.00000071   # Llama-13b
#}
def save_bandit_checkpoint(run_id, round_num, u_sum, u_diag,
                          regrets, rewards, costs,
                          plays, avg_acc, avg_summary,
                          all_rewards_sum, all_rewards_diag,
                          cum_reg, tot_r, tot_c,
                          is_final=False):
    """Save NeuralUCB joint bandit models and training state"""
    checkpoint_dir = "bandit_checkpoints_neucb_joint_med"
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    filename = f"run_{run_id}_final.pth" if is_final else f"run_{run_id}_round_{round_num}.pth"
    filepath = os.path.join(checkpoint_dir, filename)
    
    # Move everything to CPU before saving
    checkpoint = {
        'run_id': run_id,
        'round_num': round_num,
        'u_sum_state': {
            'net_state_dict': {k: v.cpu() for k, v in u_sum.func.state_dict().items()},
            'U': u_sum.U.cpu(),  # ✅ Move to CPU
            'context_list': [c.cpu() if isinstance(c, torch.Tensor) else c for c in u_sum.context_list],
            'reward': u_sum.reward
        },
        'u_diag_state': {
            'net_state_dict': {k: v.cpu() for k, v in u_diag.func.state_dict().items()},
            'U': u_diag.U.cpu(),  # ✅ Move to CPU
            'context_list': [c.cpu() if isinstance(c, torch.Tensor) else c for c in u_diag.context_list],
            'reward': u_diag.reward
        },
        'regrets': regrets,
        'rewards': rewards,
        'costs': costs,
        'plays': plays.copy() if isinstance(plays, np.ndarray) else plays,
        'avg_acc': avg_acc,
        'avg_summary': avg_summary,
        'all_rewards_sum': all_rewards_sum,
        'all_rewards_diag': all_rewards_diag,
        'cum_reg': cum_reg,
        'tot_r': tot_r,
        'tot_c': tot_c
    }
    
    torch.save(checkpoint, filepath)
    print(f"  Saved joint bandit checkpoint: {filepath}")
    return filepath

def cleanup_bandit_checkpoints(run_id):
    """Delete intermediate checkpoints, keep only final"""
    checkpoint_dir = "bandit_checkpoints_neucb_joint_med"
    pattern = os.path.join(checkpoint_dir, f"run_{run_id}_round_*.pth")
    intermediate_checkpoints = glob.glob(pattern)
    
    for cp_path in intermediate_checkpoints:
        try:
            os.remove(cp_path)
        except Exception as e:
            print(f"  Warning: Could not delete {cp_path}: {e}")


        
dataset = "medical"
input_reports = input_maker("seq",dataset,input_reports)

class Network(nn.Module):
    def __init__(self, dim = 100, hidden_size=100):
        super(Network, self).__init__()

        self.model = nn.Sequential(nn.Linear(dim, hidden_size),nn.ReLU(),nn.Linear(hidden_size, 1))


class NeuralUCBDiag:
    def __init__(self, style, dim, lamdba=1, nu=1, hidden=100):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if self.device =='cuda':
            self.func = extend(Network(dim, hidden_size=hidden).model.cuda())
        else:
            self.func = Network(dim, hidden_size=hidden)
        self.context_list = []
        self.reward = []
        self.lamdba = lamdba

        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if self.device =='cuda':
            self.U = lamdba * torch.ones((self.total_param,)).cuda()
            self.U_random = lamdba * torch.ones((self.total_param,)).cuda()
        else:
            self.U = lamdba * torch.ones((self.total_param,))
            self.U_random = lamdba * torch.ones((self.total_param,))
        self.nu = nu
        self.num_rounds = 100
        self.style = style
        
        if self.device =='cuda':
            self.loss_func = extend(nn.MSELoss().cuda())
        else:
            self.loss_func = extend(nn.MSELoss())
        self.len = 0

    def update_params(self,g_list):
        for g in g_list:
            self.U += g * g
        return 0
    
    
    def selection(self,context,style):
        tensor = torch.from_numpy(np.array(context)).float().cuda()
        mu = self.func(tensor)
        self.func.zero_grad()
        mu.backward(retain_graph=True)
        g = torch.cat([p.grad.flatten().detach() for p in self.func.parameters()])
        sigma2 = self.lamdba * self.nu * g * g / self.U
        sigma = torch.sqrt(torch.sum(sigma2))
        if style == 'ucb':
            sample_r = mu.item() +0.5* sigma.item() 
        else:
            std = (0.7 * sigma).clamp(min=1e-6)
            sample_r = torch.normal(mu.view(-1), std.view(-1))
        self.U += g * g
        return sample_r
    
    def train(self, context, reward):
        self.context_list.append(torch.from_numpy(context.reshape(1, -1)).float())
        self.reward.append(reward)
        optimizer = optim.SGD(self.func.parameters(), lr=1e-2, weight_decay=self.lamdba)
        length = len(self.reward)
        index = np.arange(length)
        np.random.shuffle(index)
        cnt = 0
        tot_loss = 0
        while True:
            batch_loss = 0
            for idx in index:
                c = self.context_list[idx]
                r = self.reward[idx]
                optimizer.zero_grad()
                if self.device =='cuda':
                    delta = self.func(c.cuda()) - torch.tensor(float(r))
                else:
                    delta = self.func(c) - torch.tensor(float(r))
                loss = delta * delta
                loss.backward()
                optimizer.step()
                batch_loss += loss.item()
                tot_loss += loss.item()
                cnt += 1
                if cnt >= 5:
                    return tot_loss / 5
            if batch_loss / length <= 1e-3:
                return batch_loss / length

parser = argparse.ArgumentParser(description='NeuralUCB')

 
parser.add_argument('--size', default=150, type=int, help='number of rounds')
parser.add_argument('--nu', type=float, default=1, metavar='v', help='nu for control variance')
parser.add_argument('--lamdba', type=float, default=1, metavar='l', help='lambda for regularzation')
parser.add_argument('--hidden', type=int, default=50, help='network hidden size')
parser.add_argument('--style', default='ts', metavar='ts|ucb', help='TS or UCB')
parser.add_argument('--number_tasks', default=2)
parser.add_argument('--no_runs',   default=3,   type=int, help='how many independent runs')
parser.add_argument('--alpha',   default=100,   type=int, help='cost accuracy tradeoff weight')



args, unknown = parser.parse_known_args()
no_tasks = args.number_tasks
no_runs = args.no_runs
alpha = args.alpha
all_regrets     = []   # list of length no_runs, each an array of length num_rounds
all_rewards     = []
all_costs       = []
all_costs_summarizer = []
all_plays       = []   # list of length no_runs, each an array of length num_arms
all_avg_arrays  = []   # list of dicts
all_avg_summary_arrays = []

class DynamicVariables:
    pass

num_rounds = args.size


summary_description_array = ["Summarize the main points of the medical report of a patient, this summary will be used for research purposes only.",
                             "You will take the role of an diagnosis analyst. Use your knowledge base to summarize inputted medical reports.","Summarize the main points of the medical report of a patient, this summary will be used for research purposes only.",
                             "You will take the role of an diagnosis analyst. Use your knowledge base to summarize inputted medical reports.","Summarize the main points of the medical report of a patient, this summary will be used for research purposes only."]

diagnosis_description_array = ["General use LLM which does not specialize in any task specifically.","LLM specializing on medical reports and trained to do medical diagnosis for research purposes.",
                               "LLM specializing for answering multiple choice telecommunications questions.", "LLM specializing on medical reports and trained to do medical diagnosis for research purposes.", "General use LLM which does not specialize in any task specifically."]

deployments_1 = {"base" : ("gpt-35-turbo","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,vertebral/basilar stenosis,hypoxia,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"), 
                 "finetune_med" : ("Med","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,vertebral/basilar stenosis,hypoxia,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"),
                 "finetune_tele" : ("Tele","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,vertebral/basilar stenosis,hypoxia,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"), 
                 "finetune_med_new": ("Med_New","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,hypoxia,vertebral/basilar stenosis,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication"),
                 "llama": ("llama","You will take the role of an medical agent whose primary goal is to give diagnosis based on medical reports for experimentation and research purposes only. You need to make inferences based on the provided report to make diagnosis predictions. Output at most 2 diagnoses. If you identify multiple diagnosis output them in a comma separated format like heart failure,colon cancer. Your given diagnoses for the patient must be one of the following: diabetes mellitus,huntington's disease,sepsis,encephalopathy,pulmonary embolism,pulmonary edema,tamponade,mitral stenosis,congestive heart failure,chronic obstructive pulmonary disease,abdominal thoracic aneurysm,neurosarcoidosis,renal failure,svc syndrome,urosepsis,acute myocardial infarction,acute coronary syndrome,carotid stenosis,aortic stenosis,coronary artery disease,osteoarthritis,aortic insufficiency,unstable angina/cath,hyperlipidemia,syncope,complete heart block,intravascular coagulation,septic shock,hepatic failure,pneumonia,pancreatitis,anemia,catheter tip infection,coma,urinary tract infection,wound infection,cerebral artery infarction,hyponatremia,cardiomyopathy,hypoxia,vertebral/basilar stenosis,sick sinus syndrome,pulmonary congestion,aseptic meningitis,neutropenia,cellulitis,cirrhosis of liver,liver failure,pericardial effusion,aortic valve dysfunction,venous thrombosis,respiratory failure,benzodiazepine overdose,vena cava obstruction,valvular heart disease,v-tach,aortic dissection,opiate intoxication")}


deployments_0 = {"base" : ("gpt-35-turbo","You are to summarize an inputted medical report, this summary will be used for research purposes only."), "assistants" : ("Assistant","You are to summarize an inputted medical report, this summary will be used for research purposes only."),"finetune_med" : ("Med","You are to summarize an inputted medical report, this summary will be used for research purposes only."),"finetune_tele" : ("Tele","You are to summarize an inputted medical report, this summary will be used for research purposes only."),"llama" : ("llama","You are to summarize an inputted medical report, this summary will be used for research purposes only."),}


emb_size = 384
emb_size_pair = 384  # 768
#l_joint = NeuralUCBDiag(args.style, emb_size_pair, args.lamdba, args.nu, args.hidden)
total_len = len(summary_description_array)+len(diagnosis_description_array)

#input_cost_per_token = {
#    "base"            : 0.0000005,
#    "assistants"      : 0.0000005,
#    "finetune_med"    : 0.00000025,
#    "finetune_tele"   : 0.00000025,
#    "finetune_med_new": 0.00000025,
#    "llama"           : 0.00000071
#}


deploy = [deployments_0,deployments_1]
cat = ''

input_reports = list(input_reports)
documents = summary_description_array+diagnosis_description_array+input_reports
from sentence_transformers import SentenceTransformer
inp_model = SentenceTransformer("paraphrase-MiniLM-L6-v2")


import tiktoken
cost_error_dict = { arm: [] for arm in arm_to_llm.keys() }

sum_len = len(summary_description_array)
from collections import defaultdict

summary_models   = ["base","assistants","finetune_med","finetune_tele","llama"]
diagnosis_models = ["base","finetune_med","finetune_tele","finetune_med_new","llama"]
super_arms = [(s, d) for s in summary_models for d in diagnosis_models]
#plays_no   = np.ones(len(super_arms))

# Track online ratio of diagnosis output/input tokens (initialize at 0.5)
diag_out_in_ratio = defaultdict(lambda: 0.5)

for run in range(no_runs):
    print(f"\n===== Starting run {run+1}/{no_runs} =====")
    regrets = []
    costs_list = []
    costs = 0
    costs_summarizer = 0
    costs_list_summarizer = []
    dyn_vars = DynamicVariables()
    all_rewards_sum = []
    all_rewards_diag = []
    #plays_no = np.ones(total_len)
    plays_no   = np.ones(len(super_arms))
    summary_reward_sums = { arm: 0.0 for arm in deployments_0.keys() }
    summary_counts      = { arm:   0   for arm in deployments_0.keys() }
    summary_avg_array   = { arm: 0.0 for arm in deployments_0.keys() }
    summ = 0
    rew = 0
    rewards_list = []
    rewards = 0
    total_reward = 0
    avg_array = {"gpt-35-turbo":0,"Med":0,"Tele":0,"Med_New":0,"llama":0}
    #for i in range(no_tasks):
    #    setattr(dyn_vars, f'l_{i}', NeuralUCBDiag(args.style,emb_size, args.lamdba, args.nu, args.hidden))
    #l_joint = NeuralUCBDiag(args.style, emb_size_pair, args.lamdba, args.nu, args.hidden)
    u_sum  = NeuralUCBDiag(args.style, emb_size_pair, args.lamdba, args.nu, args.hidden)
    u_diag = NeuralUCBDiag(args.style, emb_size_pair, args.lamdba, args.nu, args.hidden)
    
    openai_models = {"gpt-3.5-turbo","gpt-4"}
    encodings = { m: tiktoken.encoding_for_model(m) for m in openai_models }
    
    # for llama we use HuggingFace:
    llama_tok = AutoTokenizer.from_pretrained("openlm-research/open_llama_13b")
    
    # now build a lookup that is either a tiktoken encoder or HF tokenizer
    arm_encoders = {}
    for mk,llm in arm_to_llm.items():
        if llm in encodings:
            arm_encoders[mk] = encodings[llm]
        else:
            arm_encoders[mk] = llama_tok
    
    for t in range(num_rounds):
        report_text = input_reports[t]
        toks_report = reg_tokenizer(report_text, truncation=True, padding="max_length",
                            max_length=256, return_tensors="pt").to(device)
        pair_contexts_s = []
        pair_contexts_d = []

        pair_scores   = []
        pair_costs_pred = []
        for (s_arm, d_arm) in super_arms:
            # 1) Build pair context
            cont_s = get_context(documents, t, 0, summary_models.index(s_arm),
                                 len(summary_description_array), len(diagnosis_description_array),
                                 0, inp_model, dataset)
            cont_d = get_context(documents, t, 1, diagnosis_models.index(d_arm),
                                 len(summary_description_array), len(diagnosis_description_array),
                                 0, inp_model, dataset)
            #cont_d = get_context(documents, t, 1, diagnosis_models.index(d_arm),
            #                     len(summary_description_array), len(diagnosis_description_array),
            #                     0, inp_model, dataset)
            #pair_ctx = np.concatenate([cont_s, cont_d], axis=-1)
            pair_contexts_s.append(cont_s)
            pair_contexts_d.append(cont_d)
            enc_s = arm_encoders[s_arm]
            if hasattr(enc_s, "encode"):
                in_len_sum = len(enc_s.encode(report_text))
            else:
                in_len_sum = len(enc_s(report_text, truncation=True, padding=False)["input_ids"])
        
            # summarizer output length
            s_llm = arm_to_llm[s_arm]
            s_idx = orig_model_names.index(s_llm)
            s_onehot = torch.zeros(len(orig_model_names), device=device); s_onehot[s_idx] = 1.0
            s_onehot = s_onehot.unsqueeze(0)
        
            with torch.no_grad():
                out_len_sum_pred = token_length_model(
                    toks_report["input_ids"], toks_report["attention_mask"], s_onehot
                ).item()
                
                
            d_llm = arm_to_llm.get(d_arm, arm_to_llm["base"]) if d_arm in arm_to_llm else "gpt-3.5-turbo"
            d_idx = orig_model_names.index(d_llm)
            d_onehot = torch.zeros(len(orig_model_names), device=device); d_onehot[d_idx] = 1.0
            d_onehot = d_onehot.unsqueeze(0)
        
            with torch.no_grad():
                out_len_diag_pred = token_length_model(
                    toks_report["input_ids"], toks_report["attention_mask"], d_onehot
                ).item()
        
            # diagnoser input length ≈ predicted summary length
            in_len_diag_est = int(round(out_len_sum_pred))
            
            pred_cost_pair = (
                input_cost_per_token[s_arm] * in_len_sum
              + cost_per_token[s_arm]      * out_len_sum_pred
              + input_cost_per_token[d_arm] * in_len_diag_est
              + cost_per_token[d_arm]       * out_len_diag_pred
            )
            pair_costs_pred.append(pred_cost_pair)
        
            # 3) Joint value (NEUCB) with budget
            #v = l_joint.selection(cont_s, args.style)
            v_s = u_sum.selection(cont_s,  args.style)
            v_d = u_diag.selection(cont_d, args.style)
            v_comb = (float(v_s) + float(v_d))
            #v_f = v.item() if torch.is_tensor(v) else float(v)
            #pair_scores.append(v_f - alpha * pred_cost_pair)
            pair_scores.append(v_comb - alpha * pred_cost_pair)
            
        if args.style == "ts":
            scores_np = [s.cpu().detach().numpy() if torch.is_tensor(s) else s for s in pair_scores]
            best_idx = np.random.choice(np.where(np.array(scores_np) == np.array(scores_np).max())[0])
        else:  # ucb
            scores_f = [float(s) if torch.is_tensor(s) else float(s) for s in pair_scores]
            best_idx = int(np.argmax(scores_f))
        
        s_arm, d_arm = super_arms[best_idx]
        plays_no[best_idx] += 1
        print(f"Chosen pair: summarizer={s_arm}, diagnoser={d_arm}, pred_cost={pair_costs_pred[best_idx]}")
        
        summary_dep   = deployments_0[s_arm]
        summary_choice = s_arm
        summary = get_summary(report_text, summary_choice)  # uses your helper
        summary_clean = summary.replace('\n','')
        
        if hasattr(arm_encoders[s_arm], "encode"):
            in_len_sum_actual = len(arm_encoders[s_arm].encode(report_text))
        else:
            in_len_sum_actual = len(arm_encoders[s_arm](report_text, truncation=True, padding=False)["input_ids"])
        
        # If final_eval returns out_len, great; otherwise estimate via token_length_model as above.
        reward_sum, out_len_sum_actual, all_rewards_sum,_, all_rewards_diag = get_reward(
            summary_dep, s_arm, report_text, 'summary', all_rewards_sum, "",all_rewards_diag, summary
        )
        online_trainer.add_observation(
            prompt=report_text,
            model_name=s_arm,
            actual_output_length=out_len_sum_actual,
            predicted_length=out_len_sum_pred  # from earlier prediction
        )
            
            
        diag_dep = deployments_1[d_arm]
        fin_prompt_diag = (summary_clean)  # diagnoser sees the summary
        
        reg, reward_diag, out_len_diag_actual, avg_array, all_rewards_sum, all_rewards_diag = get_regret(
            deployments_1, fin_prompt_diag, 'diagnosis', d_arm, avg_array, t,
            all_rewards_sum, all_rewards_diag, new_labels, dataset
        )
        online_trainer.add_observation(
            prompt=fin_prompt_diag,
            model_name=d_arm,
            actual_output_length=out_len_diag_actual,
            predicted_length=out_len_diag_pred  # from earlier prediction
        )
        
        if hasattr(arm_encoders[d_arm], "encode"):
            in_len_diag_actual = len(arm_encoders[d_arm].encode(summary_clean))
        else:
            in_len_diag_actual = len(arm_encoders[d_arm](summary_clean, truncation=True, padding=False)["input_ids"])
        
        actual_cost_pair = (
            input_cost_per_token[s_arm] * in_len_sum_actual
          + cost_per_token[s_arm]      * out_len_sum_actual
          + input_cost_per_token[d_arm] * in_len_diag_actual
          + cost_per_token[d_arm]       * out_len_diag_actual
        )
        costs += actual_cost_pair
        costs_list.append(costs)

        rewards += int(reward_diag)
        rewards_list.append(rewards)
        summ += reg
        regrets.append(summ)
        print(f"Reward: {reward_diag} | Regret: {reg} | Actual total cost: {actual_cost_pair}")
        
        
        # ADD THIS:
        if (t + 1) % checkpoint_frequency == 0:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
#            online_trainer.save_checkpoint(
#                run_id=run,
#                round_num=t+1,
#                is_final=False
#            )
            
            save_bandit_checkpoint(
                run_id=run,
                round_num=t+1,
                u_sum=u_sum,
                u_diag=u_diag,
                regrets=regrets,
                rewards=rewards_list,
                costs=costs_list,
                plays=plays_no,
                avg_acc=avg_array,
                avg_summary=summary_avg_array,
                all_rewards_sum=all_rewards_sum,
                all_rewards_diag=all_rewards_diag,
                cum_reg=summ,
                tot_r=rewards,
                tot_c=costs,
                is_final=False
            )
        
        if (t+1) % 10 == 0:
            metrics = online_trainer.compute_metrics(round_num=t+1)
            if metrics:
                print(f"  [Round {t+1}] Cost Model - MAE: {metrics['mae']:.2f}, "
                      f"MAPE: {metrics['mape']:.2f}%, R²: {metrics['r2']:.4f}")
        
        
        
        if t != 0 and len(all_rewards_diag) > 1 and np.std(all_rewards_diag) > 0:
            norm_diag = (all_rewards_diag - np.mean(all_rewards_diag)) / np.std(all_rewards_diag)
        else:
            norm_diag = all_rewards_diag if len(all_rewards_diag) else [reward_diag]
        
        if t != 0 and len(all_rewards_sum) > 1 and np.std(all_rewards_sum) > 0:
            norm_sum = (all_rewards_sum - np.mean(all_rewards_sum)) / np.std(all_rewards_sum)
        else:
            norm_sum = all_rewards_sum if len(all_rewards_sum) else [reward_sum/100]
        
        idx_diag = len(all_rewards_diag) - 1
        idx_sum = len(all_rewards_sum) - 1
        #loss = l_joint.train(pair_contexts[best_idx], norm_diag[idx_diag])
        _ = u_sum.train(pair_contexts_s[best_idx], norm_sum[idx_sum])
        _ = u_diag.train(pair_contexts_d[best_idx], norm_diag[idx_diag])

        
        
        
        if (t+1) % 5 == 0:
            print('{}: {:.3f}, {:.3f}'.format(t+1, summ, rewards))
            

    all_regrets.append(regrets)
    all_rewards.append(rewards_list)
    all_costs.append(costs_list)
    all_costs_summarizer.append(costs_list_summarizer)
    all_plays.append(plays_no.copy())
    all_avg_arrays.append(avg_array.copy())
    all_avg_summary_arrays.append(summary_avg_array.copy())
    #all_avg_summary_arrays.append(summary_avg_array.copy())
    
    # ADD THIS:
    save_bandit_checkpoint(
        run_id=run,
        round_num=num_rounds,
        u_sum=u_sum,
        u_diag=u_diag,
        regrets=regrets,
        rewards=rewards_list,
        costs=costs_list,
        plays=plays_no,
        avg_acc=avg_array,
        avg_summary=summary_avg_array,
        all_rewards_sum=all_rewards_sum,
        all_rewards_diag=all_rewards_diag,
        cum_reg=summ,
        tot_r=rewards,
        tot_c=costs,
        is_final=True
    )
    
    cleanup_bandit_checkpoints(run_id=run)
    
    online_trainer.final_update()
    #online_trainer.save_checkpoint(run_id=run, round_num=num_rounds, is_final=True)
    #online_trainer.cleanup_intermediate_checkpoints(run_id=run)
    online_trainer.print_summary()
    #online_trainer.save_stats(f"med_results/cost_model_run_{run}_neucb_joint")
    
    run_cost_summary = {
        'overall': online_trainer.compute_metrics(),
        'per_model': online_trainer.compute_per_model_metrics()
    }
    all_cost_model_summaries.append(run_cost_summary)


import pickle
import pandas as pd
avg_error = { arm: sum(errs)/len(errs) 
              for arm, errs in cost_error_dict.items()
              if errs }
regrets_arr = np.array(all_regrets)     # shape (no_runs, num_rounds)
rewards_arr = np.array(all_rewards)
costs_arr   = np.array(all_costs)
#costs_summarizer_arr   = np.array(all_costs_summarizer)
plays_arr   = np.array(all_plays)       # shape (no_runs, num_arms)
#avg_df      = pd.DataFrame(all_avg_arrays)  # columns=model names

regrets_mean = regrets_arr.mean(axis=0)
regrets_std  = regrets_arr.std(axis=0)
rewards_mean = rewards_arr.mean(axis=0)
rewards_std  = rewards_arr.std(axis=0)
costs_mean   = costs_arr.mean(axis=0)
costs_std    = costs_arr.std(axis=0)
# costs_summarizer_mean   = costs_summarizer_arr.mean(axis=0)
# costs_summarizer_std    = costs_summarizer_arr.std(axis=0)
plays_mean = plays_arr.mean(axis=0)
plays_std  = plays_arr.std(axis=0)
# avg_mean = avg_df.mean(axis=0).to_dict()
# avg_std  = avg_df.std(axis=0).to_dict()

summary_avg_df   = pd.DataFrame(all_avg_summary_arrays)
summary_mean     = summary_avg_df.mean(axis=0).to_dict()
summary_std      = summary_avg_df.std(axis=0).to_dict()

import pickle

pickle.dump(regrets_mean, open("med_results/regrets_mean_med_budgeted_neucb_joint_2.pkl","wb"))
pickle.dump(regrets_std,  open("med_results/regrets_std_med_budgeted_neucb_joint_2.pkl","wb"))
pickle.dump(rewards_mean, open("med_results/rewards_mean_med_budgeted_neucb_joint_2.pkl","wb"))
pickle.dump(rewards_std,  open("med_results/rewards_std_med_budgeted_neucb_joint_2.pkl","wb"))
pickle.dump(costs_mean,   open("med_results/costs_mean_med_budgeted_neucb_joint_2.pkl","wb"))
pickle.dump(costs_std,    open("med_results/costs_std_med_budgeted_neucb_joint_2.pkl","wb"))
pickle.dump(plays_mean,   open("med_results/plays_mean_med_budgeted_neucb_joint_2.pkl","wb"))
pickle.dump(plays_std,    open("med_results/plays_std_med_budgeted_neucb_joint_2.pkl","wb"))



# ADD THIS:
pickle.dump(all_cost_model_summaries, 
    open("cost_model_all_runs_neucb_joint_med.pkl", "wb"))

#if all_cost_model_summaries:
#    print("\n" + "="*70)
#    print("AGGREGATE COST MODEL PERFORMANCE ACROSS ALL RUNS")
#    print("="*70)
#    
#    aggregate_metrics = {}
#    for metric in ['mae', 'rmse', 'mape', 'r2']:
#        values = [run['overall'][metric] for run in all_cost_model_summaries 
#                  if run['overall']]
#        if values:
#            aggregate_metrics[f'{metric}_mean'] = np.mean(values)
#            aggregate_metrics[f'{metric}_std'] = np.std(values)
#    
#    if aggregate_metrics:
#        print(f"\nMAE:  {aggregate_metrics['mae_mean']:.2f} ± "
#              f"{aggregate_metrics['mae_std']:.2f} tokens")
#        print(f"MAPE: {aggregate_metrics['mape_mean']:.2f} ± "
#              f"{aggregate_metrics['mape_std']:.2f}%")
#        print(f"R²:   {aggregate_metrics['r2_mean']:.4f} ± "
#              f"{aggregate_metrics['r2_std']:.4f}")
#        
#        pickle.dump(aggregate_metrics, 
#                open("cost_model_aggregate_neucb_joint_med.pkl", "wb"))

end_time = time.time()
print(f"Total runtime {args.no_runs} runs: {end_time-start_time:.2f} seconds")

print(f"Final mean regret: {regrets_mean[-1]}")


print(f"Final mean regret: {regrets_mean[-1]}")
print(f"Final mean reward: {rewards_mean[-1]}")
print(f"Final mean cost: {costs_mean[-1]}")
#print(f"Final mean summarizer cost: {costs_summarizer_mean[-1]}")
print(f"Final mean plays: {plays_mean}")
#print(f"Final mean average array: {avg_mean}")
#print(f"Final mean summary average array: {summary_mean}")
#print(f"Final mean cost error: {avg_error}")

print("All runs complete. Summary pickles written.")