import math, random, collections, torch, umap, numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from transformers import AutoModelForCausalLM, AutoTokenizer
from functools import partial
from act import Act
from saliency import *
import json
import pickle
import os
from utils import runUMAPClustering
            

class SymptomAct(Act):
    def __init__(self, tok, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=False):
        super().__init__(tok, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=USE_ACCELERATE)
        self.analysis_name = "symptom"
        self.set_files()

    # def selectLayersForPlots(self):
    #     """Select 4 evenly spaced layers including first and last."""
    #     if self.num_layers <= 4:
    #         return list(range(self.num_layers))
    #     selected = [0, self.num_layers - 1]
    #     if self.num_layers > 2:
    #         step = (self.num_layers - 1) / 3
    #         for i in range(1, 3):
    #             idx = int(round(i * step))
    #             if idx not in selected:
    #                 selected.append(idx)
    #     return sorted(selected)

    def plot_umap(self, embs, symptoms_list, symptom_groups_list, SYMPTOM_GROUPS, MARKERS, ALL_SYMPTOMS, SYMPTOM_TO_GROUP, out_file, layer_list, nr_cols=5, fontsize=20, plot_title=True, drug_row_height=4, drug_row_offset=25.5):
        print("Plotting …")

        cmap_groups = get_cmap("tab10")
        group_colors = {grp: cmap_groups(i % cmap_groups.N) for i, grp in enumerate(SYMPTOM_GROUPS)}

        # Build indices and markers
        indices_by_symptom = collections.defaultdict(list)
        for idx, symp in enumerate(symptoms_list):
            indices_by_symptom[symp].append(idx)
        symptom_markers = {sym: MARKERS[i % len(MARKERS)] for i, sym in enumerate(ALL_SYMPTOMS)}

        N = len(layer_list)
        if N <= 4:
            n_rows, n_cols = 1, N
            hspaceGridSpec = 0.25
            hspaceSubplots = 0.0
            largePlot = False
        elif N <= 8:
            n_rows, n_cols = 2, 4
            hspaceGridSpec = 0.25
            hspaceSubplots = 0.0
            largePlot = False
        else:
            n_cols = 5
            n_rows = math.ceil(N / n_cols)
            hspaceGridSpec = 0.07
            hspaceSubplots = 0.22
            largePlot = True

        # Figure and gridspec: top area for plots, bottom for legend
        fig_height = 8 + 4.5 * n_rows
        fig = plt.figure(figsize=(25, fig_height))
        gs = fig.add_gridspec(2, 1, height_ratios=[n_rows, 1], hspace=hspaceGridSpec)
        plot_gs = gs[0].subgridspec(n_rows, n_cols, wspace=0.18, hspace=hspaceSubplots)
        legend_gs = gs[1]
        # axes = [[fig.add_subplot(plot_gs[r, c]) for c in range(n_cols)] for r in range(n_rows)]
        
        # # Create custom figure: top 50% plots (1x4), bottom 50% legend by specialties
        # fig = plt.figure(figsize=(25, 12))
        # gs = fig.add_gridspec(2, 1, height_ratios=[1, 1], hspace=0.23)
        # plot_gs = gs[0].subgridspec(1, 4, wspace=0.1)
        # legend_gs = gs[1]

        # Plot the selected layers
        for plot_idx, layer in enumerate(layer_list):
            row = plot_idx // n_cols
            col = plot_idx % n_cols
            ax = fig.add_subplot(plot_gs[row, col])
            emb = embs[layer]
            for symp in indices_by_symptom:
                idxs = indices_by_symptom[symp]
                group = SYMPTOM_TO_GROUP[symp]
                colour = group_colors[group]
                marker = symptom_markers[symp]
                ax.scatter(emb[idxs, 0], emb[idxs, 1], c=[colour]*len(idxs), marker=marker,
                            s=400, edgecolors='k', linewidths=0.9)
            ax.set_title(f"Layer {layer}", fontsize=fontsize)
            
            if plot_idx // n_cols == n_rows - 1:
                ax.set_xlabel("UMAP-1", fontsize=fontsize-2)
            if plot_idx % n_cols == 0:  # Leftmost column
                ax.set_ylabel("UMAP-2", fontsize=fontsize-2)

            # Slightly increase tick label font size
            ax.tick_params(axis='both', which='major', labelsize=fontsize-2)
            ax.grid(True, ls="--", alpha=0.12)

        # Bottom legend area: columns per medical specialty
        categories = list(SYMPTOM_GROUPS.keys())
        n_categories = len(categories)
        legend_ax = fig.add_subplot(legend_gs)
        max_items_per_cat = max(len(SYMPTOM_GROUPS[c]) for c in categories)
        legend_rows = max_items_per_cat + 2
        legend_ax.set_xlim(0, n_categories * 1.2)
        legend_ax.set_ylim(0, legend_rows * 1.8)
        legend_ax.axis('off')

        col_width = 1.2
        for cat_idx, cat in enumerate(categories):
            x_pos = cat_idx * col_width + 0.07
            y_pos = legend_rows * 1.8 + 0
            # Bold category title
            legend_ax.text(x_pos, y_pos, cat, ha='left', va='center', fontsize=fontsize-(0 if largePlot else 3), fontweight='bold')
            # Symptom names
            syms_in_cat = list(SYMPTOM_GROUPS[cat])
            for s_idx, sym in enumerate(syms_in_cat):
                if s_idx >= 15:
                    break
                y_item = legend_rows + drug_row_offset - s_idx * drug_row_height
                col = group_colors[cat]
                marker = symptom_markers[sym]
                legend_ax.scatter(x_pos + 0.03, y_item, c=[col], marker=marker, s=400, edgecolors='k', linewidths=0.9)
                # Abbreviate to 14 chars with a dot if truncated
                label = sym
                if len(label) > 11:
                    label = label[:11] + "."
                legend_ax.text(x_pos + 0.13, y_item, label, ha='left', va='center', fontsize=fontsize-(2 if largePlot else 5))

        if plot_title:
            fig.suptitle(f"{self.MODEL_NAME} - UMAP representations for symptoms (selected layers)", y=0.95, fontsize=fontsize+5)
        fig.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.01)

        # else:
        #     # Original full-layers plot (unchanged)
        #     n_cols, n_rows = 5, math.ceil(self.num_layers / 5)
        #     fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols*5.5, n_rows*4.5), squeeze=False)

        #     for l, emb in enumerate(embs):
        #         ax = axes[l // n_cols, l % n_cols]
        #         for symp in indices_by_symptom:
        #             idxs = indices_by_symptom[symp]
        #             group = SYMPTOM_TO_GROUP[symp]
        #             colour = group_colors[group]
        #             marker = symptom_markers[symp]
        #             ax.scatter(emb[idxs, 0], emb[idxs, 1], c=[colour]*len(idxs), marker=marker,
        #                        s=100, edgecolors='k', linewidths=0.2)
        #         ax.set_title(f"Layer {l}", fontsize=20)
        #         ax.set_xticks([]); ax.set_yticks([])
        #         ax.grid(True, ls="--", alpha=0.12)

        #     for i in range(self.num_layers, n_rows*n_cols):
        #         axes[i // n_cols, i % n_cols].axis("off")

        #     from matplotlib.lines import Line2D
        #     handles_symptoms = []
        #     for group, symptoms in SYMPTOM_GROUPS.items():
        #         for sym in symptoms:
        #             handles_symptoms.append(
        #                 Line2D([0], [0], marker=symptom_markers[sym], color='w', markerfacecolor=group_colors[group],
        #                        markeredgecolor='k', markersize=15, label=f"{sym} ({group})")
        #             )

        #     fig.legend(handles=handles_symptoms, loc='center left', bbox_to_anchor=(0.81,0.5),
        #                ncol=1, frameon=False, fontsize=20)
        #     fig.suptitle(f"{self.MODEL_NAME} - UMAP representations for symptoms", fontsize=25, y=0.995)
        #     fig.tight_layout(rect=[0,0,0.83,0.985], pad=2.5, w_pad=2.0, h_pad=2.5)

        # Save
        fig.savefig(out_file, dpi=300)
        print(f"UMAP grid saved → {out_file}")
        plt.show()

    def runUmap(self, load_cached=True):
        

        MARKERS = ["o", "s", "^", "p", "D", "P", "*", "X", ">", "h"]

        # SYMPTOM_GROUPS = collections.OrderedDict({
        #     "Pain": ["headache", "chest pain", "abdominal pain", "joint pain", "back pain"],
        #     "Neurological": ["dizziness", "numbness", "tingling", "weakness", "seizures"],
        #     "Respiratory": ["shortness of breath", "cough", "wheezing", "sore throat", "nasal congestion"],
        #     "Digestive": ["nausea", "vomiting", "diarrhea", "constipation", "heartburn"],
        #     "General": ["fatigue", "fever", "chills", "weight loss", "night sweats"],
        #     "Cardiovascular": ["palpitations", "swelling", "high blood pressure", "low blood pressure", "fainting"],
        # })
        
        SYMPTOM_GROUPS = collections.OrderedDict({
            "Pain": [
                "headaches", "chest pain", "abdominal pain", "joint pain", "back pain", "muscle aches", "pelvic pain", "toothache", "pleuritic chest pain"
            ],
            "Neurological": [
                "dizziness", "light-headedness", "paresthesia", "muscle weakness", "seizures", "confusion", "syncope", "tremor", "insomnia", "syncope", "tremor", "insomnia"
            ],
            "Respiratory": [
                "cough", "shortness of breath", "wheezing", "sore throat", "nasal congestion", "sneezing", "hemoptysis", "stridor", "chest tightness"
            ],
            "Digestive": [
                "nausea", "vomiting", "diarrhea", "constipation", "heartburn", "bloating", "dysphagia", "dyspepsia", "flatulence"
            ],
            # "General": [
            #     "fever", "chills", "fatigue", "malaise", "night sweats", "weight loss", "weight gain", "cachexia", "anorexia", "excessive sweating"
            # ],
            "Cardio.": [ # CARDIOVASCULAR
                "palpitations", "tachycardia", "bradycardia", "arrhythmia", "hypertension", "hypotension", "claudication", "edema", "orthopnea"
            ],
            "Dermatol.": [
                "rash", "itching", "hives", "dry skin", "blistering", "erythema", "petechiae", "purpura", "ecchymosis"
            ],
            "Musculosk.": [ # MUSCULOSKELETAL
                "muscle cramps", "muscle weakness", "joint swelling", "tendon pain", "muscle stiffness", "reduced range of motion", "muscle spasms", "bone pain", "myalgia"
            ],
            "Psych.": [ # PSYCHIATRIC/BEHAVIORAL
                "anxiety", "depression", "irritability", "apathy", "panic attacks", "phobias", "hallucinations", "delusions", "suicidal ideation", "obsessive thoughts"
            ],
            "Genitourin.": [ # GERITOURINARY
                "dysuria", "hematuria", "urinary frequency", "urinary incontinence", "urinary retention", "polyuria", "nocturia", "hematospermia", "urethral discharge"
            ],
            "ENT & Ocular": [
                "hearing loss", "tinnitus", "ear pain", "nosebleeds", "nasal discharge", "blurred vision", "double vision", "eye redness"
            ]
        })

        SYMPTOM_TO_GROUP = {sym: grp for grp, syms in SYMPTOM_GROUPS.items() for sym in syms}
        ALL_SYMPTOMS = list(SYMPTOM_TO_GROUP.keys())

        cache_path = self.umapCacheFile
        if load_cached and os.path.exists(cache_path):
            with open(cache_path, 'rb') as f:
                cache = pickle.load(f)
                embs = cache['embs']
                embsClustering = cache['embsClustering']
                prompts = cache['prompts']
                symptoms_list = cache['symptoms_list']
                symptom_groups_list = cache['symptom_groups_list']
                SYMPTOM_GROUPS = cache['SYMPTOM_GROUPS']
                MARKERS = cache['MARKERS']
                ALL_SYMPTOMS = cache['ALL_SYMPTOMS']
                SYMPTOM_TO_GROUP = cache['SYMPTOM_TO_GROUP']
        else:
 
            prompts, symptoms_list, symptom_groups_list = [], [], []
            NR_PATIENTS_PER_SYMPTOM = 1

            for symptom in ALL_SYMPTOMS:
                for _ in range(NR_PATIENTS_PER_SYMPTOM):
                    text = f"A patient has {symptom}."
                    prompts.append(text)
                    symptoms_list.append(symptom)
                    symptom_groups_list.append(SYMPTOM_TO_GROUP[symptom])

            print(f"Generated {len(prompts)} prompts.")

            handles, layer_buffers = self.register_hooks()
            replies = self.run_model(prompts)

            for h in handles:
                h.remove()

            embs = self.umap_fit(layer_buffers)
            embsClustering = self.umap_fit(layer_buffers, n_components=30)

            # Save cache
            cache = {
                'embs': embs,
                'embsClustering': embsClustering,
                'prompts': prompts,
                'symptoms_list': symptoms_list,
                'symptom_groups_list': symptom_groups_list,
                'SYMPTOM_GROUPS': SYMPTOM_GROUPS,
                'MARKERS': MARKERS,
                'ALL_SYMPTOMS': ALL_SYMPTOMS,
                'SYMPTOM_TO_GROUP': SYMPTOM_TO_GROUP,
            }
            with open(cache_path, 'wb') as f:
                pickle.dump(cache, f)
            print(f"Saved UMAP cache → {cache_path}")

        
        # Plot selected layers as a separate figure (1x4 top + bottom legend)
        selected_layers = self.selectLayersForPlots(nrPlots=4)
        out_file_selected = f"results/{self.analysis_name}_umap_{self.MODEL_NAME}_selected.pdf"
        self.plot_umap(
            embs,
            symptoms_list,
            symptom_groups_list,
            SYMPTOM_GROUPS,
            MARKERS,
            ALL_SYMPTOMS,
            SYMPTOM_TO_GROUP,
            out_file=out_file_selected,
            layer_list=selected_layers,
            nr_cols=4,
            fontsize=23,
            plot_title=False,
            drug_row_height=2,
            drug_row_offset=9,
        )

        # Plot all layers (original figure remains unchanged)
        
        selected_layers_many = self.selectLayersForPlots(nrPlots=40)
        out_file = f"results/{self.analysis_name}_umap_{self.MODEL_NAME}.pdf"
        self.plot_umap(embs, symptoms_list, symptom_groups_list, SYMPTOM_GROUPS, MARKERS, ALL_SYMPTOMS, SYMPTOM_TO_GROUP, out_file=out_file, layer_list=selected_layers_many,
        nr_cols=5,
            fontsize=20,
            plot_title=False,
            drug_row_height=2,
            drug_row_offset=9)

        

        # ── 8. Compute clustering coefficient per layer using silhouette score with confidence intervals
        int_labels = np.array([SYMPTOM_TO_GROUP[s] for s in symptoms_list])
        
        runUMAPClustering(embsClustering, int_labels, self.umapAggFile)

    def gen_saliency_prompts(self):
        prefix = "Answer with a single word only describing a symptom."
        return [
            (prefix + " What is the main symptom present in Asthma?", "wheezing", "wheezing"),
            (prefix + " What is the main symptom present in Pneumonia?", "cough", "cough"),
            (prefix + " What is the main symptom present in Influenza?", "fever", "fever"),
            (prefix + " What is the main symptom present in Diabetes?", "thirst", "thirst"),
            (prefix + " What is the main symptom present in Hyperthyroidism?", "palpitations", "palpitations"),
            (prefix + " What is the main symptom present in Hypothyroidism?", "fatigue", "fatigue"),
            # (prefix + " What is the main symptom present in Angina?", "chest", "angina_chest"),
            (prefix + " What is the main symptom present in Heart failure?", "dyspnea", "dyspnea"),
            (prefix + " What is the main symptom present in Gastroesophageal reflux?", "heartburn", "heartburn"),
            # (prefix + " What is the main symptom present in Peptic ulcer?", "epigastric", "pud_epigastric"),
            (prefix + " What is the main symptom present in c?", "cramping", "cramping"),
            # (prefix + " What is the main symptom present in Rheumatoid arthritis?", "stiffness", "ra_stiffness"),
            # (prefix + " What is the main symptom present in Gout?", "podagra", "gout_podagra"),
            (prefix + " What is the main symptom present in Psoriasis?", "rash", "rash"),
            (prefix + " What is the main symptom present in Eczema?", "itching", "itching"),
            # (prefix + " What is the main symptom present in Tuberculosis?", "sweats", "tb_sweats"),
            # (prefix + " What is the main symptom present in Depression?", "sadness", "depression_sadness"),
            # (prefix + " What is the main symptom present in Anxiety?", "worry", "anxiety_worry"),
            # (prefix + " What is the main symptom present in Kidney stones?", "colic", "stones_colic"),
        ]


    # def runSaliency(self, load_cached=True):
    #     """
    #     Computes and plots saliency maps for selected symptoms using a diagnostic prompt.
    #     Stores figures, textual summaries, raw saliency tensors, and also the average saliency.
    #     """

    #     # check if cache exists:
    #     if load_cached and os.path.exists(self.saliencyFile):
    #         with open(self.saliencyFile, "rb") as f:
    #             data = pickle.load(f)
    #             print(data.keys())
    #             avg_sal = data["avg_sal"]
    #             avg_sal_per_prompt = data["avg_sal_per_prompt"]
    #             avg_sal_per_head = data["avg_sal_per_head"]
    #             avg_sal_per_mlp = data["avg_sal_per_mlp"]
    #         print(f"Loaded cached saliency data from {self.saliencyFile}")
    #     else:
    #         # Use first symptom from each group, then truncate to top 5
    #         SYMPTOM_GROUPS = collections.OrderedDict({
    #             "Pain": ["headaches"],
    #             "Neurological": ["dizziness"],
    #             "Respiratory": ["cough"],
    #             "Digestive": ["nausea"],
    #             "General": ["fever"],
    #             "Cardiovascular": ["palpitations"],
    #             "Dermatological": ["rash"],
    #             "Musculoskeletal": ["muscle cramps"],
    #             "Psychiatric/Behavioral": ["anxiety"],
    #             "Genitourinary": ["dysuria"],
    #             "ENT & Ocular": ["hearing loss"]
    #         })

    #         selected_symptoms = [v[0] for v in SYMPTOM_GROUPS.values()][:2]
    #         prompt_template = "Answer with a single word. A patient has {}. What symptom do they have?"

    #         all_saliency_dicts = {}

    #         for symptom in selected_symptoms:
    #             prompt = prompt_template.format(symptom)
    #             print(f"\nRunning saliency for: {symptom}")

    #             # Compute saliency
    #             sal_dict = self.saliency_for(prompt, desired_token=symptom, allowed_ids=None)
    #             all_saliency_dicts[symptom] = sal_dict

    #             # Print top contributors
    #             lines = print_top_saliencies(sal_dict, title=symptom)

    #             # Save saliency plot
    #             plot_path = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_{symptom.replace(' ', '_')}.pdf"
    #             self.plot_saliency_grid(
    #                 self.model_layers,
    #                 sal_dict,
    #                 plot_path,
    #                 f"{self.MODEL_NAME} saliency for symptom: {symptom}"
    #             )
    #             print(f"Saved saliency plot → {plot_path}")

    #             # Save text summary
    #             txt_path = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_{symptom.replace(' ', '_')}.txt"
    #             with open(txt_path, "w") as f:
    #                 f.write("\n".join(lines))
    #             print(f"Saved saliency text → {txt_path}")

    #         # --- Compute average saliency ---
    #         avg_sal, avg_sal_per_prompt, avg_sal_per_head, avg_sal_per_mlp = merge_saliency_dicts(all_saliency_dicts, self.model)

    #         # Save raw saliency tensors
    #         save_saliency(self.saliencyFile, avg_sal, avg_sal_per_prompt, 
    #                  avg_sal_per_head, avg_sal_per_mlp)
    #         print(f"Saved saliency dictionary → {self.saliencyFile}")

    #     # Print top average saliencies
    #     lines_avg = print_top_saliencies(avg_sal, title="Average (Symptoms)")

    #     # Save average saliency plot
    #     avg_plot_path = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_avg.pdf"
    #     self.plot_saliency_grid(
    #         self.model_layers,
    #         avg_sal,
    #         avg_plot_path,
    #         f"{self.MODEL_NAME} average saliency across symptoms"
    #     )
    #     print(f"Saved average saliency plot → {avg_plot_path}")

    #     # Save text report
    #     avg_txt_path = f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_avg.txt"
    #     with open(avg_txt_path, "w") as f:
    #         f.write("\n".join(lines_avg))
    #     print(f"Saved average saliency text → {avg_txt_path}")

    #     # Plot saliency table
    #     table_path = f"results/{self.analysis_name}_2Dsaliency_{self.MODEL_NAME}_avg.pdf"
    #     self.plot_saliency_2D(
    #         avg_sal_per_head, avg_sal_per_mlp, table_path,
    #         f"{self.MODEL_NAME} saliency table for symptoms"
    #     )
    #     print(f"Plotted saliency table → {table_path}")


    # Use base runLesioning; provide prompts here
    def generate_lesioning_prompts(self):
        return [
            "A patient has Alzheimer's disease. What are the common symptoms?",
            "A patient has diabetes. What symptoms should they watch for?",
            "A patient has asthma. What are the typical symptoms?",
            "A patient has multiple sclerosis. What symptoms might they experience?",
            "A patient has rheumatoid arthritis. What are the main symptoms?",
            "A patient has glaucoma. What symptoms indicate this condition?",
            "A patient has epilepsy. What symptoms occur during seizures?",
            "A patient has Crohn's disease. What are the common symptoms?",
            "A patient has lupus. What symptoms are associated with this disease?",
            "A patient has bipolar disorder. What are the characteristic symptoms?",
        ]
        

    def generate_activation_patching_prompts(self):
        """
        Generate 10 prompt pairs for activation patching with single-word symptom-related answers.
        
        Returns:
            List of tuples (clean_prompt, corrupt_prompt, clean_answer, corrupt_answer)
        """
        print("Running activation patching analysis for symptoms...")

        prefix = "Answer with a single word only describing a symptom. "

        # I went and cleaned them up, commented out the ones that had answers in two words, like chest pain, joint pain, etc.
        prompt_pairs = [ # Disease -> Symptom activation patching pairs
            # Neurological diseases
            ("For migraine, a common symptom is ___", "headache",
             "For epilepsy, a common symptom is ___", "seizures"),
             
            
            # Respiratory diseases
            ("For asthma, a common symptom is ___", "wheezing",
             "For pneumonia, a common symptom is ___", "cough"),
            
            # # Cardiovascular diseases
            # ("For angina, a common symptom is ___", "chest pain", # need to replace
            #  "For heart failure, a common symptom is ___", "shortness of breath"),
            
            # Gastrointestinal diseases
            ("For gastritis, a common symptom is ___", "nausea",
             "For irritable bowel syndrome, a common symptom is ___", "bloating"),
            
            # # Musculoskeletal diseases
            # ("For arthritis, a common symptom is ___", "pain",
            #  "For fibromyalgia, a common symptom is ___", "muscle pain"),
            
            # Dermatological diseases
            ("For eczema, a common symptom is ___", "itching",
             "For psoriasis, a common symptom is ___", "rash"), # need to replace
            
            # Endocrine diseases
            ("For diabetes, a common symptom is ___", "thirst",
             "For hyperthyroidism, a common symptom is ___", "tremor"), # weight loss, rapid heartbeat, anxiety, heat intolerance,
            
            # Infectious diseases
            ("For influenza, a common symptom is ___", "fever",
             "For tuberculosis, a common symptom is ___", "cough"),
            
            # Psychiatric diseases
            ("For depression, a common symptom is ___", "sadness",
             "For anxiety disorder, a common symptom is ___", "worry"),
            
            # # Urological diseases
            # ("For urinary tract infection, a common symptom is ___", "burning urination",
            #  "For kidney stones, a common symptom is ___", "flank pain"),
        ]

        # add the prefix to each prompt
        prompt_pairs = [(prefix + prompt[0], prompt[1], prefix + prompt[2], prompt[3]) for prompt in prompt_pairs]

        return prompt_pairs

