# -*- coding: utf-8 -*-
"""
Created on Sun Apr 27 17:10:56 2025

@author: baran
"""

import numpy as np
from prompt_maker import input_maker
#import get_regret
from optimal_rand_tele import opt_eval

import scipy as sp
import torch
import torch.nn as nn
import torch.optim as optim
from backpack import backpack, extend
from backpack.extensions import BatchGrad
import argparse
from embed_tele import get_context


from transformers import AutoConfig, AutoTokenizer

# import your regression‐model class
from tok_length_predict import BertRegressionModel  
def get_regret(deployments,prompt,task,selected,avg_array,t,all_rewards_sum,all_rewards_diag,labels,dataset):
    #return get_optimal_super_arm_reward(deployments,prompt,task)-reward
    return opt_eval(deployments, prompt,task,selected,avg_array,t,all_rewards_sum,all_rewards_diag,labels,dataset)
# …later in main(), after you set up device, before the bandit loop:
# ------------------------------------------------------------------

# 1. load tokenizer + config for your saved regression model
# reg_model_name = "bert-base-uncased"        # whatever you used
# reg_config     = AutoConfig.from_pretrained(reg_model_name)
# reg_tokenizer  = AutoTokenizer.from_pretrained(reg_model_name)

# # 2. instantiate & load state_dict
# device = 'cuda'
# token_length_model = BertRegressionModel(reg_config, reg_model_name).to(device)
# token_length_model.load_state_dict(
#     torch.load("bert_regression_model.pth", map_location=device)
# )
# token_length_model.eval()
#from transformers import AutoConfig, AutoTokenizer
#from token_length_predict import BertRegressionModel

# 1) load your length predictor once
reg_model_name  = "bert-base-uncased"
reg_config      = AutoConfig.from_pretrained(reg_model_name)
reg_tokenizer   = AutoTokenizer.from_pretrained(reg_model_name)
import json
with open("model_names.json") as f:
    orig_model_names = json.load(f)   # e.g. ["RWKV-4-Raven-14B","alpaca-13b",…,"gpt-4",…,"gpt-3.5-turbo",…,"llama-2-7b-chat",…]
num_models = len(orig_model_names)    
device = "cuda" if torch.cuda.is_available() else "cpu"
token_length_model = BertRegressionModel(reg_config, reg_model_name,hidden_dim=128,num_models=num_models).to(device)
token_length_model.load_state_dict(
    torch.load("best_length_model.pth", map_location=device)
)
token_length_model.eval()

# 2) dummy cost-per-token for each of your 5 arms
cost_per_token = {
    "base"            : 0.0000015,   # GPT-3.5 Turbo
    "finetune_med"    : 0.00001,   # GPT-4
    "finetune_tele"   : 0.00001,   # GPT-4
    "finetune_med_new": 0.00001,   # GPT-4
    "llama"           : 0.00000071   # Llama-13b
}


class Network(nn.Module):
    def __init__(self, dim = 100, hidden_size=100):
        super(Network, self).__init__()
        self.model = nn.Sequential(nn.Linear(dim, hidden_size),nn.ReLU(),nn.Linear(hidden_size, 1))

class NeuralUCBDiag:
    def __init__(self, style, dim, lamdba=1, nu=1, hidden=100):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if self.device =='cuda':
            self.func = extend(Network(dim, hidden_size=hidden).model.cuda())
        else:
            self.func = Network(dim, hidden_size=hidden)
        self.context_list = []
        self.reward = []
        self.lamdba = lamdba
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if self.device =='cuda':
            self.U = lamdba * torch.ones((self.total_param,)).cuda()
            self.U_random = lamdba * torch.ones((self.total_param,)).cuda()
        else:
            self.U = lamdba * torch.ones((self.total_param,))
            self.U_random = lamdba * torch.ones((self.total_param,))
        self.nu = nu
        self.num_rounds = 100
        self.style = style
        
        if self.device =='cuda':
            self.loss_func = extend(nn.MSELoss().cuda())
        else:
            self.loss_func = extend(nn.MSELoss())
        self.len = 0

    def update_params(self,g_list):
        for g in g_list:
            self.U += g * g
        return 0
    
    
    def selection(self,context,style):
        tensor = torch.from_numpy(np.array(context)).float().cuda()
        mu = self.func(tensor)
        self.func.zero_grad()
        mu.backward(retain_graph=True)
        g = torch.cat([p.grad.flatten().detach() for p in self.func.parameters()])
        sigma2 = self.lamdba * self.nu * g * g / self.U
        sigma = torch.sqrt(torch.sum(sigma2))
        if style == "ucb":
            sample_r = mu.item() +0.5* sigma.item() 
        else:
            sample_r = torch.normal(1.5*mu.view(-1), 0.2*sigma.view(-1))
        self.U += g * g
        return sample_r
           
    def train(self, context, reward):
        self.context_list.append(torch.from_numpy(context.reshape(1, -1)).float())
        self.reward.append(reward)
        optimizer = optim.SGD(self.func.parameters(), lr=1e-2, weight_decay=self.lamdba)
        length = len(self.reward)
        index = np.arange(length)
        np.random.shuffle(index)
        cnt = 0
        tot_loss = 0
        while True:
            batch_loss = 0
            for idx in index:
                c = self.context_list[idx]
                r = self.reward[idx]
                optimizer.zero_grad()
                if self.device =='cuda':
                    delta = self.func(c.cuda()) - torch.tensor(float(r))
                else:
                    delta = self.func(c) - torch.tensor(float(r))
                loss = delta * delta
                loss.backward()
                optimizer.step()
                batch_loss += loss.item()
                tot_loss += loss.item()
                cnt += 1
                if cnt >= 5:
                    return tot_loss / 5
            if batch_loss / length <= 1e-3:
                return batch_loss / length

