import time
import numpy as np
import pickle
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from backpack import extend
from prompt_maker import input_maker
#from src.embedding.embed_tele import get_context
from utils.helper import opt_eval,get_summary
#from src.regrets.sum_call_seq import get_summary
start_time = time.time()
# ─── STEP 1: Telecom dataset ────────────────────────────────────────────────────
input_reports, labels, explanations = input_maker("seq", "telecom","")
dataset = "telecom"

# ─── STEP 2: Description arrays ─────────────────────────────────────────────────
summary_description_array = [
    "Summarize the telecommunications question and its options concisely for analysis.",
    "Provide a brief recap of the telecom question and choices for researchers.",
    "You will take the role of a telecom-specialist summarizer. Summarize the question and answer options.",
    "Produce a short summary of the telecom question and all choices.",
    "Present the telecom question and its multiple-choice options in a concise summary."
]

diagnosis_description_array = [
    "Answer the telecom MCQ strictly 'option {i}' for this question.",
    "Provide the MCQ answer (1–4) for this telecom question.",
    "Output the telecom MCQ response as 'option {i}'.",
    "Select the correct option (1–4) for the telecommunications question.",
    "Choose the telecom MCQ answer and output 'option {i}'."
]

explanation_description_array = [
    "Explain in detail why the chosen telecom MCQ answer is correct.",
    "Provide a step-by-step rationale for why the selected answer is correct.",
    "As a telecom expert, justify why the chosen MCQ option is right.",
    "Offer a clear explanation for why the selected telecom answer is correct.",
    "Give a detailed rationale for why the chosen option is correct."
]

# ─── STEP 3: Deployment instructions per arm ────────────────────────────────────
documents = summary_description_array+ diagnosis_description_array+ explanation_description_array+ list(input_reports)

# ─── STEP 5: Deployment instructions per arm ───────────────────────────────────────────
deployments_summarizer = {
    "base"            : ("gpt-35-turbo", "You are to summarize a telecom question and its options."),
    "assistants"      : ("Assistant",     "You are to summarize a telecom question and its options."),
    "finetune_med"    : ("Med",           "You are to summarize a telecom question and its options."),
    "finetune_tele"   : ("Tele",          "You are to summarize a telecom question and its options."),
    "finetune_med_new": ("Med_New",       "You are to summarize a telecom question and its options."),
    "llama"           : ("llama",         "You are to summarize a telecom question and its options."),
}

