"""
Same as drugAct.py, but for 100 most common drugs in the US, and no diseases this time.
"""

import math, random, collections, torch, umap, numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from transformers import AutoModelForCausalLM, AutoTokenizer
from functools import partial
import matplotlib.colors as mcolors
from scipy import stats
import json
from saliency import *
import os
import pickle, collections
from utils import runUMAPClustering

from act import Act

class DrugAct(Act):
    def __init__(self, tok, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=False):
        super().__init__(tok, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=USE_ACCELERATE)
        self.analysis_name = "drugs"
        self.set_files()

    # def run(self, do_umap=True, do_saliency=True, do_maps=True, do_lesioning=False, do_activation_patching=False, do_activation_patching_finegrained=False):
    #     """
    #     Override the parent run method to include activation patching.
    #     """
    #     if do_umap:
    #         self.runUmap()
    #     if do_saliency:
    #         self.runSaliency()
    #     if do_lesioning:
    #         self.runLesioning()
    #     if do_activation_patching:
    #         self.runActivationPatching()
    #     if do_activation_patching_finegrained:
    #         self.runActivationPatchingFinegrained()
    #     if do_maps:
    #         self.plotMap()

    # plot helper
    def plot_umap(self, mapping, title, out_path, ALL_DRUGS, embs, idxs_by_drug, drug_markers, layer_list=None, nr_cols=5, fontsize=24, plot_title=True, drug_row_height=3, drug_row_offset=25.5, legend_cattitles_two_rows=False):
        categories = list(dict.fromkeys(mapping[d] for d in ALL_DRUGS))
        cmap       = get_cmap("tab20")
        catcol     = {cat: cmap(i % cmap.N) for i,cat in enumerate(categories)}

        # Determine which layers to plot
        if layer_list is None:
            layers_to_plot = list(range(self.num_layers))
        else:
            layers_to_plot = layer_list
        
        n_layers_to_plot = len(layers_to_plot)
        
        # Special layout for selected layers with dynamic grid
        # if layer_list is not None:
        # Determine grid layout based on number of layers
        if n_layers_to_plot == 4:
            # 4x1 grid
            nrows, ncols = 1, 4
            fig = plt.figure(figsize=(25, 12))
            gs = fig.add_gridspec(2, 1, height_ratios=[1, 1], hspace=0.1)
            plot_gs = gs[0].subgridspec(1, 4, wspace=0.1)
            legend_gs = gs[1]
            largePlot = False
            drugsPerCategoryTruncation = 15
        elif 5 <= n_layers_to_plot <= 8:
            # 4x2 grid
            nrows, ncols = 2, 4
            fig = plt.figure(figsize=(25, 24))
            gs = fig.add_gridspec(2, 1, height_ratios=[2, 1], hspace=0.1)
            plot_gs = gs[0].subgridspec(2, 4, wspace=0.1)
            legend_gs = gs[1]
            largePlot = False
            drugsPerCategoryTruncation = 19
        else:
            # 5xN grid for larger numbers
            ncols = 5
            nrows = (n_layers_to_plot + ncols - 1) // ncols  # Ceiling division
            fig = plt.figure(figsize=(25, 4 + nrows * 5))
            gs = fig.add_gridspec(2, 1, height_ratios=[nrows, 1.6], hspace=0.065)
            plot_gs = gs[0].subgridspec(nrows, ncols, wspace=0.2, hspace=0.3)
            legend_gs = gs[1]
            largePlot = True
            drugsPerCategoryTruncation = 19
        
        # Plot the layers
        for plot_idx, layer in enumerate(layers_to_plot):
            if n_layers_to_plot == 4:
                ax = fig.add_subplot(plot_gs[0, plot_idx])
            elif 5 <= n_layers_to_plot <= 8:
                row = plot_idx // 4
                col = plot_idx % 4
                ax = fig.add_subplot(plot_gs[row, col])
            else:
                row = plot_idx // ncols
                col = plot_idx % ncols
                ax = fig.add_subplot(plot_gs[row, col])
            
            emb = embs[layer]
            for drug, idxs in idxs_by_drug.items():
                col = catcol[mapping[drug]]
                ax.scatter(emb[idxs,0], emb[idxs,1],
                        c=[col]*len(idxs),
                        marker=drug_markers[drug],
                        s=400, edgecolors='k', linewidths=0.9)
            ax.set_title(f"Layer {layer}", fontsize=fontsize-3)
            
            
            if plot_idx // ncols == nrows - 1:
                ax.set_xlabel("UMAP-1", fontsize=fontsize-2)
            else:
                ax.set_xlabel("")
            if plot_idx % ncols == 0:  # Leftmost column
                ax.set_ylabel("UMAP-2", fontsize=fontsize-2)
            
            
            # Slightly increase tick label font size
            ax.tick_params(axis='both', which='major', labelsize=fontsize-(13 if largePlot else 9))
        
        # Hide unused subplots for the 5xN grid
        if n_layers_to_plot > 8:
            for i in range(n_layers_to_plot, nrows * ncols):
                row = i // ncols
                col = i % ncols
                fig.add_subplot(plot_gs[row, col]).axis('off')
        
        # Create categorized legend in bottom 50%
        n_categories = len(categories)
        legend_cols = n_categories
        max_drugs_per_cat = max(len([d for d in ALL_DRUGS if mapping[d] == cat]) for cat in categories)
        legend_rows = max_drugs_per_cat + 2  # Extra space for title and spacing
        
        # Create legend subplot
        row_height = 1.8
        legend_ax = fig.add_subplot(legend_gs)
        legend_ax.set_xlim(0, legend_cols * 1.2)  # Add horizontal spacing
        legend_ax.set_ylim(0, legend_rows * row_height)  # Increased vertical spacing
        legend_ax.axis('off')
        
        # Plot category titles and drug names
        col_width = 1.25  # Increased column width
        for cat_idx, cat in enumerate(categories):
            x_pos = cat_idx * col_width + 0.05
            y_pos = legend_rows + drug_row_offset + 2
            
            # Category title in bold. if cat made up of two words, put them in two rows. If 3+ words, put all words other than the first on the second row.
            fontReduction = 6 if largePlot else 2
            if legend_cattitles_two_rows:
                y_pos += 2
                cat_words = cat.split()
                if len(cat_words) > 2:
                    word1 = ' '.join(cat_words[:2])
                    word2 = ' '.join(cat_words[2:])
                else:
                    word1 = cat_words[0]
                    word2 = cat_words[1]
                legend_ax.text(x_pos, y_pos, word1, ha='left', va='center', 
                                fontsize=fontsize-fontReduction, fontweight='bold')
                legend_ax.text(x_pos, y_pos - 2.1, word2, ha='left', va='center', 
                                fontsize=fontsize-fontReduction, fontweight='bold')
            else:
                legend_ax.text(x_pos, y_pos, cat, ha='left', va='center', 
                            fontsize=fontsize-fontReduction, fontweight='bold')
            
            # Drug names in this category
            drugs_in_cat = [d for d in ALL_DRUGS if mapping[d] == cat]
            for drug_idx, drug in enumerate(drugs_in_cat):
                if drug_idx >= drugsPerCategoryTruncation:
                    print(f"Truncated legend for category '{cat}': showing {drugsPerCategoryTruncation} of {len(drugs_in_cat)} drugs")
                    break
                y_drug = legend_rows + drug_row_offset - drug_idx * drug_row_height  # Increased vertical spacing
                
                # Draw colored marker
                col = catcol[cat]
                marker = drug_markers[drug]
                legend_ax.scatter(x_pos, y_drug, c=[col], marker=marker, 
                                s=200, edgecolors='k', linewidths=0.9)
                
                # Drug name left-aligned next to marker
                # if name of drug is longer than 10 characters, abbreviate and add dot
                drugAbbr = drug
                if len(drug) > 12:
                    drugAbbr = drug[:12] + "."
                legend_ax.text(x_pos + 0.07, y_drug, drugAbbr, ha='left', va='center', 
                                fontsize=fontsize-fontReduction)
        
        if plot_title:
            fig.suptitle(f"{self.MODEL_NAME} - Drugs UMAP colored by {title}", 
                        y=0.95, fontsize=fontsize+6)

        # Minimize outer padding around the selected-layers figure
        fig.subplots_adjust(left=0.00, right=1, top=1, bottom=0.00)
        # if n_layers_to_plot < 15:
        #     fig.tight_layout(rect=[0,0,0.98,0.98], pad=2.5, w_pad=2.0, h_pad=2.5)
            
        # else:
        #     # Regular layout for all layers or other cases
        #     ncols, nrows = nr_cols, math.ceil(n_layers_to_plot/nr_cols)
        #     fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*8,nrows*6), squeeze=False)

        #     for plot_idx, layer in enumerate(layers_to_plot):
        #         ax = axes[plot_idx//ncols, plot_idx % ncols]
        #         emb = embs[layer]
        #         for drug, idxs in idxs_by_drug.items():
        #             col = catcol[mapping[drug]]
        #             ax.scatter(emb[idxs,0], emb[idxs,1],
        #                     c=[col]*len(idxs),
        #                     marker=drug_markers[drug],
        #                     s=400, edgecolors='k', linewidths=0.2)
        #         ax.set_title(f"Layer {layer}", fontsize=fontsize-3)
        #         ax.set_xticks([]); ax.set_yticks([])

        #     # Hide unused subplots
        #     for i in range(n_layers_to_plot, nrows*ncols):
        #         axes[i//ncols, i % ncols].axis("off")

        #     from matplotlib.lines import Line2D
        #     handles = []
        #     for cat in categories:
        #         col = catcol[cat]
        #         drugs_in_cat = [d for d in ALL_DRUGS if mapping[d] == cat]
        #         for d in drugs_in_cat:
        #             label = rf"{d} ({cat})"
        #             handles.append(
        #                 Line2D([0], [0],
        #                     marker=drug_markers[d],
        #                     color='w',
        #                     markerfacecolor=col,
        #                     markeredgecolor='k',
        #                     markersize=25,
        #                     label=label)
        #             )
        #     fig.legend(handles=handles,
        #             loc='center left',
        #             bbox_to_anchor=(0.7,0.5),
        #             frameon=False,
        #             fontsize=fontsize-1,
        #             handletextpad=1.0)

        #     if plot_title:
        #         fig.suptitle(f"{self.MODEL_NAME} - Drugs UMAP colored by {title}", y=0.98, fontsize=fontsize+6)
        #     fig.tight_layout(
        #         rect=[0,0,0.7,0.975],
        #         pad=2.5,
        #         w_pad=2.0,
        #         h_pad=2.5
        #     )
        
        fig.savefig(out_path, dpi=200)
        print(f"Saved to {out_path}")

    # def selectLayersForPlots(self, nrPlots=4):
    #     """
    #     Select 4 evenly-spaced layers for plotting, including first and last layers.
    #     """
    #     if self.num_layers <= nrPlots:
    #         return list(range(self.num_layers))
        
    #     # Always include first and last layers
    #     selected = [0, self.num_layers - 1]
        
    #     # Add 2 intermediate layers
    #     if self.num_layers > nrPlots-1:
    #         step = (self.num_layers - 1) / 3  # Divide into 3 intervals
    #         for i in range(1, 3):
    #             layer_idx = int(round(i * step))
    #             if layer_idx not in selected:
    #                 selected.append(layer_idx)
        
    #     return sorted(selected)