parser = argparse.ArgumentParser(description='NeuralUCB')

parser.add_argument('--size', default=250, type=int, help='number of rounds')
parser.add_argument('--nu', type=float, default=1, metavar='v', help='nu for control variance')
parser.add_argument('--lamdba', type=float, default=1, metavar='l', help='lambda for regularzation')
parser.add_argument('--hidden', type=int, default=50, help='network hidden size')
parser.add_argument('--style', default='ts', metavar='ts|ucb', help='TS or UCB')
parser.add_argument('--number_tasks', default=2)
parser.add_argument('--no_runs',   default=5,   type=int, help='how many independent runs')
parser.add_argument('--alpha',   default=125,   type=int, help='cost accuracy tradeoff weight')


args, unknown = parser.parse_known_args()
no_tasks = args.number_tasks
no_runs = args.no_runs
alpha = args.alpha

all_regrets     = []   # list of length no_runs, each an array of length num_rounds
all_rewards     = []
all_costs       = []
all_plays       = []   # list of length no_runs, each an array of length num_arms
all_avg_arrays  = []   # list of dicts
class DynamicVariables:
    pass

num_rounds = args.size
diagnosis_description_array = ["General use LLM which does not specialize in any task specifically.","LLM specializing on medical reports and trained to do medical diagnosis for research purposes.",
                               "LLM specializing for answering multiple choice telecommunications questions.", "LLM specializing on medical reports and trained to do medical diagnosis for research purposes.", "General use LLM which does not specialize in any task specifically."]#diagnosis_description_array = ["You will take the role of an medical diagnosis agent whose primary goal is to give diagnosis based on medical reports. For experimentation purposes only.",
deployments_1 = {"base" : ("gpt-35-turbo","You are to answer multiple choice questions related to telecommunications. Output your answer strictly as option {i} where i is between 1-4 if there are 4 options for each question. Do not output an answer like 4 to indicate option 4."), "finetune_med" : ("Med","You are to answer multiple choice questions related to telecommunications. Output your answer strictly as option {i} where i is between 1-4 if there are 4 options for each question. Do not output an answer like 4 to indicate option 4."), "finetune_tele" : ("Tele","You are to answer multiple choice questions related to telecommunications. Output your answer strictly as option {i} where i is between 1-4 if there are 4 options for each question. Do not output an answer like 4 to indicate option 4."),"finetune_med_new" : ("Med_New","You are to answer multiple choice questions related to telecommunications. Output your answer strictly as option {i} where i is between 1-4 if there are 4 options for each question. Do not output an answer like 4 to indicate option 4."),"llama": ("llama","You are to answer multiple choice questions related to telecommunications. Output your answer strictly as option {i} where i is between 1-4 as there are 4 options for each question. Do not output an answer like 4 to indicate option 4.")}

input_reports,labels = input_maker('seq',"telecom",0)[0:args.size]
# 2) define your 5 arms → real names
arm_to_llm = {
    "base"            : "gpt-3.5-turbo",
    "finetune_med"    : "gpt-4",
    "finetune_tele"   : "gpt-4",
    "finetune_med_new": "gpt-4",
    "llama"           : "llama-13b"
}

input_cost_per_token = {
    "base"            : 0.0000005,
    "finetune_med"    : 0.00000025,
    "finetune_tele"   : 0.00000025,
    "finetune_med_new": 0.00000025,
    "llama"           : 0.00000071
}