deployments_diagnoser = {
    "base"            : ("gpt-35-turbo",
                         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}' where i∈{1,2,3,4}."),
    "finetune_med"    : ("Med",
                         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'."),
    "finetune_tele"   : ("Tele",
                         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'."),
    "finetune_med_new": ("Med_New",
                         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'."),
    "llama"           : ("llama",
                         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'.")
}

deployments_explainer = {
    "base"            : ("gpt-35-turbo", 
                         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "finetune_med"    : ("Med", 
                         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "finetune_tele"   : ("Tele", 
                         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "finetune_med_new": ("Med_New", 
                         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "llama"           : ("llama", 
                         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale.")
}

# ─── STEP 4: Cost-per-token dictionaries ────────────────────────────────────────
cost_per_token = {
    "base"            : 0.000004,
    "assistants"      : 0.000004,
    "finetune_med"    : 0.00001,
    "finetune_tele"   : 0.00001,
    "finetune_med_new": 0.00001,
    "llama"           : 0.00000071
}

input_cost_per_token = {
    "base"            : 0.0000001,
    "assistants"      : 0.0000001,
    "finetune_med"    : 0.00000025,
    "finetune_tele"   : 0.00000025,
    "finetune_med_new": 0.00000025,
    "llama"           : 0.00000071
}
# ─── STEP 5: Token-length predictor ─────────────────────────────────────────────

import tiktoken
openai_models = {"gpt-3.5-turbo", "gpt-4"}
encodings = { m: tiktoken.encoding_for_model(m) for m in openai_models }
from transformers import AutoTokenizer as HFTokenizer
try:
    llama_tok = HFTokenizer.from_pretrained("openlm-research/open_llama_13b")
except Exception:
    llama_tok = reg_tokenizer


arm_to_llm = {
        "base"            : "gpt-3.5-turbo",
        "assistants"      : "gpt-3.5-turbo",
        "finetune_med"    : "gpt-4",
        "finetune_tele"   : "gpt-4",
        "finetune_med_new": "gpt-4",
        "llama"           : "llama-13b"
    }

arm_encoders = {}
for mk, llm_name in arm_to_llm.items():
    if llm_name in encodings:
        arm_encoders[mk] = encodings[llm_name]
    else:
        arm_encoders[mk] = llama_tok

# ─── STEP 7: Args ─────────────────────────────────────────────────────────────
parser = argparse.ArgumentParser()
parser.add_argument('--size', default=150, type=int, help='number of rounds')
parser.add_argument('--nu', type=float, default=1, metavar='v', help='nu for control variance')
parser.add_argument('--lamdba', type=float, default=1, metavar='l', help='lambda for regularization')
parser.add_argument('--hidden', type=int, default=50, help='network hidden size')
parser.add_argument('--style', default='ts', metavar='ts|ucb', help='TS or UCB')
parser.add_argument('--number_tasks', default=3, type=int, help='number of subtasks')
parser.add_argument('--no_runs', default=3, type=int, help='how many independent runs')
parser.add_argument('--alpha', default=10, type=int, help='cost accuracy tradeoff weight')
args = parser.parse_args()
size, nu, lamdba, hidden, style, number_tasks, no_runs, alpha = (
    args.size, args.nu, args.lamdba, args.hidden, args.style,
    args.number_tasks, args.no_runs, args.alpha
)
num_rounds = size

# ─── STEP 8: Prepare models & containers ───────────────────────────────────────
models_summarizer = list(deployments_summarizer.keys())
models_diagnoser  = list(deployments_diagnoser.keys())
models_explainer  = list(deployments_explainer.keys())
all_regrets, all_rewards, all_costs = [], [], []
super_arms = [(s, d, e) for s in models_summarizer for d in models_diagnoser for e in models_explainer]
num_triplets = len(super_arms)
all_plays = np.zeros((args.no_runs,num_triplets))
all_avg_arrays = []


def save_checkpoint(run, t, all_regrets, all_rewards, all_costs, all_plays, all_avg_arrays, 
                   regrets, rewards, costs, plays_triplet, avg_array,
                   all_rewards_sum, all_rewards_diag, all_rewards_debate,
                   actual_total_cost, tot_reward, cum_reg):
    checkpoint = {
    'run': run,
    'round': t,
    'all_regrets': all_regrets,
    'all_rewards': all_rewards,
    'all_costs': all_costs,
    'all_plays': all_plays,
    'all_avg_arrays': all_avg_arrays,
    'current_regrets': regrets,
    'current_rewards': rewards,
    'current_costs': costs,
    'current_plays_triplet': plays_triplet,
    'current_avg_array': avg_array,
    'all_rewards_sum': all_rewards_sum,
    'all_rewards_diag': all_rewards_diag,
    'all_rewards_debate': all_rewards_debate,
    'actual_total_cost': actual_total_cost,
    'tot_reward': tot_reward,
    'cum_reg': cum_reg
}
    pickle.dump(checkpoint, open("tele_results/checkpoint_finetune_tele_3subtasks.pkl", "wb"))


import os
checkpoint_path = "tele_results/checkpoint_finetune_tele_3subtasks.pkl"
start_run = 0
start_round = 0
if os.path.exists(checkpoint_path):
    # Load and resume
    checkpoint = pickle.load(open(checkpoint_path, "rb"))
    # Extract all state variables from checkpoint




# ─── RUN SIMULATIONS ───────────────────────────────────────────────────────────
for run in range(args.no_runs):
    print(f"=== Run {run+1}/{args.no_runs} ===")
    actual_total_cost = 0


    plays_triplet = np.zeros(num_triplets, dtype=int)
    #super_arms = [(s, d, e) for s in models_summarizer for d in models_diagnoser for e in models_explainer]
    #num_triplets = len(super_arms)
    # plays_s = np.zeros(len(deployments_summarizer), int)
    #plays_d = np.zeros(len(deployments_diagnoser), int)
    #plays_e = np.zeros(len(deployments_explainer), int)
    regrets, rewards, costs = [], [], []
    tot_reward = 0; cum_reg = 0
    avg_array = {"gpt-35-turbo":0,"Med":0,"Tele":0,"Med_New":0,"llama":0}
    i = 0
    documents = (
        summary_description_array
        + diagnosis_description_array
        + explanation_description_array
        + list(input_reports)
    )

    all_rewards_sum = []
    all_rewards_debate = []
    all_rewards_diag = []

    arm_to_llm = {
        "base"            : "gpt-3.5-turbo",
        "assistants"      : "gpt-3.5-turbo",
        "finetune_med"    : "gpt-4",
        "finetune_tele"   : "gpt-4",
        "finetune_med_new": "gpt-4",
        "llama"           : "llama-13b"
    }
    
    for t in range(args.size):
        question = input_reports[t]
#        toks_q = reg_tokenizer(question, truncation=True, padding="max_length", max_length=256, return_tensors="pt").to(device)
        TELE_KEY = "finetune_tele"
        s_arm = np.random.choice(models_summarizer)   # random summarizer each round
        d_arm = TELE_KEY                              # always Tele for diagnosis
        e_arm = TELE_KEY
        best_idx = super_arms.index((s_arm, d_arm, e_arm))
        plays_triplet[best_idx] += 1
        print(f"[Round {t+1}] Selected triplet -> Summarizer: {s_arm} | Diagnoser: {d_arm} | Explainer: {e_arm}")
        question = input_reports[t]
        enc_s = arm_encoders[s_arm]
        summary = get_summary(question, s_arm, "tele")
        summary_clean = summary.replace("\n","")
        enc_s = arm_encoders[s_arm]
        
        # Calculate actual input length for summarizer
        in_len_sum_actual = len(enc_s.encode(question)) if hasattr(enc_s,"encode") else len(enc_s(question, truncation=True)["input_ids"])
        
        # Calculate actual output length for summarizer (the generated summary)
        summary_text = summary if isinstance(summary, str) else str(summary)
        out_len_sum_actual = len(enc_s.encode(summary_text)) if hasattr(enc_s,"encode") else len(enc_s(summary_text, truncation=True)["input_ids"])
        
        
        
        prompt_d = summary_clean + " Please give the correct option in the format: option [correct option number]."
        reg1, reward1, out_len_diag_actual, avg_array, _, _ = opt_eval(
            deployments_diagnoser, prompt_d, "diagnosis",
            d_arm, avg_array, t, [], [], labels, dataset
        )
        enc_d = arm_encoders[d_arm]
        in_len_diag_actual = len(enc_d.encode(prompt_d)) if hasattr(enc_d,"encode") else len(enc_d(prompt_d, truncation=True)["input_ids"])
        answer_text = f"option {reward1}" if isinstance(reward1,(int,str)) else "option 1"
        prompt_e = question + " Answer chosen: " + str(answer_text)
        reg2, reward2, out_len_exp_actual, avg_array, _, _ = opt_eval(
            deployments_explainer, prompt_e, "explanation",
            e_arm, avg_array, t, [], [], explanations, dataset
        )
        enc_e = arm_encoders[e_arm]
        in_len_exp_actual = len(enc_e.encode(prompt_e)) if hasattr(enc_e,"encode") else len(enc_e(prompt_e, truncation=True)["input_ids"])

        actual_total_cost += (
            input_cost_per_token[s_arm]*in_len_sum_actual + cost_per_token[s_arm]*out_len_sum_actual +
            input_cost_per_token[d_arm]*in_len_diag_actual + cost_per_token[d_arm]*out_len_diag_actual +
            input_cost_per_token[e_arm]*in_len_exp_actual + cost_per_token[e_arm]*out_len_exp_actual
        )
        sum_cost_actual = (
            input_cost_per_token[s_arm]*in_len_sum_actual
          + cost_per_token[s_arm]*out_len_sum_actual
        )
        
        # Update metrics
        cum_reg += (reg1 + reg2)
        tot_reward += (reward1 + reward2)
        regrets.append(cum_reg)
        rewards.append(tot_reward)
        #print(f"Reward: {tot_reward} | Regret: {cum_reg} | Actual total cost: {actual_total_cost}")
        costs.append(actual_total_cost)
        if (t + 1) % 10 == 0:
            print(f"Round {t+1}: total-reward={tot_reward:.3f}, reward-expl={reward2:.3f}, total-cost={actual_total_cost:.3f}")
            save_checkpoint(run, t, all_regrets, all_rewards, all_costs, all_plays, all_avg_arrays,
             regrets, rewards, costs, plays_triplet, avg_array,
             all_rewards_sum, all_rewards_diag, all_rewards_debate,
             actual_total_cost, tot_reward, cum_reg)
    
    all_regrets.append(regrets)
    all_rewards.append(rewards)
    all_costs.append(costs)
    all_plays[run,:] = plays_triplet
    all_avg_arrays.append(avg_array.copy())

import pandas as pd
avg_df      = pd.DataFrame(all_avg_arrays)
avg_mean = avg_df.mean(axis=0).to_dict()
avg_std  = avg_df.std(axis=0).to_dict()
plays_mean = all_plays.mean(axis=0)
# ─── STEP 9: Save metrics ───────────────────────────────────────────────────────
pickle.dump(np.mean(all_regrets,axis=0),open("regrets_mean_tele_finetune.pkl","wb"))
pickle.dump(np.std(all_regrets,axis=0), open("regrets_std_tele_finetune.pkl","wb"))
pickle.dump(np.mean(all_rewards,axis=0),open("rewards_mean_tele_finetune.pkl","wb"))
pickle.dump(np.std(all_rewards,axis=0), open("rewards_std_tele_finetune.pkl","wb"))
pickle.dump(np.mean(all_costs,axis=0),   open("costs_mean_tele_finetune.pkl","wb"))
pickle.dump(np.std(all_costs,axis=0),    open("costs_std_tele_finetune.pkl","wb"))
pickle.dump(plays_mean,open("plays_tele_finetune.pkl","wb"))
print(f"Final mean regret: {np.mean(all_regrets,axis=0)[-1]}")
print(f"Final mean reward: {np.mean(all_rewards,axis=0)[-1]}")
print(f"Final mean cost: {np.mean(all_costs,axis=0)[-1]}")
#print(f"Final mean summarizer cost: {costs_summarizer_mean[-1]}")
print(f"Final mean plays: {plays_mean}")
if os.path.exists(checkpoint_path):
    os.remove(checkpoint_path)
end_time = time.time()
print(f"Code runetime: {end_time-start_time} seconds")
print("All runs complete. Summary pickles written.")
