"""
Same as drugAct.py, but for 100 most common drugs in the US, and no diseases this time.
"""

import math, random, collections, torch, umap, numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from transformers import AutoModelForCausalLM, AutoTokenizer
from functools import partial
import matplotlib.colors as mcolors
from scipy import stats
import json
from saliency import *
import os
import pickle, collections
from utils import runUMAPClustering

from act import Act

class Top100DrugAct(Act):
    def __init__(self, tok, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=False):
        super().__init__(tok, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=USE_ACCELERATE)
        self.analysis_name = "top100-drugs"
        self.set_files()

    # def run(self, do_umap=True, do_saliency=True, do_maps=True, do_lesioning=False, do_activation_patching=False, do_activation_patching_finegrained=False):
    #     """
    #     Override the parent run method to include activation patching.
    #     """
    #     if do_umap:
    #         self.runUmap()
    #     if do_saliency:
    #         self.runSaliency()
    #     if do_lesioning:
    #         self.runLesioning()
    #     if do_activation_patching:
    #         self.runActivationPatching()
    #     if do_activation_patching_finegrained:
    #         self.runActivationPatchingFinegrained()
    #     if do_maps:
    #         self.plotMap()

    # plot helper
    def plot_umap(self, mapping, title, out_path, ALL_DRUGS, embs, idxs_by_drug, drug_markers, layer_list=None, nr_cols=5, fontsize=24, plot_title=True):
        categories = list(dict.fromkeys(mapping[d] for d in ALL_DRUGS))
        cmap       = get_cmap("tab20")
        catcol     = {cat: cmap(i % cmap.N) for i,cat in enumerate(categories)}

        # Determine which layers to plot
        if layer_list is None:
            layers_to_plot = list(range(self.num_layers))
        else:
            layers_to_plot = layer_list
        
        n_layers_to_plot = len(layers_to_plot)
        
        # Special layout for selected layers (4x1 grid with categorized legend)
        if layer_list is not None and n_layers_to_plot == 4:
            # Create figure with 4x1 grid at top 50% and legend at bottom 50%
            fig = plt.figure(figsize=(25, 12))
            
            # Create 4x1 grid for plots in top 50% (reduce top whitespace)
            gs = fig.add_gridspec(2, 1, height_ratios=[1, 1], hspace=0.23)
            plot_gs = gs[0].subgridspec(1, 4, wspace=0.1)
            legend_gs = gs[1]
            
            # Plot the 4 layers
            for plot_idx, layer in enumerate(layers_to_plot):
                ax = fig.add_subplot(plot_gs[0, plot_idx])
                emb = embs[layer]
                for drug, idxs in idxs_by_drug.items():
                    col = catcol[mapping[drug]]
                    ax.scatter(emb[idxs,0], emb[idxs,1],
                            c=[col]*len(idxs),
                            marker=drug_markers[drug],
                            s=400, edgecolors='k', linewidths=0.2)
                ax.set_title(f"Layer {layer}", fontsize=fontsize-3)
                
                # Add axis labels
                ax.set_xlabel("UMAP-1", fontsize=fontsize-4)
                if plot_idx == 0:  # Only add Y-axis label to leftmost plot
                    ax.set_ylabel("UMAP-2", fontsize=fontsize-4)
                else:
                    ax.set_yticks([])
            
            # Create categorized legend in bottom 50%
            n_categories = len(categories)
            legend_cols = n_categories
            max_drugs_per_cat = max(len([d for d in ALL_DRUGS if mapping[d] == cat]) for cat in categories)
            legend_rows = max_drugs_per_cat + 2  # Extra space for title and spacing
            
            # Create legend subplot
            legend_ax = fig.add_subplot(legend_gs)
            legend_ax.set_xlim(0, legend_cols * 1.2)  # Add horizontal spacing
            legend_ax.set_ylim(0, legend_rows * 1.8)  # Increased vertical spacing
            legend_ax.axis('off')
            
            # Plot category titles and drug names
            col_width = 1.2  # Increased column width
            for cat_idx, cat in enumerate(categories):
                x_pos = cat_idx * col_width + 0.05
                y_pos = legend_rows * 1.8 - 0.5
                
                # Category title in bold
                legend_ax.text(x_pos, y_pos, cat, ha='left', va='center', 
                             fontsize=fontsize-5, fontweight='bold')
                
                # Drug names in this category
                drugs_in_cat = [d for d in ALL_DRUGS if mapping[d] == cat]
                for drug_idx, drug in enumerate(drugs_in_cat):
                    y_drug = legend_rows + 18.5 - drug_idx * 4  # Increased vertical spacing
                    
                    # Draw colored marker
                    col = catcol[cat]
                    marker = drug_markers[drug]
                    legend_ax.scatter(x_pos, y_drug, c=[col], marker=marker, 
                                    s=100, edgecolors='k', linewidths=0.5)

                    if cat == "Cardio.":
                        print('cardio drug', drug, marker, x_pos, y_drug)
                    
                    # Drug name left-aligned next to marker
                    # if name of drug is longer than 10 characters, abbreviate and add dot
                    drugAbbr = drug
                    if len(drug) > 11:
                        drugAbbr = drug[:11] + "."
                    legend_ax.text(x_pos + 0.07, y_drug, drugAbbr, ha='left', va='center', 
                                 fontsize=fontsize-6)
            
            if plot_title:
                fig.suptitle(f"{self.MODEL_NAME} - Drugs UMAP colored by {title}", 
                            y=0.95, fontsize=fontsize+6)
            
        else:
            # Regular layout for all layers or other cases
            ncols, nrows = nr_cols, math.ceil(n_layers_to_plot/nr_cols)
            fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*8,nrows*6), squeeze=False)

            for plot_idx, layer in enumerate(layers_to_plot):
                ax = axes[plot_idx//ncols, plot_idx % ncols]
                emb = embs[layer]
                for drug, idxs in idxs_by_drug.items():
                    col = catcol[mapping[drug]]
                    ax.scatter(emb[idxs,0], emb[idxs,1],
                            c=[col]*len(idxs),
                            marker=drug_markers[drug],
                            s=400, edgecolors='k', linewidths=0.2)
                ax.set_title(f"Layer {layer}", fontsize=fontsize-3)
                ax.set_xticks([]); ax.set_yticks([])

            # Hide unused subplots
            for i in range(n_layers_to_plot, nrows*ncols):
                axes[i//ncols, i % ncols].axis("off")

            from matplotlib.lines import Line2D
            handles = []
            for cat in categories:
                col = catcol[cat]
                drugs_in_cat = [d for d in ALL_DRUGS if mapping[d] == cat]
                for d in drugs_in_cat:
                    label = rf"{d} ({cat})"
                    handles.append(
                        Line2D([0], [0],
                            marker=drug_markers[d],
                            color='w',
                            markerfacecolor=col,
                            markeredgecolor='k',
                            markersize=25,
                            label=label)
                    )
            fig.legend(handles=handles,
                    loc='center left',
                    bbox_to_anchor=(0.7,0.5),
                    frameon=False,
                    fontsize=fontsize-1,
                    handletextpad=1.0)

            if plot_title:
                fig.suptitle(f"{self.MODEL_NAME} - Drugs UMAP colored by {title}", y=0.98, fontsize=fontsize+6)
            fig.tight_layout(
                rect=[0,0,0.7,0.975],
                pad=2.5,
                w_pad=2.0,
                h_pad=2.5
            )
        
        fig.savefig(out_path, dpi=200)
        print(f"Saved to {out_path}")

    def selectLayersForPlots(self):
        """
        Select 4 evenly-spaced layers for plotting, including first and last layers.
        """
        if self.num_layers <= 4:
            return list(range(self.num_layers))
        
        # Always include first and last layers
        selected = [0, self.num_layers - 1]
        
        # Add 2 intermediate layers
        if self.num_layers > 2:
            step = (self.num_layers - 1) / 3  # Divide into 3 intervals
            for i in range(1, 3):
                layer_idx = int(round(i * step))
                if layer_idx not in selected:
                    selected.append(layer_idx)
        
        return sorted(selected)

# def runDrugOnlyAnalysis(tok, model, MODEL_NAME, PRINT_REPLIES=False):
    def runUmap(self, load_cached=True):
                

        # top 100 most administered drugs in the US in 2022
        DRUG_TO_MECH = {
        "Atorvastatin":                     "Cardio.",
        "Metformin":                        "Antidiabetic",
        "Lisinopril":                       "Cardio.",
        "Levothyroxine":                    "Endocrine",
        "Amlodipine":                       "Cardio.",
        "Metoprolol":                       "Cardio.",
        "Albuterol":                        "Respiratory",
        "Losartan":                         "Cardio.",
        "Omeprazole":                       "Gastro.",
        "Gabapentin":                       "CNS",
        "Sertraline":                       "CNS",
        "Hydrochlorothiazide":              "Cardio.",
        "Rosuvastatin":                     "Cardio.",
        "Dextroamphetamine":                "CNS",
        "Escitalopram":                     "CNS",
        "Pantoprazole":                     "Gastro.",
        "Montelukast":                      "Respiratory",
        "Trazodone":                        "CNS",
        "Simvastatin":                      "Cardio.",
        "Tamsulosin":                       "Cardio.",
        "Bupropion":                        "CNS",
        "Fluoxetine":                       "CNS",
        "Acetaminophen":                    "CNS",
        "Furosemide":                       "Cardio.",
        "Fluticasone":                      "Respiratory",
        "Amoxicillin":                      "Antibacterial",
        "Apixaban":                         "Cardio.",
        "Insulin Glargine":                 "Antidiabetic",
        "Meloxicam":                        "Anti-inflam.",
        "Prednisone":                       "Anti-inflam.",
        "Duloxetine":                       "CNS",
        "Methylphenidate":                  "CNS",
        "Ibuprofen":                        "Anti-inflam.",
        "Carvedilol":                       "Cardio.",
        "Potassium Chloride":               "Miscellaneous",
        "Aspirin":                          "Cardio.",
        "Pravastatin":                      "Cardio.",
        "Ergocalciferol":                   "Miscellaneous",
        "Allopurinol":                      "Miscellaneous",
        "Citalopram":                       "CNS",
        "Alprazolam":                       "CNS",
        "Glipizide":                        "Antidiabetic",
        "Cetirizine":                       "Respiratory",
        "Venlafaxine":                      "CNS",
        "Cyclobenzaprine":                  "CNS",
        "Hydroxyzine":                      "CNS",
        "Clopidogrel":                      "Cardio.",
        "Semaglutide":                      "Antidiabetic",
        "Famotidine":                       "Gastro.",
        "Estradiol":                        "Endocrine",
        "Diclofenac":                       "Anti-inflam.",
        "Spironolactone":                   "Cardio.",
        "Hydrochlorothiazide":              "Cardio.",
        "Buspirone":                        "CNS",
        "Tramadol":                         "CNS",
        "Empagliflozin":                    "Antidiabetic",
        "Clonazepam":                       "CNS",
        "Lamotrigine":                      "CNS",
        "Fluticasone":                      "Respiratory",
        "Oxycodone":                        "CNS",
        "Ondansetron":                      "Gastro.",
        "Cholecalciferol":                  "Miscellaneous",
        "Atenolol":                         "Cardio.",
        "Glimepiride":                      "Antidiabetic",
        "Folic Acid":                       "Miscellaneous",
        "Zolpidem":                         "CNS",
        "Latanoprost":                      "Miscellaneous",
        "Doxycycline":                      "Antibacterial",
        "Lisdexamfetamine":                 "CNS",
        "Insulin Lispro":                   "Antidiabetic",
        "Clonidine":                        "Cardio.",
        "Loratadine":                       "Respiratory",
        "Finasteride":                      "Miscellaneous",
        "Dulaglutide":                      "Antidiabetic",
        "Hydrochlorothiazide":              "Cardio.",
        "Insulin Aspart":                   "Antidiabetic",
        "Propranolol":                      "Cardio.",
        "Azithromycin":                     "Antibacterial",
        "Ezetimibe":                        "Cardio.",
        "Ethinyl Estradiol":                "Endocrine",
        "Lorazepam":                        "CNS",
        "Quetiapine":                       "CNS",
        "Budesonide":                       "Respiratory",
        "Topiramate":                       "CNS",
        "Warfarin":                         "Cardio.",
        "Sitagliptin":                      "Antidiabetic",
        "Amitriptyline":                    "CNS",
        "Fenofibrate":                      "Cardio.",
        "Naproxen":                         "Anti-inflam.",
        "Rivaroxaban":                      "Cardio.",
        "Pregabalin":                       "CNS",
        "Paroxetine":                       "CNS",
        "Celecoxib":                        "Anti-inflam.",
        "Tizanidine":                       "CNS",
        "Sumatriptan":                      "CNS",
        "Amoxicillin":                      "Antibacterial",
        "Olmesartan":                       "Cardio.",
        "Acetaminophen":                    "CNS",
        "Ethinyl Estradiol":                "Endocrine",
        "Diltiazem":                        "Cardio.",
        }

        ALL_DRUGS = list(DRUG_TO_MECH.keys())

        # A. seeds & markers
        MARKERS = ["o","s","^","p","D","P","*","X",">","h","d"]

        NR_PATIENTS_PER_DRUG = 1

        # B. simple adult age sampler
        def sample_age(): return random.randint(20,60)

        cache_path = self.umapCacheFile
        if load_cached and os.path.exists(cache_path):
            print(f"Loading cached UMAP data from {cache_path}")
            with open(cache_path, 'rb') as f:
                cache = pickle.load(f)
                embs = cache.get('embs'); 
                embsClustering = cache.get('embsClustering')
                prompts = cache.get('prompts')
                drugs_list = cache.get('drugs_list'); 
                DRUG_TO_MECH = cache.get('DRUG_TO_MECH')
                ALL_DRUGS = cache.get('ALL_DRUGS'); 
                MARKERS = cache.get('MARKERS')
        else:    
            # C. generate prompts
            # GENDERS = ["male","female"]
            # GENDERS = ["female"] # actually let's do 100% females as I get two groups sometimes and I'm not sure if it's gender. 
            prompts, drugs_list = [], []
            for drug in ALL_DRUGS:
                for _ in range(NR_PATIENTS_PER_DRUG):
                    # g = random.choice(GENDERS)
                    a = sample_age()
                    # P = "He" if g=="male" else "She"
                    # p = P.lower()
                    # prompts.append(f"A {a}-year-old {g} patient takes {drug}. What drug does {p} take?")
                    prompts.append(f"A patient takes {drug}.")
                    drugs_list.append(drug)
            print(f"Generated {len(prompts)} prompts.")

            # D. hook activations
            handles, layer_buffers = self.register_hooks()

            replies = self.run_model(prompts)
            for h in handles: h.remove()

            # E. UMAP
            embs = self.umap_fit(layer_buffers)
            embsClustering = self.umap_fit(layer_buffers, n_components=30)

            # Save cache
            cache = {
                'embs': embs,
                'embsClustering': embsClustering,
                'prompts': prompts,
                'drugs_list': drugs_list,
                'DRUG_TO_MECH': DRUG_TO_MECH,
                'ALL_DRUGS': ALL_DRUGS,
                'MARKERS': MARKERS,
            }
            with open(cache_path, 'wb') as f:
                pickle.dump(cache, f)
            print(f"Saved UMAP cache → {cache_path}")


        # F. group indices & markers
        idxs_by_drug = collections.defaultdict(list)
        for idx,d in enumerate(drugs_list):
            idxs_by_drug[d].append(idx)
        drug_markers = {d: MARKERS[i % len(MARKERS)] for i,d in enumerate(ALL_DRUGS)}

        # G. single plot by mechanism
        self.plot_umap(DRUG_TO_MECH, "mechanism", f"results/{self.analysis_name}_umap_{self.MODEL_NAME}.pdf", ALL_DRUGS, embs, idxs_by_drug, drug_markers, nr_cols=5, fontsize=27, plot_title=True)
        plt.show()

        # G2. selected layers plot
        selected_layers = self.selectLayersForPlots()
        out_file_selected = f"results/{self.analysis_name}_umap_{self.MODEL_NAME}_selected.pdf"
        self.plot_umap(DRUG_TO_MECH, "mechanism", out_file_selected, ALL_DRUGS, embs, idxs_by_drug, drug_markers, layer_list=selected_layers, nr_cols=4, fontsize=23, plot_title=False)
        plt.show()

        # # H. Compute clustering coefficient per layer using silhouette score with confidence intervals
        int_labels = np.array([DRUG_TO_MECH[d] for d in ALL_DRUGS])
        runUMAPClustering(embsClustering, int_labels, self.umapAggFile)

    def runSaliency(self, load_cached=True):
        """
        Computes and plots saliency maps for 6 common drugs with a uniform prompt,
        using unconstrained token output (allowed_ids=None), and saves results.
        Includes merged tensor and average scalar saliency maps.
        """

        # Selected drugs for saliency analysis
        drugs = ["Atorvastatin", "Metformin", "Levothyroxine", "Albuterol", "Omeprazole", "Gabapentin"]

        prompt_template = "A 30-year old patient takes {}. What drug does the patient take?"

        # check if cache exists:
        if load_cached and os.path.exists(self.saliencyFile):
            with open(self.saliencyFile, "rb") as f:
                data = pickle.load(f)
                print(data.keys())
                avg_sal = data["avg_sal"]
                avg_sal_per_prompt = data["avg_sal_per_prompt"]
                avg_sal_per_head = data["avg_sal_per_head"]
                avg_sal_per_mlp = data["avg_sal_per_mlp"]
            print(f"Loaded cached saliency data from {self.saliencyFile}")
        else:
            all_saliency_dicts = {}

            for drug in drugs:
                prompt = prompt_template.format(drug)
                print(f"\nRunning saliency for: {drug}")

                # Compute saliency dictionary
                sal_dict = self.saliency_for(prompt, desired_token=drug, allowed_ids=None)
                all_saliency_dicts[drug] = sal_dict

                # Print top saliency contributors
                lines = print_top_saliencies(sal_dict, title=drug)

                # Save plot
                plot_path = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_{drug.replace(' ', '_')}.pdf"
                self.plot_saliency_grid(
                    self.model_layers,
                    sal_dict,
                    plot_path,
                    f"{self.MODEL_NAME} saliency for drug: {drug}"
                )
                print(f"Saved saliency plot → {plot_path}")

                # Save text summary
                txt_path = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_{drug.replace(' ', '_')}.txt"
                with open(txt_path, "w") as f:
                    f.write("\n".join(lines))
                print(f"Saved saliency text → {txt_path}")

            # --- Compute merged tensor and avg scalar saliency ---
            avg_sal, avg_sal_per_prompt, avg_sal_per_head, avg_sal_per_mlp = merge_saliency_dicts(all_saliency_dicts, self.model)

        # # Print + save average saliency summary
        # lines_avg = print_top_saliencies(avg_sal, title="Average (Drugs)")

        # avg_plot_path = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_avg.pdf"
        # self.plot_saliency_grid(
        #     self.model_layers,
        #     avg_sal,
        #     avg_plot_path,
        #     f"{self.MODEL_NAME} average saliency across selected drugs"
        # )
        # print(f"Saved average saliency plot → {avg_plot_path}")

        # # Save everything in .npz
        # save_saliency(self.saliencyFile, avg_sal, avg_sal_per_prompt, 
        #              avg_sal_per_head, avg_sal_per_mlp)
        # print(f"Saved all saliency data → {self.saliencyFile}")

        # Print + save average saliency summary
        lines_avg = print_top_saliencies(avg_sal, title="Average (Drugs)")

        avg_plot_path = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_avg.pdf"
        self.plot_saliency_grid(
            self.model_layers,
            avg_sal,
            avg_plot_path,
            f"{self.MODEL_NAME} average saliency across selected drugs"
        )
        print(f"Saved average saliency plot → {avg_plot_path}")

        avg_txt_path = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}.txt"
        with open(avg_txt_path, "w") as f:
            f.write("\n".join(lines_avg))
        print(f"Saved average saliency text → {avg_txt_path}")

        # Plot saliency table
        table_path = f"results/{self.analysis_name}_saliency_table_{self.MODEL_NAME}_avg.pdf"
        self.plot_saliency_2D(
            avg_sal_per_head, avg_sal_per_mlp, table_path,
            f"{self.MODEL_NAME} saliency table for selected drugs"
        )
        print(f"Plotted saliency table → {table_path}")


    # Use base runLesioning; provide prompts here
    def generate_lesioning_prompts(self):
        return [
            "A patient takes Atorvastatin. What is this drug used for?",
            "A patient takes Metformin. What are the side effects?",
            "A patient takes Levothyroxine. How should it be taken?",
            "A patient takes Albuterol. When is it prescribed?",
            "A patient takes Omeprazole. What conditions does it treat?",
        ]

    def generate_activation_patching_prompts(self):
        """
        Generate prompt pairs for activation patching with single-word drug answers.
        
        Returns:
            List of tuples (clean_prompt, corrupt_prompt, clean_answer, corrupt_answer)
        """

        print("Running activation patching analysis for top 100 drugs...")

        prompt_pairs = [
            # Epilepsy vs Depression
            ("The patient has epilepsy. A common medication is ___", 
             "The patient has depression. A common medication is ___", 
             "Valproate", "Sertraline"),
            
            # Hypertension vs Diabetes
            ("The patient has hypertension. A common medication is ___", 
             "The patient has diabetes. A common medication is ___", 
             "Lisinopril", "Metformin"),
            
            # High cholesterol vs Asthma
            ("The patient has high cholesterol. A common medication is ___", 
             "The patient has asthma. A common medication is ___", 
             "Atorvastatin", "Albuterol"),
            
            # Anxiety vs Headache
            ("The patient has anxiety. A common medication is ___", 
             "The patient has headache. A common medication is ___", 
             "Prozac", "Tylenol"),
            
            # Heart disease vs Pain
            ("The patient has heart disease. A common medication is ___", 
             "The patient has chronic pain. A common medication is ___", 
             "Metoprolol", "Morphine"),
            
            # Insomnia vs Nausea
            ("The patient has insomnia. A common medication is ___", 
             "The patient has nausea. A common medication is ___", 
             "Ambien", "Zofran"),
            
            # Arthritis vs Migraine
            ("The patient has arthritis. A common medication is ___", 
             "The patient has migraine. A common medication is ___", 
             "Celebrex", "Imitrex"),
            
            # Infection vs Allergies
            ("The patient has infection. A common medication is ___", 
             "The patient has allergies. A common medication is ___", 
             "Amoxicillin", "Claritin"),
            
            # Seizures vs Anxiety
            ("The patient has seizures. A common medication is ___", 
             "The patient has anxiety. A common medication is ___", 
             "Keppra", "Xanax"),
            
            # # High blood pressure vs Heart failure
            # ("The patient has high blood pressure. A common medication is ___", 
            #  "The patient has heart failure. A common medication is ___", 
            #  "Amlodipine", "Digoxin")
        ]
        
        print("Generated activation patching prompts:")
        for i, (clean, corrupt, clean_ans, corrupt_ans) in enumerate(prompt_pairs):
            print(f"  {i+1}. Clean: '{clean}' → '{clean_ans}'")
            print(f"     Corrupt: '{corrupt}' → '{corrupt_ans}'")
        
        return prompt_pairs

