# -*- coding: utf-8 -*-
"""
Created on Sun Sep 21 11:06:59 2025

@author: baran
"""

import numpy as np
import pickle
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from backpack import extend
from prompt_maker import input_maker
from src.embedding.embed_tele import get_context
#from src.regrets.optimal_rand_seq_tele import opt_eval
#from src.regrets.sum_call_seq import get_summary
from utils.helper import opt_eval, get_summary
from src.regrets.final_rand_tele import final_eval_telecom
import time
from cost_trainer import OnlineCostModelTrainer
import os
import glob


# ─── STEP 1: Telecom dataset ────────────────────────────────────────────────────
input_reports, labels, explanations = input_maker("seq", "telecom","")
dataset = "telecom"

# ─── STEP 2: Description arrays ─────────────────────────────────────────────────
summary_description_array = [
    "Summarize the telecommunications question and its options concisely for analysis.",
    "Provide a brief recap of the telecom question and choices for researchers.",
    "You will take the role of a telecom-specialist summarizer. Summarize the question and answer options.",
    "Produce a short summary of the telecom question and all choices.",
    "Present the telecom question and its multiple-choice options in a concise summary."
]

diagnosis_description_array = [
    "Answer the telecom MCQ strictly 'option {i}' for this question.",
    "Provide the MCQ answer (1–4) for this telecom question.",
    "Output the telecom MCQ response as 'option {i}'.",
    "Select the correct option (1–4) for the telecommunications question.",
    "Choose the telecom MCQ answer and output 'option {i}'."
]

explanation_description_array = [
    "Explain in detail why the chosen telecom MCQ answer is correct.",
    "Provide a step-by-step rationale for why the selected answer is correct.",
    "As a telecom expert, justify why the chosen MCQ option is right.",
    "Offer a clear explanation for why the selected telecom answer is correct.",
    "Give a detailed rationale for why the chosen option is correct."
]

# ─── STEP 3: Deployment instructions per arm ────────────────────────────────────
documents = summary_description_array+ diagnosis_description_array+ explanation_description_array+ list(input_reports)

# ─── STEP 5: Deployment instructions per arm ───────────────────────────────────────────
deployments_summarizer = {
    "base"            : ("gpt-35-turbo", "You are to summarize a telecom question and its options."),
    "assistants"      : ("Assistant",     "You are to summarize a telecom question and its options."),
    "finetune_med"    : ("Med",           "You are to summarize a telecom question and its options."),
    "finetune_tele"   : ("Tele",          "You are to summarize a telecom question and its options."),
    "finetune_med_new": ("Med_New",       "You are to summarize a telecom question and its options."),
    "llama"           : ("llama",         "You are to summarize a telecom question and its options."),
}

