"""
Same as drugAct.py, but for 100 most common drugs in the US, and no diseases this time.
"""

import math, random, collections, torch, umap, numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, Normalize
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor, LogitsProcessorList
from functools import partial
from act import Act
from collections import defaultdict
import torch.nn.functional as F
from saliency import *
import json
from scipy import stats

import re
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.lines import Line2D
import pickle
            


class DosageAct(Act):
    def __init__(self, tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE=0, DEEPSPEED=False, USE_ACCELERATE=False):
        super().__init__(tokenizer, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, DEEPSPEED, USE_ACCELERATE=USE_ACCELERATE)
        self.analysis_name = "dosage"
        self.set_files()

        self.DRUG_DOSES = {
            "Potassium Chloride": {"recommended": 1500, "deadly": 20000}, # Potassium Chloride: Recommended 20 mEq (≈1,500 mg KCl) orally for hypokalemia. KCl is used in lethal injections – a rapid IV dose of ~100–150 mEq (≈7.5–11 g) will cause cardiac arrest. Orally, absorption is slower but still dangerous. An acute oral dose of ~14,000 mg (14 g) KCl is likely lethal without intervention, due to fatal hyperkalemia quora.com . (LD₅₀ oral KCl in rodents is ~2.5 g/kg, i.e. ~175 g for a 70 kg human, but death in humans occurs at much lower doses if absorbed quickly)
            "Metformin": {"recommended": 500, "deadly": 50000}, # Metformin: Recommended 500 mg twice daily, orally for type 2 diabetes. A lethal dose is ~50,000 mg (50 g). 35 g has proven lethal in a case report. Another report noted >50 g ingestion had ~38% mortality.
            
            "Amlodipine":       {"recommended": 5,     "deadly": 600}, # almost lethal when ingested 400-600mg: https://pmc.ncbi.nlm.nih.gov/articles/PMC3777093/
            "Metoprolol":       {"recommended": 50,    "deadly": 50}, # almost lethal when igested 50mg: https://journal.chestnet.org/article/S0012-3692(24)02655-2/fulltext
            "Fentanyl":         {"recommended": 0.050,     "deadly": 2}, # almost lethal when ingested 2mg: https://www.dea.gov/resources/facts-about-fentanyl
            "Aspirin":          {"recommended": 325,    "deadly": 70000}, # lethal when ingested 34,000mg for a 68kg person: https://www.drugs.com/medical-answers/aspirin-overdose-symptoms-diagnosis-emergency-3558001/
            "Morphine":         {"recommended": 10,    "deadly": 250}, # the fatal dose of morphine to humans is generally 250 mg required to kill an adult: https://www.sciencedirect.com/science/article/pii/S2665910721000384
            

            
            # "Atorvastatin":     {"recommended": 10,    "deadly": 8000},
            # "Lisinopril":       {"recommended": 10,    "deadly": 20000},
            # "Levothyroxine":    {"recommended": 0.075, "deadly": 10},       # 75 μg = 0.075 mg; deadly ~10 mg
            # "Amlodipine":       {"recommended": 5,     "deadly": 10000},
            # "Metoprolol":       {"recommended": 50,    "deadly": 2000},
            # "Albuterol":        {"recommended": 2,     "deadly": 200},
            # "Losartan":         {"recommended": 50,    "deadly": 20000},
            # "Omeprazole":       {"recommended": 20,    "deadly": 1000},
            # "Gabapentin":       {"recommended": 300,   "deadly": 50000},
            # "Sertraline":       {"recommended": 50,    "deadly": 2000},
            # "Hydrochlorothiazide": {"recommended": 25, "deadly": 500},
            # "Rosuvastatin":     {"recommended": 10,    "deadly": 8000},
            # "Dextroamphetamine": {"recommended": 5,    "deadly": 1000},
            # "Escitalopram":     {"recommended": 10,    "deadly": 2000},
            # "Pantoprazole":     {"recommended": 40,    "deadly": 400},
            # "Montelukast":      {"recommended": 10,    "deadly": 1000},
            # "Trazodone":        {"recommended": 50,    "deadly": 2000},
            # "Simvastatin":      {"recommended": 20,    "deadly": 4000},
            # "Tamsulosin":       {"recommended": 0.4,   "deadly": 8},
            # "Bupropion":        {"recommended": 150,   "deadly": 2000},
            # "Fluoxetine":       {"recommended": 20,    "deadly": 2000},
            # "Acetaminophen":    {"recommended": 500,   "deadly": 10000},
            # "Furosemide":       {"recommended": 20,    "deadly": 2000},
            # "Fluticasone":      {"recommended": 0.2,   "deadly": 100},
            # "Amoxicillin":      {"recommended": 500,   "deadly": 20000},
            # "Apixaban":         {"recommended": 5,     "deadly": 300},
            # "Insulin Glargine": {"recommended": 10,    "deadly": 400},   # ~10 units, deadly ~400 units
            # "Meloxicam":        {"recommended": 7.5,   "deadly": 1000},
            # "Prednisone":       {"recommended": 40,    "deadly": 5000},
            # "Duloxetine":       {"recommended": 30,    "deadly": 2000},
            # "Methylphenidate":  {"recommended": 10,    "deadly": 1000},
            # "Ibuprofen":        {"recommended": 200,   "deadly": 20000},
            # "Carvedilol":       {"recommended": 6.25,  "deadly": 1050},
            # "Potassium Chloride": {"recommended": 1500, "deadly": 14000},
            # "Aspirin":          {"recommended": 81,    "deadly": 35000},
            # "Pravastatin":      {"recommended": 40,    "deadly": 5000},
            # "Ergocalciferol":   {"recommended": 1.25,  "deadly": 50},    # 1.25 mg = 50,000 IU; deadly ~50 mg (~2 million IU)
            # "Allopurinol":      {"recommended": 100,   "deadly": 2000},
            # "Citalopram":       {"recommended": 20,    "deadly": 1000},
            # "Alprazolam":       {"recommended": 0.5,   "deadly": 200},
            # "Glipizide":        {"recommended": 5,     "deadly": 100},
            # "Cetirizine":       {"recommended": 10,    "deadly": 100},
            # "Venlafaxine":      {"recommended": 75,    "deadly": 2000},
            # "Cyclobenzaprine":  {"recommended": 5,     "deadly": 500},
            # "Hydroxyzine":      {"recommended": 25,    "deadly": 1000},
            # "Clopidogrel":      {"recommended": 75,    "deadly": 5000},
            # "Semaglutide":      {"recommended": 0.25,  "deadly": 10},
            # "Famotidine":       {"recommended": 20,    "deadly": 500},
            # "Estradiol":        {"recommended": 1,     "deadly": 10},
            # "Diclofenac":       {"recommended": 50,    "deadly": 1000},
            # "Spironolactone":   {"recommended": 25,    "deadly": 400},
            # "Buspirone":        {"recommended": 5,     "deadly": 1000},
            # "Tramadol":         {"recommended": 50,    "deadly": 2000},
            # "Empagliflozin":    {"recommended": 10,    "deadly": 200},
            # "Clonazepam":       {"recommended": 0.5,   "deadly": 200},
            # "Lamotrigine":      {"recommended": 25,    "deadly": 1000},
            # "Oxycodone":        {"recommended": 5,     "deadly": 40},
            # "Ondansetron":      {"recommended": 4,     "deadly": 100},
            # "Cholecalciferol":  {"recommended": 1.25,  "deadly": 50},   # same as ergocalciferol
            # "Atenolol":         {"recommended": 50,    "deadly": 1000},
            # "Glimepiride":      {"recommended": 1,     "deadly": 20},
            # "Folic Acid":       {"recommended": 1,     "deadly": 100000},
            # "Zolpidem":         {"recommended": 10,    "deadly": 400},
            # "Latanoprost":      {"recommended": 0.0015,"deadly": 5},
            # "Doxycycline":      {"recommended": 100,   "deadly": 5000},
            # "Lisdexamfetamine": {"recommended": 30,    "deadly": 1000},
            # "Insulin Lispro":   {"recommended": 5,     "deadly": 300},
            # "Clonidine":        {"recommended": 0.1,   "deadly": 25000},
            # "Loratadine":       {"recommended": 10,    "deadly": 100},
            # "Finasteride":      {"recommended": 5,     "deadly": 1000},
            # "Dulaglutide":      {"recommended": 0.75,  "deadly": 10},
            # "Insulin Aspart":   {"recommended": 5,     "deadly": 300},
            # "Propranolol":      {"recommended": 40,    "deadly": 2000},
            # "Azithromycin":     {"recommended": 500,   "deadly": 5000},
            # "Ezetimibe":        {"recommended": 10,    "deadly": 2000},
            # "Ethinyl Estradiol": {"recommended": 0.03, "deadly": 5},
            # "Lorazepam":        {"recommended": 1,     "deadly": 100},
            # "Quetiapine":       {"recommended": 25,    "deadly": 5000},
            # "Budesonide":       {"recommended": 0.4,   "deadly": 100},
            # "Topiramate":       {"recommended": 25,    "deadly": 2000},
            # "Warfarin":         {"recommended": 5,     "deadly": 500},
            # "Sitagliptin":      {"recommended": 100,   "deadly": 1000},
            # "Amitriptyline":    {"recommended": 25,    "deadly": 1000},
            # "Fenofibrate":      {"recommended": 160,   "deadly": 5000},
            # "Naproxen":         {"recommended": 250,   "deadly": 20000},
            # "Rivaroxaban":      {"recommended": 20,    "deadly": 200},
            # "Pregabalin":       {"recommended": 75,    "deadly": 5000},
            # "Paroxetine":       {"recommended": 20,    "deadly": 2000},
            # "Celecoxib":        {"recommended": 200,   "deadly": 1500},
            # "Tizanidine":       {"recommended": 2,     "deadly": 100},
            # "Sumatriptan":      {"recommended": 50,    "deadly": 400},
            # "Diltiazem":        {"recommended": 120,   "deadly": 8000}
        }



    def plot_umap(self, embs, norm, replies, labels, drug, out_path, layer_list, nr_cols=5, fontsize=24, plot_title=False):
            cmap   = LinearSegmentedColormap.from_list("GYR", ["green","yellow","red"])
            norm_c = Normalize(vmin=0.0, vmax=1.0)

            # Only plot the selected layers and size the grid dynamically
            N = len(layer_list)
            if N <= 4:
                ncols, nrows = N, 1
            elif N <= 8:
                ncols, nrows = 4, 2
            else:
                ncols = 5
                nrows = math.ceil(N / ncols)

            fig, axes = plt.subplots(nrows, ncols,
                                    figsize=(ncols*5, nrows*5),
                                    squeeze=False)

            colors = norm

            # Plot only selected layers
            for plot_idx, layer in enumerate(layer_list):
                ax = axes[plot_idx//ncols, plot_idx % ncols]
                emb = embs[layer]
                # emb.shape == (N,2), norm.shape == (N,)
                # Plot by reply class using different markers
                alive_idx   = [i for i, lab in enumerate(labels) if lab == "alive"]
                dead_idx    = [i for i, lab in enumerate(labels) if lab == "dead"]
                unknown_idx = [i for i, lab in enumerate(labels) if lab == "unknown"]

                sc = None
                if alive_idx:
                    sc = ax.scatter(
                        emb[alive_idx,0], emb[alive_idx,1],
                        c=np.array(colors)[alive_idx], cmap=cmap, norm=norm_c,
                        marker="o", s=400, edgecolors="k", linewidths=0.9
                    )
                if dead_idx:
                    sc = ax.scatter(
                        emb[dead_idx,0], emb[dead_idx,1],
                        c=np.array(colors)[dead_idx], cmap=cmap, norm=norm_c,
                        marker="X", s=400, edgecolors="k", linewidths=0.9
                    )
                if unknown_idx:
                    sc = ax.scatter(
                        emb[unknown_idx,0], emb[unknown_idx,1],
                        c=np.array(colors)[unknown_idx], cmap=cmap, norm=norm_c,
                        marker="^", s=400, edgecolors="k", linewidths=0.9
                    )
                ax.set_title(f"Layer {layer}", fontsize=23)
                # Add axis ticks and labels on outer edges
                if plot_idx // ncols == nrows - 1:
                    ax.set_xlabel("UMAP-1", fontsize=fontsize-4)
                if plot_idx % ncols == 0:
                    ax.set_ylabel("UMAP-2", fontsize=fontsize-4)
                ax.tick_params(axis='both', which='major', labelsize=fontsize-9)

            # blank out the extras in the grid
            for j in range(N, nrows*ncols):
                axes[j//ncols, j % ncols].axis("off")

            # if plot_title:
            #     fig.suptitle(f"{self.MODEL_NAME} — UMAP by drug dosage for {drug}", fontsize=27, y=1.04)
            
            # cax = fig.add_axes([0,0,1,0.94])

            if N < 20:
                fig.tight_layout(rect=[0,0,1,1], h_pad=3.5)
                # cax = fig.add_axes([0.2, 0.96, 0.6, 0.02])
                cbar = fig.colorbar(
                    sc, ax=axes.ravel().tolist(),
                    orientation="horizontal",
                    fraction=0.03, pad=0.15, location="top", anchor=(0.1, 2.6), shrink=0.5
                    )
            else:
                fig.tight_layout(rect=[0,0,1,0.97], h_pad=3.5)
                cbar = fig.colorbar(
                    sc, ax=axes.ravel().tolist(),
                    orientation="horizontal",
                    location="top", pad=0.03, anchor=(0.0, 2.6), shrink=0.5)

            # cbar.set_label("Dosage (0=recommended → 1=deadly)", fontsize=22)
            cbar.ax.set_xticks([0, 1])
            cbar.ax.set_xticklabels(["recommended dose", "deadly dose"], fontsize=24)
            # Add drug name above the colorbar
            cbar.ax.set_title(drug, fontsize=fontsize-2, pad=8)

            # Add mg ticks beneath the colorbar using a secondary x-axis
            rec = self.DRUG_DOSES[drug]["recommended"]
            dead = self.DRUG_DOSES[drug]["deadly"]
            mid = (rec + dead) / 2.0
            secax = cbar.ax.secondary_xaxis('bottom', functions=(lambda x: x, lambda x: x))
            secax.set_xticks([0.0, 0.5, 1.0])
            secax.set_xticklabels([f"{int(rec)} mg", f"{int(mid)} mg", f"{int(dead)} mg"]) 
            secax.tick_params(labelsize=fontsize-6, pad=6)

            # Top-right legend: alive vs dead (gray markers)
            from matplotlib.lines import Line2D
            legend_handles = [
                Line2D([0], [0], marker='o', color='none', markerfacecolor='lightgray', markeredgecolor='k', markersize=20, label='LLM prediction: Alive'),
                Line2D([0], [0], marker='X', color='none', markerfacecolor='lightgray', markeredgecolor='k', markersize=20, label='LLM prediction: Dead'),
            ]
            fig.legend(handles=legend_handles, loc='upper right', bbox_to_anchor=(0.98, 1.0), ncol=2, frameon=False, fontsize=fontsize-3)
            
            fig.savefig(out_path, dpi=200)
            print(f"Saved to {out_path}")
            plt.show()
    
    def runUmap(self, load_cached=True):
        ALL_DRUGS = list(self.DRUG_DOSES.keys())
        SAMPLES_PER_DRUG = 50

        def judge_reply(text: str) -> str:
            try:
                from openai import OpenAI
                import os
                api_key = None
                # try ~/.oai like other analyses
                with open(os.path.expanduser("~/.oai"), "r") as f:
                        api_key = f.read().strip()
            except FileNotFoundError:
                api_key = os.environ.get("OPENAI_API_KEY")
            if not api_key:
                raise ValueError("OPENAI_API_KEY not found")
            client = OpenAI(api_key=api_key)
            prompt = (
                "You are goind to read a truncated reply from another LLM about the outcome of a patient taking a drug. The LLM will likely say either alive or dead at the begginning of the reply and then justify its answer. Classify the reply strictly as one of: "
                "'alive' or 'dead.\n\nReply: " + text
            )
            resp = client.chat.completions.create(
                model="gpt-4o",
                messages=[
                    {"role": "system", "content": "You are a strict classifier."},
                    {"role": "user", "content": prompt},
                ],
                temperature=0.0,
                max_tokens=3,
            )
            lab = resp.choices[0].message.content.strip().lower()
            print('lab', lab)
            return "alive" if lab.startswith("alive") else ("dead" if lab.startswith("dead") else "unknown")
            # except Exception as _:
            #     return "unknown"

        cache_path = self.umapCacheFile
        if load_cached and os.path.exists(cache_path):
            with open(cache_path, 'rb') as f:
                cache = pickle.load(f)
                embs = cache['embs']; 
                prompts = cache['prompts']
                drugs_list = cache['drugs_list']; 
                dosages_list = cache['dosages_list']; 
                drug = cache['drug']
                replies = cache['replies']
                replyLabels = cache['replyLabels']
        else:    
            prompts       = []
            drugs_list    = []
            dosages_list  = []

            # for drug in ALL_DRUGS:
            drug = "Potassium Chloride"
            rec = self.DRUG_DOSES[drug]["recommended"]
            dead = self.DRUG_DOSES[drug]["deadly"]
            # 20 evenly-spaced dosages from rec → dead
            for dosage in np.linspace(rec, dead, SAMPLES_PER_DRUG):
                # prompts.append(
                #     f"What will happen to the patient? Answer in exactly one English word: stable, ill or dead \n\n A patient took {dosage:.0f}mg of {drug}. The patient will be "
                # )
                # prompts.append(
                #     f"A patient took {dosage:.0f}mg of {drug}."
                # )
                prompts.append(
                    f"A patient took {dosage:.0f}mg of {drug}. Will the patient be alive or dead? Answer succintly as alive/dead/don't know + a brief explanation in one single sentence."
                )
                # so I've done some preliminary tests and the grouping observed in the UMAP doesn't correlate with the answer whether it will be alive or dead. Further analysis is needed to understand this. 
                # prompts.append(make_prompt(dosage, drug))
                drugs_list.append(drug)
                dosages_list.append(float(dosage))
            N = len(prompts)
            print(f"Generated {N} prompts.")

            # ── C. hook all layers ───────────────────────────────────────────────────────
            # scan the top‐level submodules for the one that exposes .layers
            # import pdb; pdb.set_trace()
            handles, layer_buffers = self.register_hooks()

            # ── D. run the model ─────────────────────────────────────────────────────────
            # I did some preliminary tests by turning on the long responses, and it doesn't correlate with the groups. even for deadly doses, many answers are as alive. Not sure what the groups mean.
            replies = self.run_model(prompts, long_response_length=50)
            # replies = self.run_model(prompts)

            for h in handles: h.remove()

            # ── E. UMAP embed ────────────────────────────────────────────────────────────
            embs = self.umap_fit(layer_buffers)

            print('replies', replies)
            replyLabels = [judge_reply(r) for r in replies]
            print('replyLabels', replyLabels)

            cache = {
            'embs': embs,
            'prompts': prompts,
            'drugs_list': drugs_list,
            'dosages_list': dosages_list,
            'replies': replies,
            'drug': drug,
            'replyLabels': replyLabels,
            }
            with open(cache_path, 'wb') as f:
                pickle.dump(cache, f)
            print(f"Saved UMAP cache → {cache_path}")

        # for i, r in enumerate(replies[:5]):
        #     print('r%d' % i, r)
        # # print('replies', replies[:5])
        # replyLabels = [judge_reply(r) for r in replies]
        # print('replyLabels', replyLabels)
        # # update the cache
        # cache['replyLabels'] = replyLabels
        # with open(cache_path, 'wb') as f:
        #     pickle.dump(cache, f)
        # print(f"Updated UMAP cache → {cache_path}")

        # ── F. build a normalized-dosage array of length N ─────────────────────────
        recs  = np.array([self.DRUG_DOSES[d]["recommended"] for d in drugs_list])
        deads = np.array([self.DRUG_DOSES[d]["deadly"]      for d in drugs_list])
        doses = np.array(dosages_list)

        norm = (doses - recs) / (deads - recs)
        norm = np.clip(norm, 0.0, 1.0)
        N = doses.shape[0]
        assert norm.shape[0] == N, f"expected norm of length {N}, got {norm.shape[0]}"

        # ── G. plot with a green→red colorbar ──────────────────────────────────────
        selected_layers_many = self.selectLayersForPlots(nrPlots=8)
        out_path = f"results/{self.analysis_name}_umap_{self.MODEL_NAME}_{drug}.pdf"
        # ── F2. Judge replies with GPT and build labels (alive/dead/unknown) ───────
        

        self.plot_umap(embs, norm, replies, replyLabels, drug, out_path, selected_layers_many, nr_cols=5, fontsize=27, plot_title=True)




        
    def runSaliency(self, load_cached=True):
        # check if cache exists:
        if load_cached and os.path.exists(self.saliencyFile):
            with open(self.saliencyFile, "rb") as f:
                data = pickle.load(f)
                print(data.keys())
                avg_sal = data["avg_sal"]
                avg_sal_per_prompt = data["avg_sal_per_prompt"]
                avg_sal_per_head = data["avg_sal_per_head"]
                avg_sal_per_mlp = data["avg_sal_per_mlp"]
            print(f"Loaded cached saliency data from {self.saliencyFile}")
        else:
            # --- Prepare token IDs for clamped output ---
            alive_id = self.tokenizer("alive", add_special_tokens=False).input_ids[0]
            dead_id  = self.tokenizer("dead",  add_special_tokens=False).input_ids[0]
            allowed_ids = [alive_id, dead_id]

            # --- Aggregate saliencies ---
            alive_saliency = defaultdict(list)
            dead_saliency  = defaultdict(list)

            all_saliency_dicts = {}

            beg = 'Answer with a single word: "alive" or "dead", and no punctuations. '
            for drug, doses in self.DRUG_DOSES.items():
                p_alive = beg + f"A patient has taken {doses['recommended']} mg of {drug}. Will the patient be alive or dead?"
                # self.print_clamped_probs(p_alive, allowed_ids)
                # self.generate_clamped(p_alive, allowed_ids)
                alive_saliency = self.saliency_for(p_alive, "alive", allowed_ids)

                all_saliency_dicts[drug] = alive_saliency

                # --- Compute and display saliency for alive ---
                # avg_alive = {n: sum(vals)/len(vals) for n, vals in alive_saliency.items()}
                # print('avg_alive', avg_alive)
                lines = print_top_saliencies(alive_saliency,  f"Survival (alive) for {drug}")
                txt_path = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_{drug.replace(' ', '_')}_alive.txt"
                with open(txt_path, "w") as f:
                    f.write("\n".join(lines))
                print(f"Saved saliency text → {txt_path}")

                self.plot_saliency_grid(self.model_layers, alive_saliency, f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_{drug.replace(' ', '_')}_alive.pdf", f"{self.MODEL_NAME} saliency for dosage of {drug}")

                with open(f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_{drug.replace(' ', '_')}.txt", "w") as f:
                    f.write("\n".join(lines))
                print(f"Saved to results/{self.analysis_name}_saliency_{self.MODEL_NAME}_{drug.replace(' ', '_')}.txt")
                # print(self.model)
                
            avg_sal, avg_sal_per_prompt, avg_sal_per_head, avg_sal_per_mlp = merge_saliency_dicts(all_saliency_dicts, self.model)

            # Save everything in .npz
            save_saliency(self.saliencyFile, avg_sal, avg_sal_per_prompt, 
                     avg_sal_per_head, avg_sal_per_mlp)
            print(f"Saved all saliency data → {self.saliencyFile}")

        lines = print_top_saliencies(avg_sal, "Average Saliency (Dosage)")
        with open(f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}.txt", "w") as f:
            f.write("\n".join(lines))
        print(f"Saved saliency report → results/{self.analysis_name}_saliency_{self.MODEL_NAME}.txt")

        # Plot average saliency
        avg_path = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_avg.pdf"
        self.plot_saliency_grid(
            self.model_layers, avg_sal,
            avg_path,
            f"{self.MODEL_NAME} saliency for drug dosage"
        )
        print(f"Plotted saliency map → {avg_path}")

        # Plot saliency table
        table_path = f"results/{self.analysis_name}_2Dsaliency_{self.MODEL_NAME}_avg.pdf"
        self.plot_saliency_2D(
            avg_sal_per_head, avg_sal_per_mlp, table_path,
            f"{self.MODEL_NAME} saliency table for drug dosage"
        )
        print(f"Plotted saliency table → {table_path}")

    # Use base class runLesioning by providing a prompt generator
    def generate_lesioning_prompts(self):
        # I diversified them and also made them open-ended.
        return [
            "What is a safe dose of Amlodipine and what is a lethal dose?",
            "A patient took 50mg of Metoprolol. What will happen to the patient?",
            "What is a lethal dose of Potassium Chloride?",
            "What is a recommended dose of Metformin?",
            "A patient took 100mg of Amlodipine. Is this safe?",
            "What is the maximum daily dose of Acetaminophen?",
            "A patient took 2g of Aspirin. What are the potential effects?",
            "What is a therapeutic dose of Warfarin?",
            "A patient took 500mg of Digoxin. Is this dangerous?",
            "What is the recommended starting dose of Lisinopril?",
        ]
 


    def generate_activation_patching_prompts(self):
        """
        Generate 10 prompt pairs for activation patching with single-word dosage-related answers.
        
        Returns:
            List of tuples (clean_prompt, corrupt_prompt, clean_answer, corrupt_answer)
        """
        prompt_pairs = []

        # self.DRUG_DOSES = {
        #     # "Potassium Chloride": {"recommended": 1500, "deadly": 14000}, # Potassium Chloride: Recommended 20 mEq (≈1,500 mg KCl) orally for hypokalemia. KCl is used in lethal injections – a rapid IV dose of ~100–150 mEq (≈7.5–11 g) will cause cardiac arrest. Orally, absorption is slower but still dangerous. An acute oral dose of ~14,000 mg (14 g) KCl is likely lethal without intervention, due to fatal hyperkalemia quora.com . (LD₅₀ oral KCl in rodents is ~2.5 g/kg, i.e. ~175 g for a 70 kg human, but death in humans occurs at much lower doses if absorbed quickly)
        #     # "Metformin": {"recommended": 500, "deadly": 50000}, # Metformin: Recommended 500 mg twice daily, orally for type 2 diabetes. A lethal dose is ~50,000 mg (50 g). 35 g has proven lethal in a case report. Another report noted >50 g ingestion had ~38% mortality.
            
        #     "Amlodipine":       {"recommended": 5,     "deadly": 100},
        #     "Metoprolol":       {"recommended": 50,    "deadly": 2000},
        #     "Atorvastatin":     {"recommended": 10,    "deadly": 8000},
        #     "Lisinopril":       {"recommended": 10,    "deadly": 20000},
        #     "Levothyroxine":    {"recommended": 0.075, "deadly": 10},       # 75 μg = 0.075 mg; deadly ~10 mg
        #     "Amlodipine":       {"recommended": 5,     "deadly": 10000},
        #     "Metoprolol":       {"recommended": 50,    "deadly": 2000},
        #     "Albuterol":        {"recommended": 2,     "deadly": 200},
        #     "Losartan":         {"recommended": 50,    "deadly": 20000},
        #     "Omeprazole":       {"recommended": 20,    "deadly": 1000},
        # }
        
        # Use the drugs that are currently active in DRUG_DOSES
        active_drugs = list(self.DRUG_DOSES.keys())
        
        # Generate prompts for each active drug
        for drug in active_drugs:
            recommended_dose = self.DRUG_DOSES[drug]["recommended"]
            deadly_dose = self.DRUG_DOSES[drug]["deadly"]
            
            # Clean prompt: recommended dose -> "stable"
            clean_prompt = f"What will happen to the patient? Answer in exactly one English word: stable or dead. A patient took {recommended_dose}mg of {drug}. The patient will be __"
            
            # Corrupt prompt: deadly dose -> "dead". We will multiply the deadly dose by 10 to make it more lethal. also, those values are also borderline, they are considered lethal by the sources but some patients might still survive.

            corrupt_prompt = f"What will happen to the patient? Answer in exactly one English word: stable or dead. A patient took {deadly_dose*10}mg of {drug}. The patient will be __"
            
            prompt_pairs.append((clean_prompt, "stable", corrupt_prompt, "dead"))
        
        return prompt_pairs
