import numpy as np
import pickle
import argparse
from prompt_maker import input_maker
import random

from utils.helper import opt_eval, get_summary
from src.regrets.final_rand_tele import final_eval_telecom
input_reports, labels, explanations = input_maker("seq", "telecom","")
dataset = "telecom"

deployments_summarizer = {
    "base"            : ("gpt-35-turbo", "You are to summarize a telecom question and its options."),
    "assistants"      : ("Assistant",     "You are to summarize a telecom question and its options."),
    "finetune_med"    : ("Med",           "You are to summarize a telecom question and its options."),
    "finetune_tele"   : ("Tele",          "You are to summarize a telecom question and its options."),
    "finetune_med_new": ("Med_New",       "You are to summarize a telecom question and its options."),
    "llama"           : ("llama",         "You are to summarize a telecom question and its options.")
}

deployments_diagnoser = {
    "base"            : ("gpt-35-turbo", "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}' where i ∈ {1,2,3,4}."),
    "finetune_med"    : ("Med",          "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'."),
    "finetune_tele"   : ("Tele",         "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'."),
    "finetune_med_new": ("Med_New",      "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'."),
    "llama"           : ("llama",        "You are to answer multiple choice questions related to telecommunications. Output strictly 'option {i}'.")
}

deployments_explainer = {
    "base"            : ("gpt-35-turbo", "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "finetune_med"    : ("Med",          "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "finetune_tele"   : ("Tele",         "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "finetune_med_new": ("Med_New",      "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale."),
    "llama"           : ("llama",        "You are to explain why the MCQ answer for this telecom question is correct. Provide a detailed rationale.")
}


def save_checkpoint(run, t, all_regrets, all_rewards, all_costs, all_plays, all_avg_acc, 
                   regrets, rewards, costs_list, plays_no, avg_acc, 
                   counts_diag, sums_diag, counts_expl, sums_expl,
                   all_rewards_sum, all_rewards_diag, all_rewards_expl,
                   total_cost, total_reward, cum_regret):
    checkpoint = {
        'run': run,
        'round': t,
        'all_regrets': all_regrets,
        'all_rewards': all_rewards,
        'all_costs': all_costs,
        'all_plays': all_plays,
        'all_avg_acc': all_avg_acc,
        'current_regrets': regrets,
        'current_rewards': rewards,
        'current_costs': costs_list,
        'current_plays': plays_no,
        'current_avg_acc': avg_acc,
        'counts_diag': counts_diag,
        'sums_diag': sums_diag,
        'counts_expl': counts_expl,
        'sums_expl': sums_expl,
        'all_rewards_sum': all_rewards_sum,
        'all_rewards_diag': all_rewards_diag,
        'all_rewards_expl': all_rewards_expl,
        'total_cost': total_cost,
        'total_reward': total_reward,
        'cum_regret': cum_regret
    }
    pickle.dump(checkpoint, open("tele_results/checkpoint_rand_3subtasks.pkl", "wb"))


import os
checkpoint_path = "tele_results/checkpoint_rand_3subtasks.pkl"
start_run = 0
start_round = 0
if os.path.exists(checkpoint_path):
    # Load and resume
    checkpoint = pickle.load(open(checkpoint_path, "rb"))
    # Extract all state variables from checkpoint


models_summarizer = list(deployments_summarizer.keys())
models_diagnoser  = list(deployments_diagnoser.keys())
models_explainer  = list(deployments_explainer.keys())

cost_per_token = {
    "base"            : 0.000004,
    "assistants"      : 0.000004,
    "finetune_med"    : 0.00001,
    "finetune_tele"   : 0.00001,
    "finetune_med_new": 0.00001,
    "llama"           : 0.00000071
}

input_cost_per_token = {
    "base"            : 0.0000001,
    "assistants"      : 0.0000001,
    "finetune_med"    : 0.00000025,
    "finetune_tele"   : 0.00000025,
    "finetune_med_new": 0.00000025,
    "llama"           : 0.00000071
}

arm_to_llm = {
        "base"            : "gpt-3.5-turbo",
        "assistants"      : "gpt-3.5-turbo",
        "finetune_med"    : "gpt-4",
        "finetune_tele"   : "gpt-4",
        "finetune_med_new": "gpt-4",
        "llama"           : "llama-13b"
    }
import tiktoken
parser = argparse.ArgumentParser(description="Random baseline (telecom, 3 subtasks)")
parser.add_argument("--size",         default=150, type=int, help="number of rounds")
parser.add_argument("--no_runs",      default=3,   type=int, help="how many independent runs")
args, _ = parser.parse_known_args()

num_rounds = args.size
no_runs    = args.no_runs

all_regrets  = []  # per run: list of cumulative regrets per round
all_rewards  = []  # per run: total reward (diag + expl)
all_costs    = []  # per run: cost evolution per round
all_plays    = []  # per run: play counts per arm across 3 subtasks (concatenated)
all_avg_acc  = []  # per run: average accuracy per model (diagnosis + explanation)
all_summary_avg = []  # per run: summary accuracy by model
all_diag_avg = []     # per run: diagnosis accuracy by model
all_expl_avg = []
from transformers import AutoTokenizer

# Pre-instantiate a fallback tokenizer (BERT) to use if LLaMA loading fails:
_fallback_tok = AutoTokenizer.from_pretrained("bert-base-uncased")
for run in range(no_runs):
    print(f"\n===== Starting random-run {run+1}/{no_runs} =====")
    regrets = []
    rewards = []
    costs_list = []
    plays_no = np.ones(len(models_summarizer) + len(models_diagnoser) + len(models_explainer))
    avg_acc             = {
        "gpt-35-turbo": 0.0,
        "Med"         : 0.0,
        "Tele"        : 0.0,
        "Med_New"     : 0.0,
        "llama"       : 0.0
    }

    # Track per-arm counts and rewards (for accuracy averaging)
    counts_diag = {arm: 0 for arm in models_diagnoser}
    sums_diag = {arm: 0.0 for arm in models_diagnoser}
    counts_expl = {arm: 0 for arm in models_explainer}
    sums_expl = {arm: 0.0 for arm in models_explainer}
    # Track per-model summary rewards
    counts_sum = {arm: 0 for arm in models_summarizer}
    sums_sum = {arm: 0.0 for arm in models_summarizer}
    
    # Per-model accuracy arrays (will be computed at end of run)
    summary_avg_array = {arm: 0.0 for arm in models_summarizer}
    diag_avg_array = {arm: 0.0 for arm in models_diagnoser}
    expl_avg_array = {arm: 0.0 for arm in models_explainer}
    all_rewards_sum = []  # for normalization in seqgpt, but here unused for summarizer (stub)
    all_rewards_diag = []
    all_rewards_expl = []

    total_cost = 0.0
    total_reward = 0.0
    cum_regret = 0.0

    for t in range(num_rounds):
        print(f"Round {t+1}")
        # ────── Subtask 0: Summarization (random) ──────────────────────────────
        prompt = input_reports[t].replace("\n", " ")
        # Randomly select summarizer arm
        arm_idx_sum = random.randrange(len(models_summarizer))
        arm_sum = models_summarizer[arm_idx_sum]
        print(f"Selected summarizer {arm_sum}")
        plays_no[arm_idx_sum] += 1

        # Generate summary via get_summary
        summary = get_summary(input_reports[t], arm_sum, "tele")
        print(f"Summary is {summary}")
        
        # Calculate proper token counts for summarizer
        openai_models = {"gpt-3.5-turbo", "gpt-4"}
        encodings = {m: tiktoken.encoding_for_model(m) for m in openai_models}
        
        try:
            llm_llama_tok = AutoTokenizer.from_pretrained("openlm-research/open_llama_13b")
        except Exception:
            llm_llama_tok = _fallback_tok
        
        llm_name = arm_to_llm[arm_sum]
        summary_text = summary if isinstance(summary, str) else str(summary)
        prompt_text = input_reports[t] if isinstance(input_reports[t], str) else str(input_reports[t])
        
        # Calculate output tokens (summary)
        if llm_name in encodings:
            out_len_sum = len(encodings[llm_name].encode(summary_text))
            in_len_sum = len(encodings[llm_name].encode(prompt_text))
        else:
            out_len_sum = len(llm_llama_tok(summary_text, truncation=True, padding=False)["input_ids"])
            in_len_sum = len(llm_llama_tok(prompt_text, truncation=True, padding=False)["input_ids"])
        
            
        deployment_sum = deployments_summarizer[arm_sum]
        reward_sum, _, all_rewards_sum, _ = final_eval_telecom(
            deployment_sum,
            arm_sum,
            summary,
            'summary',
            all_rewards_sum,
            all_rewards_diag,
            summary,
            input_reports[t]
        )
        
        # Track summary rewards
        all_rewards_sum.append(reward_sum)
        counts_sum[arm_sum] += 1
        sums_sum[arm_sum] += reward_sum
        
        #print(f"  Summary reward: {reward_sum}")    
            
            


        # Cost for summarizer using proper token counts
        total_cost += input_cost_per_token[arm_sum] * in_len_sum + cost_per_token[arm_sum] * out_len_sum
        costs_list.append(total_cost)

        # Prepare prompt for diagnosis including summary
        diag_prompt = prompt + "\n\nSummary:\n" + summary + "\nPlease give the correct option in the format: option [number]."

        # ────── Subtask 1: Diagnosis (random) ──────────────────────────────────
        arm_idx_diag = random.randrange(len(models_diagnoser))
        arm_diag = models_diagnoser[arm_idx_diag]
        print(f"Selected diagnoser {arm_diag}")
        # Play index offset for diagnosis arms
        plays_no[len(models_summarizer) + arm_idx_diag] += 1

        # Get regret and reward for diagnosis
        reg1, reward1, out_len1, avg_acc, _, _ = opt_eval(
            deployments_diagnoser,
            diag_prompt,
            "diagnosis",
            arm_diag,
            avg_acc,         # avg_array reused to accumulate
            t,
            all_rewards_sum,  # sum across all subtasks
            all_rewards_diag, # tracking diag rewards
            labels,
            dataset
        )
        # Update cumulative regret
        cum_regret += reg1
        regrets.append(cum_regret)
        total_reward += reward1

        # Update diagnosis rewards and counts
        all_rewards_diag.append(reward1)
        counts_diag[arm_diag] += 1
        sums_diag[arm_diag] += reward1
        print(f"Diagnosis reward: {reward1}")

        # Cost for diagnosis
        in_len_diag = len(diag_prompt.split())  # approximation: word count as token proxy
        total_cost += input_cost_per_token[arm_diag] * in_len_diag + cost_per_token[arm_diag] * out_len1
        costs_list[-1] = total_cost  # update cost at this round

        # ────── Subtask 2: Explanation (random) ────────────────────────────────
        # Build explanation prompt
        explanation_prompt = (
            input_reports[t] + "\nAnswer chosen: " + str(reward1) + " Please give an explanation for why this answer choice is correct."
        )
        arm_idx_expl = random.randrange(len(models_explainer))
        arm_expl = models_explainer[arm_idx_expl]
        print(f"Selected explainer {arm_expl}")
        # Play index offset for explainer arms
        plays_no[len(models_summarizer) + len(models_diagnoser) + arm_idx_expl] += 1

        # Generate explanation text (unused for reward computation)
        explanation_text = get_summary(explanation_prompt, arm_expl, "tele")
        out_len2 = len(explanation_text.split())

        # Get gold explanation
        gold_expl = explanations[t]
        # Get regret and reward for explanation
        reg2, reward2, out_len2_eval, avg_acc, _, _ = opt_eval(
            deployments_explainer,
            explanation_prompt,
            "explanation",
            arm_expl,
            avg_acc,
            t,
            all_rewards_sum,
            all_rewards_expl,
            explanations,
            dataset
        )
        # Update cumulative regret
        cum_regret += reg2
        regrets[-1] = cum_regret  # reflect both subtasks
        total_reward += reward2

        # Update explanation rewards and counts
        rewards.append(total_reward)
        all_rewards_expl.append(reward2)
        counts_expl[arm_expl] += 1
        sums_expl[arm_expl] += reward2
        print(f"Explanation reward: {reward2:.3f}")


        # Cost for explanation
        in_len_expl = len(explanation_prompt.split())
        total_cost += input_cost_per_token[arm_expl] * in_len_expl + cost_per_token[arm_expl] * out_len2
        costs_list[-1] = total_cost

        # End of round
        if (t + 1) % 10 == 0:
            print(f"Round {t+1}: total-reward={total_reward:.3f}, reward-expl={reward2:.3f}, total-cost={total_cost:.3f}")
            save_checkpoint(run, t, all_regrets, all_rewards, all_costs, all_plays, all_avg_acc,
                   regrets, rewards, costs_list, plays_no, avg_acc,
                   counts_diag, sums_diag, counts_expl, sums_expl,
                   all_rewards_sum, all_rewards_diag, all_rewards_expl,
                   total_cost, total_reward, cum_regret)

    for arm in models_summarizer:
        if counts_sum[arm] > 0:
            summary_avg_array[arm] = sums_sum[arm] / counts_sum[arm]
    
    for arm in models_diagnoser:
        if counts_diag[arm] > 0:
            diag_avg_array[arm] = sums_diag[arm] / counts_diag[arm]
    
    for arm in models_explainer:
        if counts_expl[arm] > 0:
            expl_avg_array[arm] = sums_expl[arm] / counts_expl[arm]
    
    # Store in lists for aggregation
    all_summary_avg.append(summary_avg_array.copy())
    all_diag_avg.append(diag_avg_array.copy())
    all_expl_avg.append(expl_avg_array.copy())
    
    all_regrets.append(np.array(regrets))
    all_rewards.append(np.array(rewards))
    all_costs.append(np.array(costs_list))
    all_plays.append(plays_no.copy())
    all_avg_acc.append(avg_acc.copy())

# ─── FINAL AGGREGATION AND PICKLE SAVING ─────────────────────────────────────────
import pandas as pd
regrets_arr = np.stack(all_regrets, axis=0)  # shape (no_runs, num_rounds)
rewards_arr = np.stack(all_rewards, axis=0)
costs_arr   = np.stack(all_costs, axis=0)
plays_arr   = np.stack(all_plays, axis=0)      # shape (no_runs, total_arms)
avg_df      = pd.DataFrame(all_avg_acc)

# Compute means and stds
regrets_mean = regrets_arr.mean(axis=0)
regrets_std  = regrets_arr.std(axis=0)
rewards_mean = rewards_arr.mean(axis=0)
rewards_std  = rewards_arr.std(axis=0)
costs_mean    = costs_arr.mean(axis=0)
costs_std     = costs_arr.std(axis=0)
plays_mean    = plays_arr.mean(axis=0)
plays_std     = plays_arr.std(axis=0)
avg_mean      = avg_df.mean(axis=0).to_dict()
avg_std       = avg_df.std(axis=0).to_dict()

summary_avg_df = pd.DataFrame(all_summary_avg)
diag_avg_df = pd.DataFrame(all_diag_avg)
expl_avg_df = pd.DataFrame(all_expl_avg)

summary_mean = summary_avg_df.mean(axis=0).to_dict()
summary_std = summary_avg_df.std(axis=0).to_dict()
diag_mean = diag_avg_df.mean(axis=0).to_dict()
diag_std = diag_avg_df.std(axis=0).to_dict()
expl_mean = expl_avg_df.mean(axis=0).to_dict()
expl_std = expl_avg_df.std(axis=0).to_dict()

# Save pickles for comparison
pickle.dump(regrets_mean, open("tele_results/regrets_mean_tele_rand_3subtasks_2.pkl", "wb"))
pickle.dump(regrets_std,  open("tele_results/regrets_std_tele_rand_3subtasks_2.pkl",  "wb"))
pickle.dump(rewards_mean, open("tele_results/rewards_mean_tele_randgpt_3subtasks_2.pkl","wb"))
pickle.dump(rewards_std,  open("tele_results/rewards_std_tele_randgpt_3subtasks_2.pkl","wb"))
pickle.dump(costs_mean,   open("tele_results/costs_mean_tele_rand_3subtasks_2.pkl",   "wb"))
pickle.dump(costs_std,    open("tele_results/costs_std_tele_rand_3subtasks_2.pkl",    "wb"))
pickle.dump(plays_mean,   open("tele_results/plays_mean_tele_rand_3subtasks_2.pkl",   "wb"))
pickle.dump(plays_std,    open("tele_results/plays_std_tele_rand_3subtasks_2.pkl",    "wb"))
pickle.dump(avg_mean,     open("tele_results/avg_accuracy_mean_tele_rand_3subtasks_2.pkl", "wb"))
pickle.dump(avg_std,      open("tele_results/avg_accuracy_std_tele_rand_3subtasks_2.pkl",  "wb"))
# Save per-model accuracies for each subtask
pickle.dump(summary_mean, open("tele_results/sum_avg_accuracy_mean_tele_rand_3subtasks_2.pkl", "wb"))
pickle.dump(summary_std,  open("tele_results/sum_avg_accuracy_std_tele_rand_3subtasks_2.pkl", "wb"))
pickle.dump(diag_mean,    open("tele_results/diag_accuracy_mean_tele_rand_3subtasks_2.pkl", "wb"))
pickle.dump(diag_std,     open("tele_results/diag_accuracy_std_tele_rand_3subtasks_2.pkl", "wb"))
pickle.dump(expl_mean,    open("tele_results/expl_accuracy_mean_tele_rand_3subtasks_2.pkl", "wb"))
pickle.dump(expl_std,     open("tele_results/expl_accuracy_std_tele_rand_3subtasks_2.pkl", "wb"))

if os.path.exists(checkpoint_path):
    os.remove(checkpoint_path)

print(f"Final cum-regret (mean): {regrets_mean[-1]:.4f}")
print("Random baseline runs complete. Pickles written for 3-subtask telecom.")