# Paper feedback from medical professional - Mechanism of action
# Receptor agonists and antagonist should be listed next to each other
# List of abbreviated drugs for figure legend:
# Isosorbide Mon - Isosorbide Mononitrate
# Fludrocortison - Fludrocortisone
# Cyclobenzaprin - Cyclobenzaprine
# Dextroamphetam - Dextroamphetamine
# Methylphenidat - Methylphenidate
# Lisdexamfetami - Lisdexamfetamine
# hydrochlorothi - hydrochlorothiazide
# Potassium Chlo - Potassium Chloride
# Cholecalcifero - Cholecalciferol
# Chloramphenico - Chloramphenicol
# Insulin glargi - Insulin glargine
# Ethinyl Estrad - Ethinyl Estradiol
# 3. Drug categories:
# Enzyme inhibitor:
# Clopidogrel - take out
# Enzyme activators:
# Take out enzyme activator category (very difficult class of drugs - most have actually other primary mechanism of actions)
# —> if you check the figures it makes sense because they do not cluster at all
# Receptor antagonists:
# Lisinopril - ACE inhibitor - take out
# Amlodipine - Calcium channel blocker - take out
# Diltiazem - Calcium channel blocker - take out
# Hormone modulators:
# Levothyroxine - is not a hormone modulator, but a synthetic hormone replacement for thyroxine (T4) - take out
# Testosterone - take out
# Desmopressin - hormone analog - take out
# Ion channel modulators:
# Buspirone - acts on serotonin 5-HT1A receptor - take out
# Tizanidine - Receptor agonist - take out
# Transporter modulator:
# Lisdexamfetami - take out
# Cholestyramine - take out
# —> very interesting class - most of them modulate the serotonin transporter —> they also cluster nicely in later layers
# Signal transduction modulators:
# Hydrochlorothiazide - take out
# Furosemide. - take out
# Apixaban - take out (anticoagulant)
# Potassium Chlo - do you mean Potassium Chloride? Then it is ok
# Cholecalcifero - Cholecalciferol? Then ok
# Rivaroxaban - primarily anticoagulant and secondary signal transduction - better to take out
# DNA/RNA interferer:
# Amoxicillin - take out (disrupts cell wall synthesis)
# Doxycycline - take out (inhibits protein synthesis)
# Azithromycin - take out
# Vancomycin - take out
# Vincristine - take out
# Paclitaxel - take out
# Chloramphenico - take out
# If you want to include more - you can take a few from here: https://www.researchgate.net/figure/List-of-clinically-approved-RNA-drugs-by-the-United-States-Food-and-Drug-Administration_tbl1_341644556
# Receptor agonists:
# All good