deployments_diagnoser = {
    "base"            : ("gpt-35-turbo",
                         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}' where i∈{1,2,3,4}."),
    "finetune_med"    : ("Med",
                         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'."),
    "finetune_tele"   : ("Tele",
                         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'."),
    "finetune_med_new": ("Med_New",
                         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'."),
    "llama"           : ("llama",
                         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'.")
}

deployments_explainer = {
    "base"            : ("gpt-35-turbo", 
                         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "finetune_med"    : ("Med", 
                         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "finetune_tele"   : ("Tele", 
                         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "finetune_med_new": ("Med_New", 
                         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "llama"           : ("llama", 
                         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale.")
}

# ─── STEP 4: Cost-per-token dictionaries ────────────────────────────────────────
cost_per_token = {
    "base"            : 0.000004,
    "assistants"      : 0.000004,
    "finetune_med"    : 0.00001,
    "finetune_tele"   : 0.00001,
    "finetune_med_new": 0.00001,
    "llama"           : 0.00000071
}

input_cost_per_token = {
    "base"            : 0.0000001,
    "assistants"      : 0.0000001,
    "finetune_med"    : 0.00000025,
    "finetune_tele"   : 0.00000025,
    "finetune_med_new": 0.00000025,
    "llama"           : 0.00000071
}
arm_to_llm = {
        "base"            : "gpt-3.5-turbo",
        "assistants"      : "gpt-3.5-turbo",
        "finetune_med"    : "gpt-4",
        "finetune_tele"   : "gpt-4",
        "finetune_med_new": "gpt-4",
        "llama"           : "llama-13b"
    }
    
    
    

def save_bandit_checkpoint(run_id, round_num, u_s, u_d, u_e,
                           regrets, rewards, costs,
                           plays_triplet,
                           avg_acc,
                           all_rewards_diag,
                           cum_reg,
                           tot_r,
                           tot_c,
                           is_final=False):
    """Save NeuralUCBDiag models (summarizer, diagnoser, explainer) and training state"""
    checkpoint_dir = "bandit_checkpoints_neucb_joint_tele"
    os.makedirs(checkpoint_dir, exist_ok=True)

    if is_final:
        filename = f"run_{run_id}_final.pth"
    else:
        filename = f"run_{run_id}_round_{round_num}.pth"

    filepath = os.path.join(checkpoint_dir, filename)

    checkpoint = {
        'run_id': run_id,
        'round_num': round_num,

        'u_s_state': {
            'net_state_dict': u_s.net.state_dict(),
            'U': u_s.U,
            'contexts': u_s.contexts,
            'rewards': u_s.rewards
        },
        'u_d_state': {
            'net_state_dict': u_d.net.state_dict(),
            'U': u_d.U,
            'contexts': u_d.contexts,
            'rewards': u_d.rewards
        },
        'u_e_state': {
            'net_state_dict': u_e.net.state_dict(),
            'U': u_e.U,
            'contexts': u_e.contexts,
            'rewards': u_e.rewards
        },

        # Metrics
        'regrets': regrets,
        'rewards': rewards,
        'costs': costs,
        'plays_triplet': plays_triplet,
        'avg_acc': avg_acc,
        'all_rewards_diag': all_rewards_diag,
        'cum_reg': cum_reg,
        'tot_r': tot_r,
        'tot_c': tot_c,
    }

    torch.save(checkpoint, filepath)
    print(f"  Saved bandit checkpoint: {filepath}")
    return filepath


def load_bandit_checkpoint(filepath, dim, lamdba, nu, hidden):
    """Recreate u_s, u_d, u_e and restore state"""
    checkpoint = torch.load(filepath, map_location=device)

    u_s = NeuralUCBDiag(style, dim, lamdba, nu, hidden)
    u_d = NeuralUCBDiag(style, dim, lamdba, nu, hidden)
    u_e = NeuralUCBDiag(style, dim, lamdba, nu, hidden)

    # Restore
    for u, key in [(u_s, 'u_s_state'), (u_d, 'u_d_state'), (u_e, 'u_e_state')]:
        u.net.load_state_dict(checkpoint[key]['net_state_dict'])
        u.U = checkpoint[key]['U']
        u.contexts = checkpoint[key]['contexts']
        u.rewards = checkpoint[key]['rewards']

    print(f"  Loaded bandit checkpoint from: {filepath}")
    print(f"     Run: {checkpoint['run_id']}, Round: {checkpoint['round_num']}")
    return u_s, u_d, u_e, checkpoint


def cleanup_bandit_checkpoints(run_id):
    """Delete intermediate checkpoints, keep only final"""
    checkpoint_dir = "bandit_checkpoints_neucb_joint_tele"
    pattern = os.path.join(checkpoint_dir, f"run_{run_id}_round_*.pth")
    intermediate_checkpoints = glob.glob(pattern)

    deleted_count = 0
    for cp_path in intermediate_checkpoints:
        try:
            os.remove(cp_path)
            deleted_count += 1
        except Exception as e:
            print(f"  Warning: Could not delete {cp_path}: {e}")

    if deleted_count > 0:
        print(f"  Cleaned up {deleted_count} intermediate bandit checkpoint(s) for run {run_id}")










# ─── STEP 5: Token-length predictor ─────────────────────────────────────────────
from transformers import AutoConfig, AutoTokenizer
import json
from src.token_prediction.tok_length_predict import BertRegressionModel
import tiktoken

reg_model_name = "bert-base-uncased"
reg_config     = AutoConfig.from_pretrained(reg_model_name)
reg_tokenizer  = AutoTokenizer.from_pretrained(reg_model_name)
with open("src/token_prediction/model_names.json") as f:
    orig_model_names = json.load(f)
num_models = len(orig_model_names)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

token_length_model = BertRegressionModel(
    reg_config, reg_model_name,
    hidden_dim=128,
    num_models=num_models
).to(device)
token_length_model.load_state_dict(torch.load("best_length_model.pth", map_location=device))
token_length_model.eval()

online_trainer = OnlineCostModelTrainer(
    model=token_length_model,
    tokenizer=reg_tokenizer,
    orig_model_names=orig_model_names,
    arm_to_llm=arm_to_llm,
    cost_per_token=cost_per_token,
    input_cost_per_token=input_cost_per_token,
    device=device,
    checkpoint_dir = "cost_model_checkpoints_neucb_joint",
    lr=1e-6,
    update_freq=5
)
all_cost_model_summaries = []
checkpoint_frequency = 20




from sentence_transformers import SentenceTransformer
inp_model = SentenceTransformer("paraphrase-MiniLM-L6-v2")

import tiktoken
openai_models = {"gpt-3.5-turbo", "gpt-4"}
encodings = { m: tiktoken.encoding_for_model(m) for m in openai_models }
from transformers import AutoTokenizer as HFTokenizer
try:
    llama_tok = HFTokenizer.from_pretrained("openlm-research/open_llama_13b")
except Exception:
    llama_tok = reg_tokenizer




arm_encoders = {}
for mk, llm_name in arm_to_llm.items():
    if llm_name in encodings:
        arm_encoders[mk] = encodings[llm_name]
    else:
        arm_encoders[mk] = llama_tok

# ─── STEP 6: NeuralUCB Bandit ──────────────────────────────────────────────────
class NeuralUCBDiag:
    def __init__(self, style, dim, lamdba=1, nu=1, hidden=100):
        self.device = device
        self.net    = extend(nn.Sequential(nn.Linear(dim, hidden), nn.ReLU(), nn.Linear(hidden,1)).to(self.device))
        self.lamdba = lamdba
        self.nu     = nu
        p_count     = sum(p.numel() for p in self.net.parameters())
        self.U      = lamdba * torch.ones(p_count, device=self.device)
        self.contexts = []
        self.rewards  = []
        self.style    = style
    def selection(self, context, style):
        x = torch.from_numpy(context).float().to(self.device).unsqueeze(0)
        mu = self.net(x)
        self.net.zero_grad(); mu.backward(retain_graph=True)
        grads = torch.cat([p.grad.flatten() for p in self.net.parameters()])
        sigma = torch.sqrt(torch.sum((self.lamdba*self.nu*grads*grads/self.U)))
        score = (0.2*mu.item()+2*sigma.item()) if style=='ucb' else torch.normal(1.0*mu.view(-1),0.05*sigma.view(-1)).item()
        self.U += grads*grads
        return score
    def train(self, context, reward):
        c = torch.from_numpy(context).float().to(self.device).unsqueeze(0)
        self.contexts.append(c); self.rewards.append(float(reward))
        optimizer = optim.SGD(self.net.parameters(), lr=1e-4, weight_decay=self.lamdba)
        tot_loss = 0; cnt=0
        for ctx, r in zip(self.contexts, self.rewards):
            optimizer.zero_grad()
            pred = self.net(ctx).view(-1)[0]
            loss = (pred-r)**2
            loss.backward(); optimizer.step()
            tot_loss += loss.item(); cnt+=1
            if cnt>=5: break
        return tot_loss/cnt if cnt>0 else 0

# ─── STEP 7: Args ─────────────────────────────────────────────────────────────
parser = argparse.ArgumentParser()
parser.add_argument('--size', default=150, type=int, help='number of rounds')
parser.add_argument('--nu', type=float, default=1, metavar='v', help='nu for control variance')
parser.add_argument('--lamdba', type=float, default=1, metavar='l', help='lambda for regularization')
parser.add_argument('--hidden', type=int, default=50, help='network hidden size')
parser.add_argument('--style', default='ts', metavar='ts|ucb', help='TS or UCB')
parser.add_argument('--number_tasks', default=3, type=int, help='number of subtasks')
parser.add_argument('--no_runs', default=3, type=int, help='how many independent runs')
parser.add_argument('--alpha', default=10, type=int, help='cost accuracy tradeoff weight')
args = parser.parse_args()
size, nu, lamdba, hidden, style, number_tasks, no_runs, alpha = (
    args.size, args.nu, args.lamdba, args.hidden, args.style,
    args.number_tasks, args.no_runs, args.alpha
)
num_rounds = size

# ─── STEP 8: Prepare models & containers ───────────────────────────────────────
models_summarizer = list(deployments_summarizer.keys())
models_diagnoser  = list(deployments_diagnoser.keys())
models_explainer  = list(deployments_explainer.keys())
all_regrets, all_rewards, all_costs = [], [], []
super_arms = [(s, d, e) for s in models_summarizer for d in models_diagnoser for e in models_explainer]
num_triplets = len(super_arms)
all_plays = np.zeros((args.no_runs,num_triplets))
all_avg_arrays = []


# ─── RUN SIMULATIONS ───────────────────────────────────────────────────────────
for run in range(args.no_runs):
    print(f"=== Run {run+1}/{args.no_runs} ===")
    #u_sum = NeuralUCBDiag(args.style, 384, args.lamdba, args.nu, args.hidden)
    #u_diag = NeuralUCBDiag(args.style, 384, args.lamdba, args.nu, args.hidden)
    #u_exp  = NeuralUCBDiag(args.style, 384, args.lamdba, args.nu, args.hidden)
    #u_joint = NeuralUCBDiag(args.style, 1152, args.lamdba, args.nu, args.hidden)
    u_s = NeuralUCBDiag(args.style, 384, args.lamdba, args.nu, args.hidden)
    u_d = NeuralUCBDiag(args.style, 384, args.lamdba, args.nu, args.hidden)
    u_e = NeuralUCBDiag(args.style, 384, args.lamdba, args.nu, args.hidden)
    actual_total_cost = 0


    plays_triplet = np.zeros(num_triplets, dtype=int)
    #super_arms = [(s, d, e) for s in models_summarizer for d in models_diagnoser for e in models_explainer]
    #num_triplets = len(super_arms)
    # plays_s = np.zeros(len(deployments_summarizer), int)
    #plays_d = np.zeros(len(deployments_diagnoser), int)
    #plays_e = np.zeros(len(deployments_explainer), int)
    regrets, rewards, costs = [], [], []
    tot_reward = 0; cum_reg = 0
    avg_array = {"gpt-35-turbo":0,"Med":0,"Tele":0,"Med_New":0,"llama":0}
    i = 0
    documents = (
        summary_description_array
        + diagnosis_description_array
        + explanation_description_array
        + list(input_reports)
    )

    all_rewards_sum = []
    all_rewards_diag = []
    all_rewards_expl = []

    arm_to_llm = {
        "base"            : "gpt-3.5-turbo",
        "assistants"      : "gpt-3.5-turbo",
        "finetune_med"    : "gpt-4",
        "finetune_tele"   : "gpt-4",
        "finetune_med_new": "gpt-4",
        "llama"           : "llama-13b"
    }
    
    for t in range(args.size):
        print(f"Round {t+1}")
        question = input_reports[t]
        #toks_q = reg_tokenizer(question, truncation=True, padding="max_length", max_length=256, return_tensors="pt").to(device)
        triplet_contexts_s, triplet_contexts_d,triplet_contexts_e,triplet_scores, pred_costs = [], [], [],[],[]
        pred_sum_lengths = []
        pred_diag_lengths = []
        pred_exp_lengths = []
        for (s_arm, d_arm, e_arm) in super_arms:
            j_s = models_summarizer.index(s_arm)
            j_d = models_diagnoser.index(d_arm)
            j_e = models_explainer.index(e_arm)
        
            cont_s = get_context(documents, t, 0, j_s, len(summary_description_array), len(diagnosis_description_array), 0, inp_model, dataset)
            cont_d = get_context(documents, t, 1, j_d, len(summary_description_array), len(diagnosis_description_array), 0, inp_model, dataset)
            cont_e = get_context(documents, t, 2, j_e, len(summary_description_array), len(diagnosis_description_array), 0, inp_model, dataset)
            #ctx = np.concatenate([cont_s, cont_d, cont_e], axis=-1)
            #ctx = cont_s
            #triplet_contexts.append(ctx)
            triplet_contexts_s.append(cont_s)
            triplet_contexts_d.append(cont_d)
            triplet_contexts_e.append(cont_e)
            
            enc_s = arm_encoders[s_arm]
            in_len_sum = len(enc_s.encode(question)) if hasattr(enc_s, "encode") else len(enc_s(question, truncation=True)["input_ids"])
            # summary output pred
            idx_s = orig_model_names.index(arm_to_llm[s_arm])
            onehot_s = torch.zeros(num_models, device=device); onehot_s[idx_s]=1.0
            toks_s = reg_tokenizer(
                text=question,
                truncation=True,
                padding="max_length",
                max_length=256,
                return_tensors="pt"
            ).to(device)
            with torch.no_grad():
                out_len_sum_pred = token_length_model(
                    toks_s["input_ids"],
                    toks_s["attention_mask"],
                    onehot_s.unsqueeze(0)
                ).item()

            
            diag_prompt_approx = (
                question
                + " Please give the correct option in the format: option [correct option number]."
            ).replace("\n", "")

            enc_d = arm_encoders[d_arm]
            if hasattr(enc_d, "encode"):
                in_len_diag_est = len(enc_d.encode(diag_prompt_approx))
            else:
                in_len_diag_est = len(enc_d(
                    diag_prompt_approx, truncation=True
                )["input_ids"])


            idx_d = orig_model_names.index(arm_to_llm[d_arm])
            onehot_d = torch.zeros(num_models, device=device)
            onehot_d[idx_d] = 1.0

            toks_d = reg_tokenizer(
                text=diag_prompt_approx,
                truncation=True,
                padding="max_length",
                max_length=256,
                return_tensors="pt"
            ).to(device)
            with torch.no_grad():
                out_len_diag_pred = token_length_model(
                    toks_d["input_ids"],
                    toks_d["attention_mask"],
                    onehot_d.unsqueeze(0)
                ).item()


        
            expl_prompt_approx = (
                question + " Answer chosen: option [x]."
            ).replace("\n", "")

            enc_e = arm_encoders[e_arm]
            if hasattr(enc_e, "encode"):
                in_len_exp_est = len(enc_e.encode(expl_prompt_approx))
            else:
                in_len_exp_est = len(enc_e(
                    expl_prompt_approx, truncation=True
                )["input_ids"])

            idx_e = orig_model_names.index(arm_to_llm[e_arm])
            onehot_e = torch.zeros(num_models, device=device)
            onehot_e[idx_e] = 1.0

            toks_e = reg_tokenizer(
                text=expl_prompt_approx,
                truncation=True,
                padding="max_length",
                max_length=256,
                return_tensors="pt"
            ).to(device)
            with torch.no_grad():
                out_len_exp_pred = token_length_model(
                    toks_e["input_ids"],
                    toks_e["attention_mask"],
                    onehot_e.unsqueeze(0)
                ).item()

            # ---------- Total predicted cost for this triplet ----------
            pred_cost = (
                input_cost_per_token[s_arm] * in_len_sum
                + cost_per_token[s_arm]     * out_len_sum_pred
                + input_cost_per_token[d_arm] * in_len_diag_est
                + cost_per_token[d_arm]       * out_len_diag_pred
                + input_cost_per_token[e_arm] * in_len_exp_est
                + cost_per_token[e_arm]       * out_len_exp_pred
            )


            pred_sum_lengths.append(out_len_sum_pred)
            pred_diag_lengths.append(out_len_diag_pred)
            pred_exp_lengths.append(out_len_exp_pred)

            # val = u_joint.selection(ctx, args.style)
            # val_f = val.item() if torch.is_tensor(val) else float(val)
            # triplet_scores.append(val_f - args.alpha * pred_cost)
            v_s = u_s.selection(cont_s, args.style)
            v_d = u_d.selection(cont_d, args.style)
            v_e = u_e.selection(cont_e, args.style)
            v_s = float(v_s) if torch.is_tensor(v_s) else v_s
            v_d = float(v_d) if torch.is_tensor(v_d) else v_d
            v_e = float(v_e) if torch.is_tensor(v_e) else v_e
            
            triplet_scores.append( (v_s + v_d + v_e) - args.alpha * pred_cost )
            pred_costs.append(pred_cost)
        best_idx = int(np.argmax(triplet_scores)) if args.style=='ucb' else \
           int(np.random.choice(np.flatnonzero(np.array(triplet_scores)==np.max(triplet_scores))))
        pred_len_sum_chosen  = pred_sum_lengths[best_idx]
        pred_len_diag_chosen = pred_diag_lengths[best_idx]
        pred_len_exp_chosen  = pred_exp_lengths[best_idx]
        s_arm, d_arm, e_arm = super_arms[best_idx]
        plays_triplet[best_idx] += 1
        print(f"[Round {t+1}] Selected triplet -> "
      f"Summarizer: {s_arm} | Diagnoser: {d_arm} | Explainer: {e_arm}")
        
        summary = get_summary(question, s_arm, "tele")
        summary_clean = summary.replace("\n","")
        enc_s = arm_encoders[s_arm]
        in_len_sum_actual = len(enc_s.encode(question)) if hasattr(enc_s,"encode") else len(enc_s(question, truncation=True)["input_ids"])
        
        summary_text = summary if isinstance(summary, str) else str(summary)
        if hasattr(enc_s, "encode"):
            out_len_sum_actual = len(enc_s.encode(summary_text))
        else:
            out_len_sum_actual = len(enc_s(summary_text, truncation=True)["input_ids"])
        
        # Evaluate summary quality with GPT-4o
        deployment_sum = deployments_summarizer[s_arm]
        reward_sum, _, all_rewards_sum, _ = final_eval_telecom(
            deployment_sum,
            s_arm,
            summary,
            'summary',
            all_rewards_sum,
            all_rewards_diag,
            summary,
            question
        )
        
        # Track summary reward
        all_rewards_sum.append(reward_sum)
        #print(f"  Summary reward: {reward_sum}/100")
        
        online_trainer.add_observation(
            prompt=question,
            model_name=s_arm,
            actual_output_length=out_len_sum_actual,
            predicted_length=pred_len_sum_chosen
        )
        


        prompt_d = summary_clean + " Please give the correct option in the format: option [correct option number]."
        reg1, reward1, out_len_diag_actual, avg_array, _, _ = opt_eval(
            deployments_diagnoser, prompt_d, "diagnosis",
            d_arm, avg_array, t, [], [], labels, dataset
        )
        all_rewards_diag.append(reward1)
        enc_d = arm_encoders[d_arm]
        in_len_diag_actual = len(enc_d.encode(prompt_d)) if hasattr(enc_d,"encode") else len(enc_d(prompt_d, truncation=True)["input_ids"])
        answer_text = f"option {reward1}" if isinstance(reward1,(int,str)) else "option 1"
        prompt_e = question + " Answer chosen: " + str(answer_text)
        reg2, reward2, out_len_exp_actual, avg_array, _, _ = opt_eval(
            deployments_explainer, prompt_e, "explanation",
            e_arm, avg_array, t, [], [], explanations, dataset
        )
        all_rewards_expl.append(reward2)


        online_trainer.add_observation(
            prompt=prompt_d,
            model_name=d_arm,
            actual_output_length=out_len_diag_actual,
            predicted_length=pred_len_diag_chosen
        )
        online_trainer.add_observation(
            prompt=prompt_e,
            model_name=e_arm,
            actual_output_length=out_len_exp_actual,
            predicted_length=pred_len_exp_chosen
        )

        enc_e = arm_encoders[e_arm]
        in_len_exp_actual = len(enc_e.encode(prompt_e)) if hasattr(enc_e,"encode") else len(enc_e(prompt_e, truncation=True)["input_ids"])

        actual_total_cost += (
            input_cost_per_token[s_arm]*in_len_sum_actual + cost_per_token[s_arm]*out_len_sum_actual +
            input_cost_per_token[d_arm]*in_len_diag_actual + cost_per_token[d_arm]*out_len_diag_actual +
            input_cost_per_token[e_arm]*in_len_exp_actual + cost_per_token[e_arm]*out_len_exp_actual
        )
        sum_cost_actual = (
            input_cost_per_token[s_arm]*in_len_sum_actual
          + cost_per_token[s_arm]*out_len_sum_actual
        )
        
        # Update metrics
        cum_reg += (reg1 + reg2)
        tot_reward += (reward1 + reward2)
        regrets.append(cum_reg)
        rewards.append(tot_reward)
        #print(f"Reward: {tot_reward} | Regret: {cum_reg} | Actual total cost: {actual_total_cost}")
        costs.append( (input_cost_per_token[s_arm]*in_len_sum_actual + cost_per_token[s_arm]*out_len_sum_actual) +
                      (input_cost_per_token[d_arm]*in_len_diag_actual + cost_per_token[d_arm]*out_len_diag_actual) +
                      (input_cost_per_token[e_arm]*in_len_exp_actual + cost_per_token[e_arm]*out_len_exp_actual) )
        
        # Train joint learner
        #target = (reward1 + reward2) - args.alpha * actual_total_cost
        #loss = u_joint.train(triplet_contexts[best_idx], reward1+reward2)
        
        # Normalize rewards for training
        reward_sum_norm = reward_sum / 100.0  # Convert from [0,100] to [0,1]
        reward1_norm = float(reward1)
        reward2_norm = float(reward2)
        
        # Normalize summary rewards using z-score
        rews_array_sum = np.array(all_rewards_sum, dtype=float) / 100.0
        if rews_array_sum.std() == 0:
            target_s = reward_sum_norm
        else:
            target_s = (reward_sum_norm - rews_array_sum.mean()) / rews_array_sum.std()
        
        # Normalize diagnosis rewards using z-score
        if len(all_rewards_diag) > 0:
            rews_array_diag = np.array(all_rewards_diag, dtype=float)
            if rews_array_diag.std() == 0:
                target_d = reward1_norm
            else:
                target_d = (reward1_norm - rews_array_diag.mean()) / rews_array_diag.std()
        else:
            target_d = reward1_norm
        
        # Normalize explanation rewards using z-score
        if len(all_rewards_expl) > 0:
            rews_array_expl = np.array(all_rewards_expl, dtype=float)
            if rews_array_expl.std() == 0:
                target_e = reward2_norm
            else:
                target_e = (reward2_norm - rews_array_expl.mean()) / rews_array_expl.std()
        else:
            target_e = reward2_norm
        
        # Train on the chosen super-arm's contexts
        _ = u_s.train(triplet_contexts_s[best_idx], target_s)
        _ = u_d.train(triplet_contexts_d[best_idx], target_d)
        _ = u_e.train(triplet_contexts_e[best_idx], target_e)
        
        
        
        
        # if (t+1) % 5 == 0:
        #     print('{}: {:.3f}, {:.3f}, {:.3f}'.format(t+1, summ, rewards, loss))
        if (t + 1) % checkpoint_frequency == 0:
            online_trainer.save_checkpoint(
                run_id=run,
                round_num=t+1,
                is_final=False
            )

            save_bandit_checkpoint(
                run_id=run,
                round_num=t+1,
                u_s=u_s, u_d=u_d, u_e=u_e,
                regrets=regrets,
                rewards=rewards,
                costs=costs,
                plays_triplet=plays_triplet,
                avg_acc=avg_array,
                all_rewards_diag=all_rewards_diag,
                cum_reg=cum_reg,
                tot_r=tot_reward,
                tot_c=actual_total_cost,
                is_final=False
            )

        # Optional: print cost model metrics every 10 rounds
        if (t + 1) % 10 == 0:
            print(f"Round {t+1}: Total reward - {tot_reward}, Total cost - {actual_total_cost}")
            metrics = online_trainer.compute_metrics(round_num=t+1)
            if metrics:
                print(f"  [Round {t+1}] Cost Model - MAE: {metrics['mae']:.2f}, "
                    f"MAPE: {metrics['mape']:.2f}%, R²: {metrics['r2']:.4f}")
        
    all_regrets.append(regrets)
    all_rewards.append(rewards)
    all_costs.append(costs)
    all_plays[run,:] = plays_triplet
    #all_plays_d.append(plays_d)
    #all_plays_e.append(plays_e)
    all_avg_arrays.append(avg_array.copy())
    print(f"Run {run+1} complete: Final reward = {tot_reward:.4f}")

    # Save final bandit checkpoint
    save_bandit_checkpoint(
        run_id=run,
        round_num=args.size,
        u_s=u_s, u_d=u_d, u_e=u_e,
        regrets=regrets,
        rewards=rewards,
        costs=costs,
        plays_triplet=plays_triplet,
        avg_acc=avg_array,
        all_rewards_diag=all_rewards_diag,
        cum_reg=cum_reg,
        tot_r=tot_reward,
        tot_c=actual_total_cost,
        is_final=True
    )

    cleanup_bandit_checkpoints(run_id=run)

    # Cost model final update + save
    online_trainer.final_update()
    online_trainer.save_checkpoint(run_id=run, round_num=args.size, is_final=True)
    online_trainer.cleanup_intermediate_checkpoints(run_id=run)
    online_trainer.print_summary()
    online_trainer.save_stats(f"tele_results/cost_model_run_{run}_neucb_joint")

    run_cost_summary = {
        'overall': online_trainer.compute_metrics(),
        'per_model': online_trainer.compute_per_model_metrics()
    }
    all_cost_model_summaries.append(run_cost_summary)


import pandas as pd
avg_df      = pd.DataFrame(all_avg_arrays)
avg_mean = avg_df.mean(axis=0).to_dict()
avg_std  = avg_df.std(axis=0).to_dict()
plays_mean = all_plays.mean(axis=0)
# ─── STEP 9: Save metrics ───────────────────────────────────────────────────────
pickle.dump(np.mean(all_regrets,axis=0),open("regrets_mean_neucb_budgeted_joint_tele_1.pkl","wb"))
pickle.dump(np.std(all_regrets,axis=0), open("regrets_std_neucb_budgeted_joint_tele_1.pkl","wb"))
pickle.dump(np.mean(all_rewards,axis=0),open("rewards_mean_neucb_budgeted_joint_tele_1.pkl","wb"))
pickle.dump(np.std(all_rewards,axis=0), open("rewards_std_neucb_budgeted_joint_tele_1.pkl","wb"))
pickle.dump(np.mean(all_costs,axis=0),   open("costs_mean_neucb_budgeted_joint_tele_1.pkl","wb"))
pickle.dump(np.std(all_costs,axis=0),    open("costs_std_neucb_budgeted_joint_tele_1.pkl","wb"))
pickle.dump(plays_mean,open("plays_neucb_joint_budgeted_tele_1.pkl","wb"))
pickle.dump(
    all_cost_model_summaries,
    open("tele_results/cost_model_all_runs_neucb_joint.pkl", "wb")
)
print(f"Final mean regret: {np.mean(all_regrets,axis=0)[-1]}")
print(f"Final mean reward: {np.mean(all_rewards,axis=0)[-1]}")
print(f"Final mean cost: {np.mean(all_costs,axis=0)[-1]}")
#print(f"Final mean summarizer cost: {costs_summarizer_mean[-1]}")
print(f"Final mean plays: {plays_mean}")
#pickle.dump(all_plays_d,open("plays_d_neucb_budgeted.pkl","wb"))
#pickle.dump(all_plays_e,open("plays_e_neucb_budgeted.pkl","wb"))
# pickle.dump(avg_mean,     open("avg_accuracy_mean_tele_budgeted_neucb.pkl","wb"))
# pickle.dump(avg_std,      open("avg_accuracy_std_tele_budgeted_neucb.pkl","wb"))

print("All runs complete. Summary pickles written.")
