"""
UMAP of per-layer activations for different diseases

Many clinical-context prompts spanning 29 specialties x 5 common diseases x N synthetic
patients/disease (gender & age varied). Each prompt simply states the patient's age,
sex and diagnosis, then asks the model to name the disease.

Visualisation: a scatter plot of UMAP embeddings for each layer, coloured by specialty.
"""

# ── Imports ──────────────────────────────────────────────────────────────
import math, random, collections, torch, umap, numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from transformers import AutoModelForCausalLM, AutoTokenizer
from functools import partial
from act import Act
from saliency import *
import json
import os
import pickle
from utils import runUMAPClustering

class DiseaseAct(Act):
    def __init__(self, tok, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=False):
        super().__init__(tok, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=USE_ACCELERATE)
        self.analysis_name = "disease"
        self.set_files()

    # def selectLayersForPlots(self, nrPlots):
    #     """Select 4 evenly-spaced layers for plotting, including first and last layers."""
        
    #     # Always include first and last layers
    #     selected = [0, self.num_layers - 1]
        
    #     # Add 2 intermediate layers
    #     if self.num_layers > 2:
    #         step = (self.num_layers - 1) / 3  # Divide into 3 intervals
    #         for i in range(1, 3):
    #             layer_idx = int(round(i * step))
    #             if layer_idx not in selected:
    #                 selected.append(layer_idx)
        
    #     return sorted(selected)

    def plot_umap(self, embs, diseases_list, specialties_list, SPECIALTIES, MARKERS, ALL_DISEASES, DISEASE_TO_SPEC, out_file, layer_list=None, nr_cols=5, fontsize=20, plot_title=True, drug_row_height=2, drug_row_offset=9):
            print("Plotting …")

            # assign a unique colour per specialty (29 total)
            cmap_spec  = get_cmap("tab20")  # will cycle; extend when index>20
            spec_colors = {}
            for i, spec in enumerate(SPECIALTIES):
                spec_colors[spec] = cmap_spec(i % cmap_spec.N)

            # Use layer_list if provided, otherwise use all layers
            if layer_list is None:
                layer_list = list(range(len(embs)))
            
            # Dynamic grid rules for layers:
            # - if N < 4  → 1 x N
            # - if 4 ≤ N ≤ 8 → 2 x 4
            # - if N > 8  → 5 x M (M grows as needed)
            N = len(layer_list)
            if N <= 4:
                n_rows, n_cols = 1, N
                hspaceGridSpec = 0.25
                hspaceSubplots = 0.0
            elif N <= 8:
                n_rows, n_cols = 2, 4
                hspaceGridSpec = 0.25
                hspaceSubplots = 0.0
            else:
                n_cols = 5
                n_rows = math.ceil(N / n_cols)
                hspaceGridSpec = 0.07
                hspaceSubplots = 0.22

            # Figure and gridspec: top area for plots, bottom for legend
            fig_height = 8 + 4.5 * n_rows
            fig = plt.figure(figsize=(25, fig_height))
            gs = fig.add_gridspec(2, 1, height_ratios=[n_rows, 1], hspace=hspaceGridSpec)
            plot_gs = gs[0].subgridspec(n_rows, n_cols, wspace=0.18, hspace=hspaceSubplots)
            legend_gs = gs[1]
            axes = [[fig.add_subplot(plot_gs[r, c]) for c in range(n_cols)] for r in range(n_rows)]
            
            # precompute groupings by disease
            indices_by_disease = collections.defaultdict(list)
            for idx, dis in enumerate(diseases_list):
                indices_by_disease[dis].append(idx)

            # Map every disease (145 total) to one of the markers (duplicates inevitable)
            disease_markers = {
                d: MARKERS[i % len(MARKERS)]    # cycle when we run out
                for i, d in enumerate(ALL_DISEASES)
            }

            for plot_idx, l in enumerate(layer_list):
                ax = axes[plot_idx // n_cols][plot_idx % n_cols]
                emb = embs[l]

                for dis in indices_by_disease:          # 145 diseases
                    idxs  = indices_by_disease[dis]
                    spec  = DISEASE_TO_SPEC[dis]
                    colour = spec_colors[spec]
                    marker = disease_markers[dis]

                    ax.scatter(emb[idxs, 0], emb[idxs, 1],
                            c=[colour]*len(idxs),
                            marker=marker,
                            s=400, edgecolors='k', linewidths=0.9)

                # connect all six points within each disease cluster
                for dis, idxs in indices_by_disease.items():
                    pts = emb[idxs]
                    # draw as a simple polygon or group of segments in arbitrary order
                    ax.plot(pts[:,0], pts[:,1], lw=0.8, color='grey', alpha=0.4)

                ax.set_title(f"Layer {l}", fontsize=fontsize)

                if plot_idx // n_cols == n_rows - 1:
                    ax.set_xlabel("UMAP-1", fontsize=fontsize-2)
                if plot_idx % n_cols == 0:  # Leftmost column
                    ax.set_ylabel("UMAP-2", fontsize=fontsize-2)

                # ax.set_xlabel("UMAP1", fontsize=fontsize-2)
                # if plot_idx == 0:
                #     ax.set_ylabel("UMAP2", fontsize=fontsize-2)
                # Slightly increase tick label font size
                ax.tick_params(axis='both', which='major', labelsize=fontsize-2)
                
                ax.grid(True, ls="--", alpha=0.12)

            # hide extra panels (if grid has more slots than layers)
            for i in range(len(layer_list), n_rows*n_cols):
                axes[i // n_cols][i % n_cols].axis("off")

            # # ─ Legend ─
            # Bottom half: columns by specialty, bold titles, diseases listed with markers
            categories = list(SPECIALTIES.keys())
            legend_ax = fig.add_subplot(legend_gs)
            max_items_per_cat = max(len(SPECIALTIES[c]) for c in categories)
            legend_rows = max_items_per_cat + 2
            legend_ax.set_xlim(0, len(categories) * 1.2)
            legend_ax.set_ylim(0, legend_rows * 1.8)
            legend_ax.axis('off')

            col_width = 1.2
            for cat_idx, cat in enumerate(categories):
                x_pos = cat_idx * col_width + 0.08
                y_pos = legend_rows * 1.8 + 0.5
                legend_ax.text(x_pos, y_pos, cat, ha='left', va='center', fontsize=fontsize-3, fontweight='bold')

                for item_idx, dis in enumerate(SPECIALTIES[cat]):
                    if item_idx >= 15:
                        break
                    y_item = legend_rows + drug_row_offset - item_idx * drug_row_height
                    col = spec_colors[cat]
                    marker = disease_markers[dis]
                    legend_ax.scatter(x_pos, y_item + 0.03, c=[col], marker=marker, s=400, edgecolors='k', linewidths=0.9)
                    label = dis
                    if len(label) > 13:
                        label = label[:13] + "."
                    legend_ax.text(x_pos + 0.1, y_item, label, ha='left', va='center', fontsize=fontsize-5)

            if plot_title:
                fig.suptitle(f"{self.MODEL_NAME} - UMAP representations for diseases (selected layers)", y=0.95, fontsize=fontsize+5)
            fig.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.01)
            # else:  # Regular layout
            #     # Build legend handles (diseases on the right) as before
            #     handles_disease = [
            #         Line2D([0], [0],
            #             marker=disease_markers[d],
            #             color='w',
            #             markerfacecolor=spec_colors[DISEASE_TO_SPEC[d]],
            #             markeredgecolor='k',
            #             markersize=8,
            #             label=d[:14])
            #         for d in ALL_DISEASES
            #     ]

            #     fig.legend(handles=handles_disease, loc='center left', bbox_to_anchor=(0.81,0.5),
            #             ncol=1, frameon=False, fontsize=fontsize-3)
            #     if plot_title:
            #         fig.suptitle(f"{self.MODEL_NAME} - UMAP representations for diseases", fontsize=fontsize+5, y=0.995)
            #     fig.tight_layout(
            #         rect=[0,0,0.83,0.975],   # keep your outer margins
            #         pad=2.5,                   # global padding around the whole figure
            #         w_pad=2.0,                 # width padding between subplots (in font-size units)
            #         h_pad=2.5                  # height padding between subplots
            #     )

            
            fig.savefig(out_file, dpi=300)
            print(f"UMAP grid saved → {out_file}")
            plt.show()

            


    def runUmap(self, load_cached=True):

        # ── 0-bis.  Markers ──────────────────────────────────────────────────────
        # Matplotlib has ~15 easily readable markers.  We’ll cycle through them.
        MARKERS = [
            "o",  # circle
            "s",  # square
            "^",  # up-triangle
            "p",  # pentagon
            "D",  # diamond
            "P",  # plus-filled (a “cross square”)
            # "v",  # down-triangle
            "*",  # star
            "X",  # x-filled
            ">",  # right-triangle
            "h",  # hexagon-1
            "d",  # thin-diamond
            # "1",  # tri-down
            # "2",  # tri-up
            # "<",  # left-triangle

        ]

        # ── 1. Config ────────────────────────────────────────────────────────────
        a = random.randint  # shorthand for randint

        # 1. group cancer together
        # 2. take out family medicine, internal medicine, pediatrics, geriatrics, radiology psychiatry, pathology, preventitive medicine
        # 3. spell out the acronyms
        # ── 2. Knowledge base: specialties & diseases ────────────────────────────
        SPECIALTIES = collections.OrderedDict([
        ("Cardiology", [
            "Coronary artery disease", "Heart failure", "Atrial fibrillation", "Systemic hypertension", "Valvular heart disease",
            "Myocardial infarction", "Cardiomyopathy", "Pericarditis", "Arrhythmia", "Aortic aneurysm", "Congenital heart disease"
        ]),
        ("Orthopedics", [
            "Osteoarthritis", "Bone fracture", "Rotator-cuff tear", "Lumbar disk herniation", "ACL injury",
            "Osteoporosis", "Carpal tunnel syndrome", "Scoliosis", "Meniscus tear", "Spinal stenosis", "Tendonitis"
        ]),
        ("Oncology", [
            "Breast cancer", "Lung cancer", "Colorectal cancer", "Prostate cancer", "Pancreatic cancer",
            "Leukemia", "Lymphoma", "Ovarian cancer", "Melanoma", "Gastric cancer", "Bladder cancer"
        ]),
        ("Neurology", [
            "Ischemic stroke", "Epilepsy", "Alzheimer's disease", "Parkinson disease", "Multiple sclerosis",
            "Migraine", "Amyotrophic lateral sclerosis (ALS)", "Huntington's disease", "Peripheral neuropathy", "Meningitis", "Encephalitis" # replace meningitis
        ]),
        ("Obs & Gyn", [
            "Pelvic organ prolapse", "Uterine fibroids", "Polycystic ovary syndrome", "Endometriosis", "Abnormal uterine bleeding",
            "Ectopic pregnancy", "Gestational diabetes", "Preeclampsia", "Ovarian cysts", "Cervical cancer", "Menopause-related disorders" # replace cervical cancer
        ]),
        ("Dermatology", [
            "Acne vulgaris", "Atopic dermatitis", "Psoriasis", "Seborrheic dermatitis", "Rosacea",
            "Melanoma", "Basal cell carcinoma", "Squamous cell carcinoma", "Vitiligo", "Contact dermatitis", "Hidradenitis suppurativa"
        ]),
        ("Gastroent.", [
            "Gastroesophageal reflux disease", "Irritable bowel syndrome (IBS)", "Peptic ulcer", "Inflammatory bowel disease", "Chronic liver disease",
            "Celiac disease", "Hepatitis B", "Hepatitis C", "Gallstones", "Pancreatitis", "Diverticulitis" # replace hepatitis  B and C
        ]),
        ("Pulmonology", [
            "Asthma", "COPD", "Pneumonia", "Pulmonary embolism", "Idiopathic pulmonary fibrosis",
            "Tuberculosis", "Sleep apnea", "Bronchiectasis", "Lung cancer", "Acute respiratory distress syndrome (ARDS)", "Cystic fibrosis"
        ]),
        ("Urology", [
            "Benign prostatic hyperplasia", "Urinary tract infection", "Kidney stone", "Enlarged prostate", "Urinary incontinence",
            "Bladder cancer", "Prostatitis", "Hydronephrosis", "Erectile dysfunction", "Interstitial cystitis", "Renal cell carcinoma" # replace bladder cancer and renal cell carcinoma
        ]),
    ])
            
            ## Should do this group separately ##
            # ("Ophthalmology",            ["Cataract", "Glaucoma", "Macular degeneration", "Diabetic retinopathy", "Conjunctivitis"]),
            # ("Otolaryngology",          ["Otitis media", "Chronic rhinosinusitis", "Tonsillitis", "Sensorineural hearing loss", "Pharyngitis"]),
            # ("Infectious Diseases",      ["HIV", "Tuberculosis", "Sepsis", "Seasonal influenza", "COVID-19"]),
            # ("Nephrology",               ["Chronic kidney disease", "Acute kidney injury", "Glomerulonephritis", "Hypertensive nephrosclerosis", "Nephrotic syndrome"]),
            # ("Endocrinology",           ["Type 2 diabetes", "Hypothyroidism", "Hyperthyroidism", "Type 1 diabetes", "Cushing's syndrome"]),
            # ("Neurosurgery",             ["Intracranial tumor", "Traumatic brain injury", "Cerebral aneurysm", "Hydrocephalus", "Arteriovenous malformations"]),
            # ("Hematology",               ["Iron-deficiency anemia", "Venous thromboembolism (VTE)", "Acute myeloid leukemia", "Non-Hodgkin lymphoma", "Sickle-cell disease"]),
            # ("Allergy & Immunology",     ["Allergic rhinitis", "Food allergy", "Chronic urticaria", "Anaphylaxis", "Angioedema"]),
            # ("Rheumatology",             ["Rheumatoid arthritis", "Ankylosing spondylitis", "Systemic lupus erythematosus", "Gout", "Bursitis"]),


            ##### do not include these below, they are tricky. 
            # ("Geriatrics",               ["Dementia", "Osteoarthritis", "Hypertension", "Osteoporosis", "Urinary incontinence"]),
            # ("Psychiatry",               ["Major depression", "Generalized anxiety", "Bipolar disorder", "Schizophrenia", "Alcohol use disorder"]),
            # ("Anesthesiology",           ["Post-op pain", "Chronic back pain", "Peri-op OSA", "Malignant hyperthermia", "Drug anaphylaxis"]),
            # ("Radiology",                ["Community-acquired pneumonia", "Bone fracture", "Pulmonary embolism", "Intracranial hemorrhage", "Breast cancer"]),
            # ("Internal Medicine",        ["Hypertension", "Type 2 diabetes", "Hyperlipidemia", "Coronary artery disease", "Chronic kidney disease"]),
            # ("Family Medicine",          ["Common cold", "Hypertension", "Mechanical low-back pain", "Type 2 diabetes", "Obesity"]),
            # ("Pediatrics",               ["Acute otitis media", "Viral URI", "Asthma", "Viral gastroenteritis", "Atopic dermatitis"]),
            # ("Emergency Medicine",       ["Acute coronary syndrome", "Appendicitis", "Asthma/COPD exacerbation", "Traumatic fracture", "Sepsis"]),
            # ("Pathology",                ["Breast carcinoma", "Lung adenocarcinoma", "Colorectal adenocarcinoma", "Prostate adenocarcinoma", "DLBCL"]),
            # ("Preventive Medicine",      ["Hypertension", "Hyperlipidemia", "Prediabetes", "Obesity", "ASCVD prevention"]),
            # ("General Surgery",          ["Appendicitis", "Cholecystitis", "Inguinal hernia", "Colorectal cancer", "Breast cancer"]),
            

        # DISEASE_TO_SPEC = {d: spec for spec, dis in SPECIALTIES.items() for d in dis}
        DISEASE_TO_SPEC = {}
        duplicates = collections.defaultdict(list)

        for spec, dis in SPECIALTIES.items():
            for d in dis:
                if d not in DISEASE_TO_SPEC:          # keep first occurrence
                    DISEASE_TO_SPEC[d] = spec
                else:                                 # remember duplicates
                    duplicates[d].append(spec)

        if duplicates:                               # optional: warn or print
            print("Duplicate diseases:", {k: [DISEASE_TO_SPEC[k], *v]
                                            for k, v in duplicates.items()})

        ALL_DISEASES     = list(DISEASE_TO_SPEC)


        # ── 3. Prompt generation (6 patients/disease → 870 prompts) ─────────────
        GENDERS = ["male", "female"]
        AGE_BUCKETS = {
            "peds"  : (3, 12),
            "teen"  : (13, 19),
            "adult" : (20, 60),
            "senior": (61, 90),
        }

        ped_conditions = {"Acute otitis media", "Viral URI", "Asthma", "Otitis media", "Tonsillitis", "Conjunctivitis"}
        ger_conditions = {"Dementia", "Osteoporosis", "Heart failure", "Cataract", "Macular degeneration"}

        NR_PATIENTS_PER_DISEASE = 1  # 6 patients per disease

        def sample_age(disease: str) -> int:
            if disease in ped_conditions:
                lo, hi = AGE_BUCKETS["peds"]
            elif disease in ger_conditions:
                lo, hi = AGE_BUCKETS["senior"]
            else:
                lo, hi = AGE_BUCKETS["adult"]
            return a(lo, hi)


        # Try cache first
        cache_path = self.umapCacheFile
        if load_cached and os.path.exists(cache_path):
            with open(cache_path, 'rb') as f:
                cache = pickle.load(f)
                embs = cache.get('embs'); prompts = cache.get('prompts')
                embsClustering = cache.get('embsClustering')
                diseases_list = cache.get('diseases_list'); specialties_list = cache.get('specialties_list')
                DISEASE_TO_SPEC = cache.get('DISEASE_TO_SPEC'); ALL_DISEASES = cache.get('ALL_DISEASES')
                SPECIALTIES = cache.get('SPECIALTIES'); 
                MARKERS = cache.get('MARKERS')

        else:
            prompts, diseases_list, specialties_list = [], [], []
            for disease in ALL_DISEASES:
                for _ in range(NR_PATIENTS_PER_DISEASE):
                    # gender = random.choice(GENDERS)
                    # age    = sample_age(disease)
                    # pron_h = "He" if gender == "male" else "She"
                    # pron_h2= pron_h.lower()
                    text   = (
                        f"A patient was diagnosed with {disease}. "
                        # f"What does {pron_h2} suffer from?"
                    )
                    prompts.append(text)
                    diseases_list.append(disease)
                    specialties_list.append(DISEASE_TO_SPEC[disease])

            print(f"Generated {len(prompts)} prompts.")

            # ── 5. Capture activations ───────────────────────────────────────────────
            handles, layer_buffers = self.register_hooks()

            # ═══════════ 4. Run prompts (collect + print model reply) ═════════════════
            replies = self.run_model(prompts)

            for h in handles:
                h.remove()

            # ── 5. UMAP reductions ───────────────────────────────────────────────────
            embs = self.umap_fit(layer_buffers)
            embsClustering = self.umap_fit(layer_buffers, n_components=30)
            # Save cache for reuse
            
            
            cache = {
                'embs': embs,
                'embsClustering': embsClustering,
                'prompts': prompts,
                'diseases_list': diseases_list,
                'specialties_list': specialties_list,
                'DISEASE_TO_SPEC': DISEASE_TO_SPEC,
                'ALL_DISEASES': ALL_DISEASES,
                'SPECIALTIES': SPECIALTIES,
                'MARKERS': MARKERS,
            }
            with open(cache_path, 'wb') as f:
                pickle.dump(cache, f)
            print(f"Saved UMAP cache → {cache_path}")

        # ── 6b. Plot selected layers ──────────────────────────────────────────────
        selected_layers = self.selectLayersForPlots(nrPlots=4)
        print(f"Selected layers: {selected_layers}")
        out_file_selected = f"results/{self.analysis_name}_umap_{self.MODEL_NAME}_selected.pdf"
        self.plot_umap(embs, diseases_list, specialties_list, SPECIALTIES, MARKERS, ALL_DISEASES, DISEASE_TO_SPEC, out_file_selected, layer_list=selected_layers, nr_cols=4, fontsize=23, plot_title=False)

        # ── 6. Plotting ──────────────────────────────────────────────────────────
        layer_list = self.selectLayersForPlots(nrPlots=40)
        out_file = f"results/{self.analysis_name}_umap_{self.MODEL_NAME}.pdf"
        self.plot_umap(embs, diseases_list, specialties_list, SPECIALTIES, MARKERS, ALL_DISEASES, DISEASE_TO_SPEC, out_file, layer_list=layer_list, nr_cols=5, fontsize=23, plot_title=False)

        

        # ── 7. Compute clustering coefficient per layer using silhouette score with confidence intervals
        int_labels = np.array([DISEASE_TO_SPEC[d] for d in ALL_DISEASES])
        runUMAPClustering(embsClustering, int_labels, self.umapAggFile)

        
    def gen_saliency_prompts(self):
        prefix = "Answer with a single word only describing a disease."
        # answer should be a single-word disease ideally, or a common acronym.
        return [
            (prefix + "A patient has fever, productive cough, and chest crackles on exam. The most likely diagnosis is _", "Pneumonia", "pneumonia"),
            (prefix + "A patient has wheezing, shortness of breath, and nighttime cough. The most likely diagnosis is _", "Asthma", "asthma"),
            (prefix + "A patient has persistent thirst, frequent urination, and unexplained weight loss. The most likely diagnosis is _", "Diabetes", "diabetes"),
            (prefix + "A patient has tremor at rest, bradykinesia, and rigidity. The most likely diagnosis is _", "Parkinson's", "parkinsons"),
            (prefix + "A patient has severe unilateral throbbing headache with photophobia and nausea. The most likely diagnosis is _", "Migraine", "migraine"),
            (prefix + "A patient has jaundice, right upper quadrant pain, and elevated liver enzymes. The most likely diagnosis is _", "Hepatitis", "hepatitis"),
            (prefix + "A patient has burning epigastric pain relieved by eating and black tarry stools. The most likely diagnosis is _", "Ulcer", "ulcer"),
            (prefix + "A patient has heat intolerance, weight loss, and palpitations with a diffuse goiter. The most likely diagnosis is _", "Hyperthyroidism", "hyperthyroidism"),
            (prefix + "A patient has facial pain, purulent nasal discharge, and maxillary tenderness. The most likely diagnosis is _", "Sinusitis", "sinusitis"),
            (prefix + "A patient has swollen, tender cervical nodes, tonsillar exudates, and fever. The most likely diagnosis is _", "Tonsillitis", "tonsillitis"),
            (prefix + "A patient has severe lower right abdominal pain, fever, and rebound tenderness. The most likely diagnosis is _", "Appendicitis", "appendicitis"),
            (prefix + "A patient has burning urination, urinary frequency, and suprapubic discomfort. The most likely diagnosis is _", "Cystitis", "cystitis"),
            (prefix + "A patient has sudden facial droop, arm weakness, and slurred speech. The most likely diagnosis is _", "Stroke", "stroke"),
            (prefix + "A patient has itchy grouped vesicles on the elbows and gluten sensitivity. The most likely diagnosis is _", "Dermatitis", "dermatitis"),
            (prefix + "A patient has severe left flank pain radiating to the groin and hematuria. The most likely diagnosis is _", "Nephrolithiasis", "nephrolithiasis"),
            (prefix + "A patient has episodic palpitations, sweating, and headaches with hypertension. The most likely diagnosis is _", "Pheochromocytoma", "pheochromocytoma"),
            (prefix + "A patient has chronic productive cough, clubbing, and bronchiectasis on CT. The most likely diagnosis is _", "Bronchiectasis", "bronchiectasis"),
            (prefix + "A patient has proximal muscle weakness, heliotrope rash, and elevated CK. The most likely diagnosis is _", "Dermatomyositis", "dermatomyositis"),
            (prefix + "A patient has fatigue, cold intolerance, weight gain, and dry skin. The most likely diagnosis is _", "Hypothyroidism", "hypothyroidism"),
        ]

    # def runSaliency(self, load_cached=True):
    #     """
    #     Computes and plots saliency maps for three diseases using the same clinical prompt.
    #     Also computes the average saliency across them.
    #     """

    #     # Define diseases list outside the if/else block
    #     diseases = ["COPD", "Alzheimer's disease", "COVID-19"]
    #     all_saliency_dicts = {}
        
    #     # check if cache exists:
    #     if load_cached and os.path.exists(self.saliencyFile):
    #         with open(self.saliencyFile, "rb") as f:
    #             data = pickle.load(f)
    #             print(data.keys())
    #             avg_sal = data["avg_sal"]
    #             avg_sal_per_prompt = data["avg_sal_per_prompt"]
    #             avg_sal_per_head = data["avg_sal_per_head"]
    #             avg_sal_per_mlp = data["avg_sal_per_mlp"]
    #         print(f"Loaded cached saliency data from {self.saliencyFile}")
    #     else:
    #         prompt_template = (
    #             "Someone was diagnosed with {}. "
    #             "What do they suffer from?"
    #         )

    #         all_saliency = []
    #         all_saliency_dicts = {}

    #         for disease in diseases:
    #             prompt = prompt_template.format(disease)
    #             print(f"\nRunning saliency for: {disease}")
                
    #             # Run clamped-free saliency (raw logit)
    #             sal_dict = self.saliency_for(prompt, disease, allowed_ids=None)
    #             all_saliency_dicts[disease] = sal_dict
    #             all_saliency.append(sal_dict)

    #             # Print + plot
    #             lines = print_top_saliencies(sal_dict, title=disease)
    #             out_path = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_{disease.replace(' ', '_')}.pdf"
    #             self.plot_saliency_grid(
    #                 self.model_layers, sal_dict,
    #                 out_path,
    #                 f"{self.MODEL_NAME} saliency for {disease}"
    #             )
    #             print(f"Saved saliency plot → {out_path}")

    #         # --- Compute average saliency ---
    #         avg_sal, avg_sal_per_prompt, avg_sal_per_head, avg_sal_per_mlp = merge_saliency_dicts(all_saliency_dicts, self.model)

    #         # save everything in .npz
    #         save_saliency(self.saliencyFile, avg_sal, avg_sal_per_prompt, 
    #                  avg_sal_per_head, avg_sal_per_mlp)
    #         print(f"Saved all saliency data → {self.saliencyFile}")

    #     lines_avg = print_top_saliencies(avg_sal, title="Average")

    #     avg_path = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_avg.pdf"
    #     self.plot_saliency_grid(
    #         self.model_layers,
    #         avg_sal,
    #         avg_path,
    #         f"{self.MODEL_NAME} saliency averaged across: {', '.join(diseases)}"
    #     )
    #     print(f"Saved average saliency plot → {avg_path}")

    #     # Save textual summary
    #     txt_path = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}.txt"
    #     with open(txt_path, "w") as f:
    #         # Only write individual disease saliencies if we computed them (not from cache)
    #         if all_saliency_dicts:
    #             for disease in diseases:
    #                 f.write(f"\nTop saliencies for {disease}:\n")
    #                 f.write("\n".join(print_top_saliencies(all_saliency_dicts[disease], title=disease)))
    #                 f.write("\n\n")
    #         f.write("Average saliency:\n")
    #         f.write("\n".join(lines_avg))
    #     print(f"Saved saliency report → {txt_path}")

    #     # Plot saliency table
    #     table_path = f"results/{self.analysis_name}_2Dsaliency_{self.MODEL_NAME}_avg.pdf"
    #     self.plot_saliency_2D(
    #         avg_sal_per_head, avg_sal_per_mlp, table_path,
    #         f"{self.MODEL_NAME} saliency table for diseases"
    #     )
    #     print(f"Plotted saliency table → {table_path}")


    # Use base runLesioning; provide prompts here
    def generate_lesioning_prompts(self):
        return [
            "A 30-year-old female patient presents with chronic cough, shortness of breath, and wheezing. What is the most likely diagnosis?", # asthma
            "A 45-year-old male patient shows memory loss, confusion, and difficulty with daily tasks. What is the most likely diagnosis?", # alzheimer's disease
            "A 25-year-old patient has fever, dry cough, and loss of taste and smell. What is the most likely diagnosis?", # covid-19
            "A 60-year-old female patient reports excessive thirst, frequent urination, and unexplained weight loss. What is the most likely diagnosis?", # diabetes mellitus
            "A 35-year-old male patient has severe headaches, chest pain, and elevated blood pressure. What is the most likely diagnosis?", # hypertension / hypertensive heart disease
            "A 50-year-old female patient has joint pain, morning stiffness, and swelling in her hands. What is the most likely diagnosis?", # rheumatoid arthritis
            "A 28-year-old male patient experiences sudden seizures, loss of consciousness, and muscle convulsions. What is the most likely diagnosis?", # epilepsy
            "A 55-year-old female patient reports blurred vision, eye pain, and seeing halos around lights. What is the most likely diagnosis?", # glaucoma
            "A 40-year-old male patient has abdominal pain, bloody diarrhea, and significant weight loss. What is the most likely diagnosis?", # crohn's disease
            "A 32-year-old female patient shows a butterfly-shaped rash on her face, joint pain, and extreme fatigue. What is the most likely diagnosis?", # lupus
        ]
        
    def generate_activation_patching_prompts(self):
        """
        Generate 10 prompt pairs for activation patching with single-word disease-related answers.
        
        Returns:
            List of tuples (clean_prompt, corrupt_prompt, clean_answer, corrupt_answer)
        """


        prefix = "Answer with a single word only describing a disease. "

        prompt_pairs = [
            # Heart vs Lung
            ("A patient has chest pain and shortness of breath. The most likely diagnosis is ___", "Angina",
             "A patient has chest pain and coughing. The most likely diagnosis is ___", "Pneumonia"),
            
            # Brain vs Heart
            ("A patient has confusion and weakness. The most likely diagnosis is ___", "Stroke",
             "A patient has chest pain and weakness. The most likely diagnosis is ___", "Infarction"),
            
            # Lupus vs Hepatitis
            ("A patient has arthralgia and fatigue. The most likely diagnosis is ___", "Lupus",
             "A patient has jaundice and fatigue. The most likely diagnosis is ___", "Hepatitis"),
            
            # Diabetes vs Thyroid
            ("A patient has frequent urination and thirst. The most likely diagnosis is ___", "Diabetes",
             "A patient has weight gain and fatigue. The most likely diagnosis is ___", "Hypothyroidism"),
            
            # Cancer vs Infection
            ("A patient has weight loss and fatigue. The most likely diagnosis is ___", "Cancer",
             "A patient has fever and fatigue. The most likely diagnosis is ___", "Infection"),
            
            # Arthritis vs Gout
            ("A patient has joint pain and stiffness. The most likely diagnosis is ___", "Arthritis",
            "A patient has toe pain and swelling. The most likely diagnosis is ___", "Gout"),
            
            # Depression vs Anxiety
            ("A patient has sadness and fatigue. The most likely diagnosis is ___", "Depression",
             "A patient has worry and restlessness. The most likely diagnosis is ___", "Anxiety"),
            
            # Asthma vs COPD
            # ("A patient has wheezing and coughing. The most likely diagnosis is ___", "Asthma",
            #  "A patient has shortness of breath and coughing. The most likely diagnosis is ___", "COPD"), # many models said bronchitis or even asthma also
            
            # Migraine vs Sinusitis
            ("A patient has severe headache and nausea. The most likely diagnosis is ___", "Migraine",
             "A patient has dull headache and pressure. The most likely diagnosis is ___", "Sinusitis"),
            
        ]
        # add the prefix to each prompt
        prompt_pairs = [(prefix + prompt[0], prompt[1], prefix + prompt[2], prompt[3]) for prompt in prompt_pairs]
        
        return prompt_pairs