# Paper feedback - drug by speciality
# Cardiology:
# Aspirin - primary application pain relief - but we can leave it in - just be aware of that
# Endocrinology:
# All good
# Gastro:
# All good
# Neuroloy:
# All good
# Psychiatry:
# All good
# Urology:
# All good
# Rheumatology:
# —> all good in principle -but most of drugs overlap with other specialities
# —> few of them share mechanism of action, hence they still cluster in later layers probably
# Infectious Diseases:
# All good
# Oncology:
# All good


# def runDrugOnlyAnalysis(tok, model, MODEL_NAME, PRINT_REPLIES=False):
    def runUmap(self, load_cached):
                

        # # top 100 most administered drugs in the US in 2022
        # DRUG_TO_MECH = {
        # "Atorvastatin":                     "Cardio.",
        # "Metformin":                        "Antidiabetic",
        # "Lisinopril":                       "Cardio.",
        # "Levothyroxine":                    "Endocrine",
        # "Amlodipine":                       "Cardio.",
        # "Metoprolol":                       "Cardio.",
        # "Albuterol":                        "Respiratory",
        # "Losartan":                         "Cardio.",
        # "Omeprazole":                       "Gastro.",
        # "Gabapentin":                       "CNS",
        # "Sertraline":                       "CNS",
        # "Hydrochlorothiazide":              "Cardio.",
        # "Rosuvastatin":                     "Cardio.",
        # "Dextroamphetamine":                "CNS",
        # "Escitalopram":                     "CNS",
        # "Pantoprazole":                     "Gastro.",
        # "Montelukast":                      "Respiratory",
        # "Trazodone":                        "CNS",
        # "Simvastatin":                      "Cardio.",
        # "Tamsulosin":                       "Cardio.",
        # "Bupropion":                        "CNS",
        # "Fluoxetine":                       "CNS",
        # "Acetaminophen":                    "CNS",
        # "Furosemide":                       "Cardio.",
        # "Fluticasone":                      "Respiratory",
        # "Amoxicillin":                      "Antibacterial",
        # "Apixaban":                         "Cardio.",
        # "Insulin Glargine":                 "Antidiabetic",
        # "Meloxicam":                        "Anti-inflam.",
        # "Prednisone":                       "Anti-inflam.",
        # "Duloxetine":                       "CNS",
        # "Methylphenidate":                  "CNS",
        # "Ibuprofen":                        "Anti-inflam.",
        # "Carvedilol":                       "Cardio.",
        # "Potassium Chloride":               "Miscellaneous",
        # "Aspirin":                          "Cardio.",
        # "Pravastatin":                      "Cardio.",
        # "Ergocalciferol":                   "Miscellaneous",
        # "Allopurinol":                      "Miscellaneous",
        # "Citalopram":                       "CNS",
        # "Alprazolam":                       "CNS",
        # "Glipizide":                        "Antidiabetic",
        # "Cetirizine":                       "Respiratory",
        # "Venlafaxine":                      "CNS",
        # "Cyclobenzaprine":                  "CNS",
        # "Hydroxyzine":                      "CNS",
        # "Clopidogrel":                      "Cardio.",
        # "Semaglutide":                      "Antidiabetic",
        # "Famotidine":                       "Gastro.",
        # "Estradiol":                        "Endocrine",
        # "Diclofenac":                       "Anti-inflam.",
        # "Spironolactone":                   "Cardio.",
        # "Hydrochlorothiazide":              "Cardio.",
        # "Buspirone":                        "CNS",
        # "Tramadol":                         "CNS",
        # "Empagliflozin":                    "Antidiabetic",
        # "Clonazepam":                       "CNS",
        # "Lamotrigine":                      "CNS",
        # "Fluticasone":                      "Respiratory",
        # "Oxycodone":                        "CNS",
        # "Ondansetron":                      "Gastro.",
        # "Cholecalciferol":                  "Miscellaneous",
        # "Atenolol":                         "Cardio.",
        # "Glimepiride":                      "Antidiabetic",
        # "Folic Acid":                       "Miscellaneous",
        # "Zolpidem":                         "CNS",
        # "Latanoprost":                      "Miscellaneous",
        # "Doxycycline":                      "Antibacterial",
        # "Lisdexamfetamine":                 "CNS",
        # "Insulin Lispro":                   "Antidiabetic",
        # "Clonidine":                        "Cardio.",
        # "Loratadine":                       "Respiratory",
        # "Finasteride":                      "Miscellaneous",
        # "Dulaglutide":                      "Antidiabetic",
        # "Hydrochlorothiazide":              "Cardio.",
        # "Insulin Aspart":                   "Antidiabetic",
        # "Propranolol":                      "Cardio.",
        # "Azithromycin":                     "Antibacterial",
        # "Ezetimibe":                        "Cardio.",
        # "Ethinyl Estradiol":                "Endocrine",
        # "Lorazepam":                        "CNS",
        # "Quetiapine":                       "CNS",
        # "Budesonide":                       "Respiratory",
        # "Topiramate":                       "CNS",
        # "Warfarin":                         "Cardio.",
        # "Sitagliptin":                      "Antidiabetic",
        # "Amitriptyline":                    "CNS",
        # "Fenofibrate":                      "Cardio.",
        # "Naproxen":                         "Anti-inflam.",
        # "Rivaroxaban":                      "Cardio.",
        # "Pregabalin":                       "CNS",
        # "Paroxetine":                       "CNS",
        # "Celecoxib":                        "Anti-inflam.",
        # "Tizanidine":                       "CNS",
        # "Sumatriptan":                      "CNS",
        # "Amoxicillin":                      "Antibacterial",
        # "Olmesartan":                       "Cardio.",
        # "Acetaminophen":                    "CNS",
        # "Ethinyl Estradiol":                "Endocrine",
        # "Diltiazem":                        "Cardio.",
        # }

        # took out Allergy and Pulmonology and Ophthalmology cause they had fewest drugs and I had too many categories.
        drugs = [
        # --- Original list (96) ---
        ("Atorvastatin", "Enzyme Inhibitor", "Cardiology"),
        # ("Metformin", "Enzyme Activator", "Endocrinology"), # our medical professional said to take out the entire enzyme activator category as it's a very difficult category to analyze and interpret
        # ("Lisinopril", "Receptor Antagonist", "Cardiology"),
        # ("Levothyroxine", "Hormone Modulator", "Endocrinology"),
        # ("Amlodipine", "Receptor Antagonist", "Cardiology"),
        ("Metoprolol", "Receptor Antagonist", "Cardiology"),
        # ("Albuterol", "Receptor Agonist", "Pulmonology"),  # asthma/COPD
        ("Losartan", "Receptor Antagonist", "Cardiology"),
        ("Omeprazole", "Enzyme Inhibitor", "Gastro."),
        ("Gabapentin", "Ion Channel Modulator", "Neurology"),
        ("Sertraline", "Transporter Modulator", "Psychiatry"),
        # ("Hydrochlorothiazide", "Signal Transduc. Modulator", "Cardiology"),
        ("Rosuvastatin", "Enzyme Inhibitor", "Cardiology"),
        ("Dextroamphetamine", "Transporter Modulator", "Psychiatry"),
        ("Escitalopram", "Transporter Modulator", "Psychiatry"),
        ("Pantoprazole", "Enzyme Inhibitor", "Gastro."),
        # ("Montelukast", "Receptor Antagonist", "Pulmonology"),  # allergy overlap
        ("Trazodone", "Transporter Modulator", "Psychiatry"),
        ("Simvastatin", "Enzyme Inhibitor", "Cardiology"),
        ("Tamsulosin", "Receptor Antagonist", "Urology"),
        ("Bupropion", "Transporter Modulator", "Psychiatry"),
        ("Fluoxetine", "Transporter Modulator", "Psychiatry"),
        ("Acetaminophen", "Signal Transduc. Modulator", "Rheumatology"),  # pain
        # ("Furosemide", "Signal Transduc. Modulator", "Cardiology"),  # nephrology overlap
        # ("Fluticasone", "Signal Transduc. Modulator", "Pulmonology"),  # allergy overlap
        # ("Amoxicillin", "DNA/RNA Interferer", "Infectious Dis."),
        # ("Apixaban", "Signal Transduc. Modulator", "Cardiology"),
        ("Insulin Glargine", "Receptor Agonist", "Endocrinology"),
        ("Meloxicam", "Enzyme Inhibitor", "Rheumatology"),
        ("Prednisone", "Signal Transduc. Modulator", "Rheumatology"),  # also pulm/allergy
        ("Duloxetine", "Transporter Modulator", "Psychiatry"),
        ("Methylphenidate", "Transporter Modulator", "Psychiatry"),
        ("Ibuprofen", "Enzyme Inhibitor", "Rheumatology"),
        ("Carvedilol", "Receptor Antagonist", "Cardiology"),
        ("Potassium Chloride", "Signal Transduc. Modulator", "Cardiology"),  # nephrology overlap
        ("Aspirin", "Enzyme Inhibitor", "Cardiology"),
        ("Pravastatin", "Enzyme Inhibitor", "Cardiology"),
        ("Ergocalciferol", "Signal Transduc. Modulator", "Endocrinology"),
        ("Allopurinol", "Enzyme Inhibitor", "Rheumatology"),
        ("Citalopram", "Transporter Modulator", "Psychiatry"),
        ("Alprazolam", "Ion Channel Modulator", "Psychiatry"),
        # ("Glipizide", "Enzyme Activator", "Endocrinology"),
        # ("Cetirizine", "Receptor Antagonist", "Allergy"),  # pulm overlap
        ("Venlafaxine", "Transporter Modulator", "Psychiatry"),
        ("Cyclobenzaprine", "Ion Channel Modulator", "Neurology"),
        ("Hydroxyzine", "Receptor Antagonist", "Psychiatry"),  # also allergy
        # ("Clopidogrel", "Enzyme Inhibitor", "Cardiology"), # our medical professional said to take it out
        ("Semaglutide", "Receptor Agonist", "Endocrinology"),
        ("Famotidine", "Receptor Antagonist", "Gastro."),
        ("Estradiol", "Receptor Agonist", "Endocrinology"),
        ("Diclofenac", "Enzyme Inhibitor", "Rheumatology"),
        ("Spironolactone", "Hormone Modulator", "Cardiology"),  # also endocrinology
        # ("Buspirone", "Ion Channel Modulator", "Psychiatry"),
        ("Tramadol", "Ion Channel Modulator", "Neurology"),
        # ("Empagliflozin", "Enzyme Activator", "Endocrinology"),
        ("Clonazepam", "Ion Channel Modulator", "Psychiatry"),
        ("Lamotrigine", "Ion Channel Modulator", "Neurology"),
        ("Oxycodone", "Ion Channel Modulator", "Neurology"),
        ("Ondansetron", "Receptor Antagonist", "Gastro."),
        ("Cholecalciferol", "Signal Transduc. Modulator", "Endocrinology"),
        ("Atenolol", "Receptor Antagonist", "Cardiology"),
        # ("Glimepiride", "Enzyme Activator", "Endocrinology"),
        ("Folic Acid", "Signal Transduc. Modulator", "Endocrinology"),
        ("Zolpidem", "Ion Channel Modulator", "Psychiatry"),
        # ("Latanoprost", "Receptor Agonist", "Ophthalmology"),
        # ("Doxycycline", "DNA/RNA Interferer", "Infectious Dis."),
        # ("Lisdexamfetamine", "Transporter Modulator", "Psychiatry"),
        ("Insulin Lispro", "Receptor Agonist", "Endocrinology"),
        ("Clonidine", "Receptor Agonist", "Cardiology"),  # neuro overlap
        # ("Loratadine", "Receptor Antagonist", "Allergy"),
        ("Finasteride", "Enzyme Inhibitor", "Urology"),
        ("Dulaglutide", "Receptor Agonist", "Endocrinology"),
        ("Insulin Aspart", "Receptor Agonist", "Endocrinology"),
        ("Propranolol", "Receptor Antagonist", "Cardiology"),
        # ("Azithromycin", "DNA/RNA Interferer", "Infectious Dis."),
        ("Ezetimibe", "Enzyme Inhibitor", "Cardiology"),
        ("Ethinyl Estradiol", "Receptor Agonist", "Endocrinology"),
        ("Lorazepam", "Ion Channel Modulator", "Psychiatry"),
        ("Quetiapine", "Receptor Antagonist", "Psychiatry"),
        # ("Budesonide", "Signal Transduc. Modulator", "Pulmonology"),
        ("Topiramate", "Ion Channel Modulator", "Neurology"),
        ("Warfarin", "Enzyme Inhibitor", "Cardiology"),
        ("Sitagliptin", "Enzyme Inhibitor", "Endocrinology"),
        ("Amitriptyline", "Transporter Modulator", "Psychiatry"),
        ("Fenofibrate", "Enzyme Inhibitor", "Cardiology"),
        ("Naproxen", "Enzyme Inhibitor", "Rheumatology"),
        # ("Rivaroxaban", "Signal Transduc. Modulator", "Cardiology"),
        ("Pregabalin", "Ion Channel Modulator", "Neurology"),
        ("Paroxetine", "Transporter Modulator", "Psychiatry"),
        ("Celecoxib", "Enzyme Inhibitor", "Rheumatology"),
        # ("Tizanidine", "Ion Channel Modulator", "Neurology"),
        ("Sumatriptan", "Receptor Agonist", "Neurology"),
        ("Olmesartan", "Receptor Antagonist", "Cardiology"),
        # ("Diltiazem", "Receptor Antagonist", "Cardiology"),

        # --- First expansion (24) ---
        # ("Nitroglycerin", "Enzyme Activator", "Cardiology"),
        # ("Minoxidil", "Enzyme Activator", "Cardiology"),
        # ("Riluzole", "Enzyme Activator", "Neurology"),
        ("Ciprofloxacin", "DNA/RNA Interferer", "Infectious Dis."),
        ("Rifampin", "DNA/RNA Interferer", "Infectious Dis."),
        ("Methotrexate", "DNA/RNA Interferer", "Rheumatology"),  # also oncology
        ("Doxorubicin", "DNA/RNA Interferer", "Oncology"),
        ("Tamoxifen", "Hormone Modulator", "Oncology"),
        ("Leuprolide", "Hormone Modulator", "Oncology"),
        ("Fludrocortisone", "Hormone Modulator", "Endocrinology"),
        # ("Testosterone", "Hormone Modulator", "Endocrinology"),
        ("Sucralfate", "Signal Transduc. Modulator", "Gastro."),
        ("Mesalamine", "Signal Transduc. Modulator", "Gastro."),
        ("Loperamide", "Receptor Agonist", "Gastro."),
        ("Ursodiol", "Signal Transduc. Modulator", "Gastro."),
        # ("Vancomycin", "DNA/RNA Interferer", "Infectious Dis."),
        ("Acyclovir", "DNA/RNA Interferer", "Infectious Dis."),
        ("Amphotericin B", "Signal Transduc. Modulator", "Infectious Dis."),
        ("Isoniazid", "DNA/RNA Interferer", "Infectious Dis."),
        # ("Timolol", "Receptor Antagonist", "Ophthalmology"),
        # ("Pilocarpine", "Receptor Agonist", "Ophthalmology"),
        ("Oxybutynin", "Receptor Antagonist", "Urology"),
        ("Mirabegron", "Receptor Agonist", "Urology"),

        # --- Second expansion (51) ---
        # ("Hydralazine", "Enzyme Activator", "Cardiology"),
        # ("Isosorbide Mononitrate", "Enzyme Activator", "Cardiology"),
        # ("Cinacalcet", "Enzyme Activator", "Endocrinology"),
        # ("Selexipag", "Enzyme Activator", "Cardiology"),
        # ("Cilostazol", "Enzyme Activator", "Cardiology"),
        # # ("Milrinone", "Enzyme Activator", "Cardiology"),
        # ("Linaclotide", "Enzyme Activator", "Gastro."),
        ("5-Fluorouracil", "DNA/RNA Interferer", "Oncology"),
        ("Cytarabine", "DNA/RNA Interferer", "Oncology"),
        # ("Vincristine", "DNA/RNA Interferer", "Oncology"),
        # ("Paclitaxel", "DNA/RNA Interferer", "Oncology"),
        # ("Chloramphenicol", "DNA/RNA Interferer", "Infectious Dis."),
        ("Linezolid", "DNA/RNA Interferer", "Infectious Dis."),
        ("Zidovudine", "DNA/RNA Interferer", "Infectious Dis."),
        ("Anastrozole", "Hormone Modulator", "Oncology"),
        ("Mifepristone", "Hormone Modulator", "Oncology"),
        ("Cabergoline", "Hormone Modulator", "Endocrinology"),
        # ("Desmopressin", "Hormone Modulator", "Endocrinology"),
        ("Octreotide", "Hormone Modulator", "Endocrinology"),
        ("Eplerenone", "Hormone Modulator", "Cardiology"),
        ("Clomiphene", "Hormone Modulator", "Endocrinology"),
        ("Rifaximin", "DNA/RNA Interferer", "Gastro."),
        ("Metoclopramide", "Receptor Antagonist", "Gastro."),
        ("Granisetron", "Receptor Antagonist", "Gastro."),
        # ("Cholestyramine", "Transporter Modulator", "Gastro."),
        ("Lactulose", "Signal Transduc. Modulator", "Gastro."),
        ("Dicyclomine", "Receptor Antagonist", "Gastro."),
        # ("Pancrelipase", "Enzyme Activator", "Gastro."),
        ("Meropenem", "DNA/RNA Interferer", "Infectious Dis."),
        ("Ceftriaxone", "DNA/RNA Interferer", "Infectious Dis."),
        ("Daptomycin", "Signal Transduc. Modulator", "Infectious Dis."),
        ("Oseltamivir", "DNA/RNA Interferer", "Infectious Dis."),
        ("Efavirenz", "DNA/RNA Interferer", "Infectious Dis."),
        ("Caspofungin", "Signal Transduc. Modulator", "Infectious Dis."),
        ("Metronidazole", "DNA/RNA Interferer", "Infectious Dis."),
        # ("Brimonidine", "Receptor Agonist", "Ophthalmology"),
        # ("Dorzolamide", "Enzyme Inhibitor", "Ophthalmology"),
        # ("Travoprost", "Receptor Agonist", "Ophthalmology"),
        ("Darifenacin", "Receptor Antagonist", "Urology"),
        ("Tolterodine", "Receptor Antagonist", "Urology"),
        ("Tadalafil", "Enzyme Inhibitor", "Urology"),
        ("Alfuzosin", "Receptor Antagonist", "Urology"),
        ("Sirolimus", "Signal Transduc. Modulator", "Oncology"),  # also transplant
        ("Everolimus", "Signal Transduc. Modulator", "Oncology"),
        ("Imatinib", "Signal Transduc. Modulator", "Oncology"),
        # ("Theophylline", "Enzyme Inhibitor", "Pulmonology"),
        # ("Omalizumab", "Signal Transduc. Modulator", "Pulmonology"),
        # ("Mepolizumab", "Signal Transduc. Modulator", "Pulmonology"),
        # ("Adalimumab", "Signal Transduc. Modulator", "Rheumatology"),
        # ("Colchicine", "Signal Transduc. Modulator", "Rheumatology"),
        ]

        # print how many drugs are in each mechanism and specialty
        mechanism_counts = collections.Counter([mechanism for (_name, mechanism, _specialty) in drugs])
        specialty_counts = collections.Counter([specialty for (_name, _mechanism, specialty) in drugs])
        print("Number of drugs in each mechanism:")
        for mech, cnt in sorted(mechanism_counts.items(), key=lambda x: (-x[1], x[0])):
            print(f"  {mech}: {cnt}")
        print("Number of drugs in each specialty:")
        for spec, cnt in sorted(specialty_counts.items(), key=lambda x: (-x[1], x[0])):
            print(f"  {spec}: {cnt}")


        # Build lists/mappings from tuple list
        ALL_DRUGS = [name for (name, _mechanism, _specialty) in drugs]
        DRUG_TO_CATEGORY = {name: specialty for (name, _mechanism, specialty) in drugs}
        DRUG_TO_MECHANISM = {name: mechanism for (name, mechanism, _specialty) in drugs}


        # A. seeds & markers
        MARKERS = ["o","s","^","p","D","P","*","X",">","h","d"]

        NR_PATIENTS_PER_DRUG = 1

        # B. simple adult age sampler
        def sample_age(): return random.randint(20,60)

        cache_path = self.umapCacheFile
        if load_cached and os.path.exists(cache_path):
            print(f"Loading cached UMAP data from {cache_path}")
            with open(cache_path, 'rb') as f:
                cache = pickle.load(f)
                embs = cache.get('embs'); 
                embsClustering = cache.get('embsClustering')
                prompts = cache.get('prompts')
                drugs_list = cache.get('drugs_list'); 
                DRUG_TO_CATEGORY = cache.get('DRUG_TO_CATEGORY', DRUG_TO_CATEGORY)
                DRUG_TO_MECHANISM = cache.get('DRUG_TO_MECHANISM', DRUG_TO_MECHANISM)
                ALL_DRUGS = cache.get('ALL_DRUGS'); 
                MARKERS = cache.get('MARKERS')
        else:    
            # C. generate prompts
            # GENDERS = ["male","female"]
            # GENDERS = ["female"] # actually let's do 100% females as I get two groups sometimes and I'm not sure if it's gender. 
            prompts, drugs_list = [], []
            for drug in ALL_DRUGS:
                for _ in range(NR_PATIENTS_PER_DRUG):
                    # g = random.choice(GENDERS)
                    a = sample_age()
                    # P = "He" if g=="male" else "She"
                    # p = P.lower()
                    # prompts.append(f"A {a}-year-old {g} patient takes {drug}. What drug does {p} take?")
                    prompts.append(f"A patient takes {drug}. What drug does the patient take?") # Raz: so it seems that without anchoring with the second part of the prompt "What drug does the patient take?", the clustering is not as good!! Not sure why ...
                    drugs_list.append(drug)
            print(f"Generated {len(prompts)} prompts.")

            # D. hook activations
            handles, layer_buffers = self.register_hooks()

            replies = self.run_model(prompts)
            for h in handles: h.remove()

            # E. UMAP
            embs = self.umap_fit(layer_buffers)
            embsClustering = self.umap_fit(layer_buffers, n_components=30)

            # Save cache
            cache = {
                'embs': embs,
                'embsClustering': embsClustering,
                'prompts': prompts,
                'drugs_list': drugs_list,
                'DRUG_TO_CATEGORY': DRUG_TO_CATEGORY,
                'DRUG_TO_MECHANISM': DRUG_TO_MECHANISM,
                'ALL_DRUGS': ALL_DRUGS,
                'MARKERS': MARKERS,
            }
            with open(cache_path, 'wb') as f:
                pickle.dump(cache, f)
            print(f"Saved UMAP cache → {cache_path}")


        # F. group indices & markers
        idxs_by_drug = collections.defaultdict(list)
        for idx,d in enumerate(drugs_list):
            idxs_by_drug[d].append(idx)
        drug_markers = {d: MARKERS[i % len(MARKERS)] for i,d in enumerate(ALL_DRUGS)}

        # selected layers plot
        selected_layers = self.selectLayersForPlots(nrPlots=4)
        out_file_selected = f"results/{self.analysis_name}_umap-specialty_{self.MODEL_NAME}_selected.pdf"
        self.plot_umap(DRUG_TO_CATEGORY, "", out_file_selected, ALL_DRUGS, embs, idxs_by_drug, drug_markers, layer_list=selected_layers, nr_cols=4, fontsize=23, plot_title=False, drug_row_height=2.2, drug_row_offset=11.5)
        plt.show()

        # single plot by category (medical specialty)
        selected_layers_many = self.selectLayersForPlots(nrPlots=40)
        self.plot_umap(DRUG_TO_CATEGORY, f"UMAP by medical specialty: {self.MODEL_NAME}", f"results/{self.analysis_name}_umap-specialty_{self.MODEL_NAME}.pdf", ALL_DRUGS, embs, idxs_by_drug, drug_markers, nr_cols=5, fontsize=27, plot_title=False, layer_list=selected_layers_many, drug_row_height=1.8, drug_row_offset=14)
        plt.show()



        # # # Convert string labels to integer labels for silhouette score
        # int_labels = np.array([DRUG_TO_CATEGORY[d] for d in ALL_DRUGS])
        # runUMAPClustering(embsClustering, int_labels, self.umapDrugSpecFile)

        # mechanism plots (selected layers)
        out_file_selected_mech = f"results/{self.analysis_name}_umap-mechanism_{self.MODEL_NAME}_selected.pdf"
        self.plot_umap(DRUG_TO_MECHANISM, "", out_file_selected_mech, ALL_DRUGS, embs, idxs_by_drug, drug_markers, layer_list=selected_layers, nr_cols=4, fontsize=23, plot_title=False, legend_cattitles_two_rows=True, drug_row_height=2.2, drug_row_offset=11.5)
        plt.show()



        # mechanism plots (all layers)
        self.plot_umap(DRUG_TO_MECHANISM, f"UMAP by mechanism: {self.MODEL_NAME}", f"results/{self.analysis_name}_umap-mechanism_{self.MODEL_NAME}.pdf", ALL_DRUGS, embs, idxs_by_drug, drug_markers, layer_list=selected_layers_many, nr_cols=5, fontsize=27, plot_title=False, legend_cattitles_two_rows=True,drug_row_height=1.8, drug_row_offset=14.0)
        plt.show()


        
        # # H. Compute clustering coefficient per layer using silhouette score with confidence intervals
        # mech_labels = np.array([DRUG_TO_MECHANISM[d] for d in ALL_DRUGS])
        # runUMAPClustering(embsClustering, mech_labels, self.umapDrugMechFile)

    def set_files(self):
        # self.umapAggFile = f"results/{self.analysis_name}_umap_agg_{self.MODEL_NAME}.npz"
        self.umapDrugMechFile = f"results/{self.analysis_name}_umap-mechanism_{self.MODEL_NAME}.npz"
        self.umapDrugSpecFile = f"results/{self.analysis_name}_umap-specialty_{self.MODEL_NAME}.npz"

        self.umapCacheFile = f"results/{self.analysis_name}_umap_cache_{self.MODEL_NAME}.pkl"
        self.saliencyFile = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}.pkl"
        self.lesioningFile = f"results/{self.analysis_name}_lesioning_{self.MODEL_NAME}.json"
        self.lesioningFinegrainedFile = f"results/{self.analysis_name}_lesioning_finegrained_{self.MODEL_NAME}.json"
        self.activationPatchingFile = f"results/{self.analysis_name}_activation_patching_{self.MODEL_NAME}.json"
        self.activationPatchingFinegrainedFile = f"results/{self.analysis_name}_activation_patching_finegrained_{self.MODEL_NAME}.json"


    def gen_saliency_prompts(self):
        # prompts here
        prefix = "Answer with a single word only describing a drug."
        return [
            (prefix + "A patient suffers from high LDL cholesterol and has atherosclerotic disease. What drug should the patient take?", "Atorvastatin", "atorvastatin"),
            (prefix + "A patient suffers from type 2 diabetes with elevated A1C despite diet changes. What drug should the patient take?", "Metformin", "metformin"),
            (prefix + "A patient suffers from hypothyroidism with fatigue and cold intolerance. What drug should the patient take?", "Levothyroxine", "levothyroxine"),
            (prefix + "A patient suffers from asthma with wheezing and exercise-induced bronchospasm. What drug should the patient take?", "Albuterol", "albuterol"),
            (prefix + "A patient suffers from gastroesophageal reflux with heartburn. What drug should the patient take?", "Omeprazole", "omeprazole"),
            (prefix + "A patient suffers from hypertension with chronic kidney disease. What drug should the patient take?", "Lisinopril", "lisinopril"),
            (prefix + "A patient suffers from heart failure with reduced ejection fraction and fluid overload. What drug should the patient take?", "Furosemide", "furosemide"),
            (prefix + "A patient suffers from atrial fibrillation with high stroke risk. What drug should the patient take?", "Warfarin", "warfarin"),
            (prefix + "A patient suffers from acute chest pain due to myocardial ischemia. What drug should the patient take?", "Aspirin", "aspirin"),
            (prefix + "A patient suffers from neuropathic pain after shingles. What drug should the patient take?", "Gabapentin", "gabapentin"),
            (prefix + "A patient suffers from bacterial pneumonia requiring oral therapy. What drug should the patient take?", "Azithromycin", "azithromycin"),
            (prefix + "A patient suffers from generalized anxiety with panic episodes. What drug should the patient take?", "Sertraline", "sertraline"),
            (prefix + "A patient suffers from major depression with low energy. What drug should the patient take?", "Fluoxetine", "fluoxetine"),
            (prefix + "A patient suffers from schizophrenia with auditory hallucinations. What drug should the patient take?", "Risperidone", "risperidone"),
            (prefix + "A patient suffers from epilepsy with focal seizures. What drug should the patient take?", "Lamotrigine", "lamotrigine"),
            (prefix + "A patient suffers from rheumatoid arthritis with swollen painful joints. What drug should the patient take?", "Methotrexate", "methotrexate"),
            (prefix + "A patient suffers from gout with acute podagra. What drug should the patient take?", "Colchicine", "colchicine"),
            (prefix + "A patient suffers from hypertriglyceridemia and pancreatitis risk. What drug should the patient take?", "Fenofibrate", "fenofibrate"),
            (prefix + "A patient suffers from osteoporosis with vertebral fractures. What drug should the patient take?", "Alendronate", "alendronate"),
            (prefix + "A patient suffers from type 2 diabetes needing rapid mealtime glucose control. What drug should the patient take?", "Insulin", "insulin"),
        ]


    # Use base runLesioning; provide prompts here
    def generate_lesioning_prompts(self):
        # open ended questions on drug knowledge
        return [
            "A patient takes Atorvastatin. What is this drug used for?", # hyperlipidemia
            "A patient takes Metformin. What are the side effects?", # diabetes mellitus
            "A patient takes Levothyroxine. How should it be taken?", # hypothyroidism
            "A patient takes Albuterol. When is it prescribed?", # asthma/COPD
            "A patient takes Omeprazole. What conditions does it treat?", # gastroesophageal reflux disease / peptic ulcer disease
            "A patient takes Warfarin. What are the important precautions?", # anticoagulation
            "A patient takes Lisinopril. What is the mechanism of action?", # ACE inhibitor
            "A patient takes Prednisone. What are the common side effects?", # corticosteroid
            "A patient takes Digoxin. What conditions does it treat?", # heart failure
            "A patient takes Furosemide. How does this medication work?", # diuretic
        ]

    def generate_activation_patching_prompts(self):
        """
        Generate prompt pairs for activation patching with single-word drug answers.
        
        Returns:
            List of tuples (clean_prompt, corrupt_prompt, clean_answer, corrupt_answer)
        """

        print("Running activation patching analysis for top 100 drugs...")

        prefix = "Answer with a single word only describing a drug given in the United States. "

        prompt_pairs = [
            # ADHD vs Depression
            ("A patient has ADHD. A common medication is ___","Adderall",      
             "A patient has depression. A common medication is ___",  "Sertraline"),
            
            # Hypertension vs Diabetes
            ("A patient has hypertension. A common medication is ___","Lisinopril",
             "A patient has diabetes. A common medication is ___",  "Metformin"),
            
            # High cholesterol vs Asthma
            ("A patient has high cholesterol. A common medication is ___", "Atorvastatin", # or simvastatin
             "A patient has asthma. A common medication is ___", "Albuterol"),
            
            # # Anxiety vs Headache
            # ("A patient has anxiety. A common medication is ___", "Prozac",
            #  "A patient has headache. A common medication is ___", "Tylenol"),
            
            # Heart disease vs Pain
            ("A patient has heart disease. A common medication is ___", "Aspirin)",
             "A patient has heart failure. A common medication is ___", "Furosemide"),
            
            # Insomnia vs cyrcadian rhythm sleep disorder
            ("A patient has insomnia. A common medication is ___", "Zolpidem",
             "A patient has circadian rhythm sleep disorder. A common medication is ___", "Melatonin"),
            
            # Arthritis vs Migraine
            ("A patient has arthritis. A common medication is ___", "Celebrex",
             "A patient has migraine. A common medication is ___", "Imitrex"),
            
            # Infection vs Allergies
            ("A patient has infection. A common medication is ___", "Amoxicillin",
             "A patient has allergies. A common medication is ___", "Loratadine"), # Amoxicillin is fine here
            
            # Seizures vs Anxiety
            ("A patient has epilepsy. A common medication is ___", "Levetiracetam",
             "A patient has schizophrenia. A common medication is ___", "Risperidone"),
            
            # # High blood pressure vs Heart failure
            # ("A patient has high blood pressure. A common medication is ___", 
            #  "Amlodipine",
            #  "A patient has heart failure. A common medication is ___", 
            #  "Digoxin")
        ]
        # add the prefix to each prompt
        prompt_pairs = [(prefix + prompt[0], prompt[1], prefix + prompt[2], prompt[3]) for prompt in prompt_pairs]
        
        print("Generated activation patching prompts:")
        for i, (clean, corrupt, clean_ans, corrupt_ans) in enumerate(prompt_pairs):
            print(f"  {i+1}. Clean: '{clean}' → '{clean_ans}'")
            print(f"     Corrupt: '{corrupt}' → '{corrupt_ans}'")
        
        return prompt_pairs