input_reports = list(input_reports)
documents = diagnosis_description_array+input_reports
from sentence_transformers import SentenceTransformer
inp_model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
for run in range(no_runs):
    print(f"\n===== Starting run {run+1}/{no_runs} =====")
    regrets = []
    costs_list = []
    dyn_vars = DynamicVariables()

    summ = 0
    rew = 0
    rewards_list = []
    total_reward = 0
    emb_size = 384
    total_len = len(diagnosis_description_array)
    for i in range(total_len):
        setattr(dyn_vars, f'l_{i}', NeuralUCBDiag(args.style,emb_size, args.lamdba, args.nu, args.hidden))
        setattr(dyn_vars, f'normalized_{i}', [])

    deploy = [deployments_1]
    cat = ''
    rewards = 0
    costs = 0



    diag_len = len(diagnosis_description_array)
    all_rewards_sum = []
    all_rewards_diag = []
    plays_no = np.ones(total_len)
    avg_array = {"gpt-35-turbo":0,"Med":0,"Tele":0,"Med_New":0,"llama":0}



# arm_tokenizers = {
#     mk: AutoTokenizer.from_pretrained(llm_name)
#     for mk, llm_name in arm_to_llm.items()
# }
#models = ["base", "finetune_med","finetune_tele","finetune_med_new","llama"]
    import tiktoken
    openai_models = {"gpt-3.5-turbo","gpt-4"}
    encodings = { m: tiktoken.encoding_for_model(m) for m in openai_models }
    
    # for llama we use HuggingFace:
    llama_tok = AutoTokenizer.from_pretrained("openlm-research/open_llama_13b")
    
    # now build a lookup that is either a tiktoken encoder or HF tokenizer
    arm_encoders = {}
    for mk,llm in arm_to_llm.items():
        if llm in encodings:
            arm_encoders[mk] = encodings[llm]
        else:
            arm_encoders[mk] = llama_tok
    
    
    for t in range(num_rounds):
        context = []
        models = ["base","finetune_med","finetune_tele","finetune_med_new","llama"]
    
        prompt_to_model = input_reports[t]
        task = 'summary'
        values = []
        for j in range(len(diagnosis_description_array)):
            l = getattr(dyn_vars, f'l_{j}')
            cont = get_context(documents,t,i,j,0,len(diagnosis_description_array),len(input_reports),inp_model,"telecom")
            context.append(cont)
            values.append(l.selection(cont,args.style))
        toks = reg_tokenizer(
            prompt_to_model,
            truncation=True,
            padding="max_length",
            max_length=256,
            return_tensors="pt"
        ).to(device)    
        
        pred_lengths = []
        for mk in models:                   # models = ["base",…,"llama"]
           llm_name = arm_to_llm[mk]       # e.g. "gpt-4"
           idx      = orig_model_names.index(llm_name)
           onehot   = torch.zeros(len(orig_model_names), device=device)
           onehot[idx] = 1.0
           onehot   = onehot.unsqueeze(0)  # [1×25]
    
           with torch.no_grad():
               pred = token_length_model(
                   toks["input_ids"],
                   toks["attention_mask"],
                   onehot
               )
           pred_lengths.append(pred.item())
        
        in_lengths = []
        for mk in models:
            enc = arm_encoders[mk]
            # If this is a tiktoken Encoding, use .encode(...)
            if hasattr(enc, "encode"):
                # tiktoken Encoding
                in_len = len(enc.encode(prompt_to_model))
            else:
                # HuggingFace tokenizer
                in_len = len(enc(
                    prompt_to_model,
                    truncation=True,
                    padding=False
                )["input_ids"])
            in_lengths.append(in_len)
        cost = []
        values_f = [float(v) for v in values]
        print(f"Accuracy UCB: {values_f}")
        # for i, mk in enumerate(models):
        #     #values[i] = values[i] - 500*cost_per_token[mk] * pred_lengths[i]
        #     values[i] = (
        #         values[i]- alpha*cost_per_token[mk]* pred_lengths[i]- alpha*input_cost_per_token[mk] * in_lengths[i]
        #     )
        #     cost.append(alpha*cost_per_token[mk]* pred_lengths[i] + alpha*input_cost_per_token[mk] * in_lengths[i])
        # print(f"Budget cost: {cost}")
        
        
        if args.style == "ts":
            values_np = [v.cpu().detach().numpy() if isinstance(v, torch.Tensor) else v for v in values]
            arm = np.random.choice(np.where(np.array(values_np) == np.array(values_np).max())[0])
        elif args.style =="ucb":
            arm = np.argmax(values)
    
        plays_no[arm] += 1 
        arm_select = models[arm]
        
        
                
        # values_f_1 = [float(v) for v in values]

        # print(f"Final values: {values_f_1}")
        print(f"Selected arm: {arm_select}")
        if models[arm] == "finetune_med" or models[arm]=="finetune_tele" or models[arm]=="finetune_med_new":
            cat = "finetune"
            
        else:    
            cat = models[arm]
        dep = deploy[0]
    
        selected = arm_select
        fin_prompt = prompt_to_model
        deployment = dep[selected]
        
        reg,reward,out_len,avg_array,all_rewards_sum,all_rewards_diag = get_regret(deployments_1,fin_prompt,task,selected,avg_array,t,all_rewards_sum,all_rewards_diag,labels,"telecom")
        costs += input_cost_per_token[arm_select]*in_lengths[arm]+ cost_per_token[arm_select]* out_len
        rewards += int(reward)
        rewards_list.append(rewards)
        costs_list.append(costs)
        summ+= reg
        regrets.append(summ)
        print(f"Reward: {reward}")
        #print(reward)
        print(f"Regret: {reg}")
        #print(reg)
        print(plays_no)
        #print("Done")
        l = getattr(dyn_vars, f'l_{arm}')
    
        if t ==0:
            if reward == 0:    
                all_rewards_diag.append(1)
            else:
                all_rewards_diag.append(0)
            new_rews = all_rewards_diag
            new_rews = (new_rews-np.mean(all_rewards_diag))/np.std(all_rewards_diag)
            index = 0
            loss = l.train(context[arm], new_rews[index])
        else:
            new_rews = all_rewards_diag
            new_rews = (new_rews-np.mean(all_rewards_diag))/np.std(all_rewards_diag)
            index = all_rewards_diag.index(reward)
            loss = l.train(context[arm], new_rews[index])
    
        if (t+1) % 5 == 0:
            print('{}: {:.3f}, {:.3f}, {:.3f}'.format(t+1, summ, rewards,loss))
    all_regrets.append(regrets)
    all_rewards.append(rewards_list)
    all_costs.append(costs_list)
    all_plays.append(plays_no.copy())
    all_avg_arrays.append(avg_array.copy())
# ── Now compute means & stds ──
import pandas as pd
regrets_arr = np.array(all_regrets)     # shape (no_runs, num_rounds)
rewards_arr = np.array(all_rewards)
costs_arr   = np.array(all_costs)
plays_arr   = np.array(all_plays)       # shape (no_runs, num_arms)
avg_df      = pd.DataFrame(all_avg_arrays)  # columns=model names

regrets_mean = regrets_arr.mean(axis=0)
regrets_std  = regrets_arr.std(axis=0)
rewards_mean = rewards_arr.mean(axis=0)
rewards_std  = rewards_arr.std(axis=0)
costs_mean   = costs_arr.mean(axis=0)
costs_std    = costs_arr.std(axis=0)
plays_mean = plays_arr.mean(axis=0)
plays_std  = plays_arr.std(axis=0)
avg_mean = avg_df.mean(axis=0).to_dict()
avg_std  = avg_df.std(axis=0).to_dict()
import pickle
pickle.dump(regrets_mean, open("regrets_mean_tele_normal_seqgpt_2.pkl","wb"))
pickle.dump(regrets_std,  open("regrets_std_tele_normal_seqgpt_2.pkl","wb"))
pickle.dump(rewards_mean, open("rewards_mean_tele_normal_seqgpt_2.pkl","wb"))
pickle.dump(rewards_std,  open("rewards_std_tele_normal_seqgpt_2.pkl","wb"))
pickle.dump(costs_mean,   open("costs_mean_tele_normal_seqgpt_2.pkl","wb"))
pickle.dump(costs_std,    open("costs_std_tele_normal_seqgpt_2.pkl","wb"))
pickle.dump(plays_mean,   open("plays_mean_tele_normal_seqgpt_2.pkl","wb"))
pickle.dump(plays_std,    open("plays_std_tele_normal_seqgpt_2.pkl","wb"))
pickle.dump(avg_mean,     open("avg_accuracy_mean_tele_normal_seqgpt_2.pkl","wb"))
pickle.dump(avg_std,      open("avg_accuracy_std_tele_normal_seqgpt_2.pkl","wb"))
print(f"Final mean regret: {regrets_mean[-1]}")
print(f"Final mean reward: {rewards_mean[-1]}")
print(f"Final mean cost: {costs_mean[-1]}")
print(f"Final mean plays: {plays_mean}")
print(f"Final mean average array: {avg_mean}")

print("All runs complete. Summary pickles written.")
# with open("regrets_seqgpt_tele_4_250_budget.pkl", "wb") as file:
#     pickle.dump(regrets, file)
# with open("rewards_seqgpt_tele_4_250_budget.pkl", "wb") as file:
#     pickle.dump(rewards_list, file)
# with open("costs_seqgpt_tele_4_250_budget.pkl", "wb") as file:
#     pickle.dump(costs_list, file)
# with open("arm_plays_seqgpt_tele_4_250_budget.pkl", "wb") as file:
#     pickle.dump(list(plays_no), file)
# with open("avg_array_acc_llms_4_tele_budget.pkl", "wb") as file:
#     pickle.dump(avg_array, file)