# ── Imports ──────────────────────────────────────────────────────────────
import math, torch, umap, numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
from transformers import AutoModelForCausalLM, AutoTokenizer
from act import Act
import pickle, os
            
class DiseaseProgressionAct(Act):
    def __init__(self, tok, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=False):
        super().__init__(tok, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=USE_ACCELERATE)
        self.analysis_name = "progression"
        self.set_files()

    def selectLayersForPlots(self, nrPlots=4):
        """Select 5 evenly-spaced layers including first and last"""
        # Calculate step size to get 3 intermediate layers between first and last
        if nrPlots > self.num_layers:
            return list(range(self.num_layers))
        else:
            step = (self.num_layers - 1) / (nrPlots - 1)  # 4 intervals for 5 points
            selected_layers = [0]  # Always include first layer
            for i in range(1, nrPlots-1):  # Add 3 intermediate layers
                layer_idx = int(round(i * step))
                if layer_idx not in selected_layers:  # Avoid duplicates
                    selected_layers.append(layer_idx)
            
            # always include last layer
            selected_layers.append(self.num_layers - 1)
            selected_layers = sorted(selected_layers)  # Ensure they're in order
            return selected_layers

    def plot_umap(self, embs, disease_stages, disease_labels, DISEASES, MARKERS, out_file, layer_list=None, nr_cols=5, fontsize=24, plot_title=False, circularity_subplot=True):
        print("Plotting …")

        # Initialize circularity results storage
        circularity_results = {
            'closest_to_first_stage': {},  # {disease: {layer: closest_stage}}
            'closest_to_end_stage': {}     # {disease: {layer: closest_stage}}
        }

        # If layer_list is not provided, use all layers
        if layer_list is None:
            layer_list = list(range(self.num_layers))
        
        # Select only 4 layers for UMAP plots
        # if len(layer_list) >= 4:
        #     # Select 4 evenly-spaced layers including first and last
        #     step = (len(layer_list) - 1) / 3  # 3 intervals for 4 points
        #     selected_layers = [layer_list[0]]  # Always include first layer
        #     for i in range(1, 3):  # Add 2 intermediate layers
        #         layer_idx = int(round(i * step))
        #         if layer_idx < len(layer_list) and layer_list[layer_idx] not in selected_layers:
        #             selected_layers.append(layer_list[layer_idx])
        #     if layer_list[-1] not in selected_layers:  # Always include last layer
        #         selected_layers.append(layer_list[-1])
        #     selected_layers = sorted(selected_layers)
        # else:
        
        selected_layers = layer_list
        n_layers_to_plot = len(selected_layers)
        smallPlot = n_layers_to_plot < 9
        # Add two extra subplots for circularity analysis if requested
        total_subplots = n_layers_to_plot + (2 if circularity_subplot else 0)
        n_cols, n_rows = nr_cols, math.ceil(total_subplots / nr_cols)
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols*5.5, n_rows*5), squeeze=False)
        
        # Set figure background to white
        fig.patch.set_facecolor('white')

        cmap_stage = plt.cm.RdYlGn_r  # Red-Yellow-Green reversed (green=early, red=late)
        norm_stage = Normalize(vmin=0, vmax=1)

        disease_markers = {disease: MARKERS[i % len(MARKERS)] for i, disease in enumerate(DISEASES)}

        for plot_idx, layer_idx in enumerate(selected_layers):
            ax = axes[plot_idx // n_cols, plot_idx % n_cols]
            emb = embs[layer_idx]

            for disease in DISEASES:
                idxs = [i for i, d in enumerate(disease_labels) if d == disease]
                stages = [disease_stages[i] for i in idxs]
                colors = cmap_stage(norm_stage(stages))
                marker = disease_markers[disease]
                nr_stages = len(idxs)

                ax.scatter(emb[idxs, 0], emb[idxs, 1],
                           c=colors, marker=marker,
                           s=240, edgecolors='k', linewidths=0.3, alpha=1)

                # Connect points by arrows according to disease stage
                sorted_idxs = [idx for _, idx in sorted(zip(stages, idxs))]
                for i in range(len(sorted_idxs) - 1):
                    start, end = emb[sorted_idxs[i]], emb[sorted_idxs[i+1]]
                    # Color arrow to match the source point
                    source_stage = stages[sorted_idxs.index(sorted_idxs[i])]
                    arrow_color = cmap_stage(norm_stage(source_stage))
                    ax.annotate('', xy=end, xytext=start,
                                arrowprops=dict(arrowstyle="->", color=arrow_color, lw=5, alpha=0.85))

            ax.set_title(f"Layer {layer_idx}", fontsize=fontsize)
            
            # Add axis labels and ticks
            ax.set_xlabel("UMAP1", fontsize=fontsize-2)
            if plot_idx % n_cols == 0:  # Only add y-axis label to left-most subplots
                ax.set_ylabel("UMAP2", fontsize=fontsize-2)
            
            # Add ticks and tick labels
            ax.tick_params(axis='both', which='major', labelsize=fontsize-4)
            ax.grid(True, ls="--", alpha=0.1)
            
            # Set white background for UMAP subplots
            ax.set_facecolor('white')

        # Add circularity analysis subplots if requested
        if circularity_subplot:

            if smallPlot:
                fontsize_circ = fontsize
            else:
                fontsize_circ = fontsize-3

            # First circularity plot: closest stage to first stage
            circularity_plot_idx = n_layers_to_plot
            ax_circ = axes[circularity_plot_idx // n_cols, circularity_plot_idx % n_cols]
            
            # For each disease and layer, find the stage (3-7) closest to the first stage
            for disease in DISEASES:
                disease_idxs = [i for i, d in enumerate(disease_labels) if d == disease]
                disease_stages_subset = [disease_stages[i] for i in disease_idxs]
                
                # Sort by disease stage to get progression order
                sorted_pairs = sorted(zip(disease_stages_subset, disease_idxs))
                sorted_idxs = [idx for _, idx in sorted_pairs]
                sorted_stages = [stage for stage, _ in sorted_pairs]
                
                later_stage_indices = list(range(2, nr_stages))  
                
                # Initialize disease results
                circularity_results['closest_to_first_stage'][disease] = {}
                
                # For each layer, find which later stage is closest to the first stage
                # Use all layers (0 to num_layers-1) for the circularity analysis, not just selected layers
                all_layers = list(range(self.num_layers))
                closest_stages = []
                for layer_idx in all_layers:
                    emb = embs[layer_idx]
                    first_emb = emb[sorted_idxs[0]]  # First embedding (earliest stage)
                    
                    # Calculate distances to all later stages (stages 3-7, indices 3-6)
                    distances_to_later_stages = []
                    for stage_idx in later_stage_indices:
                        # if stage_idx < len(sorted_idxs):  # Make sure we don't go out of bounds
                        dist = np.linalg.norm(emb[sorted_idxs[stage_idx]] - first_emb)
                        # Convert 0-based index to 1-based stage number: index 3->stage 4, index 4->stage 5, etc.
                        stage_number = stage_idx + 1
                        distances_to_later_stages.append((dist, stage_number))
                    
                    # Find the closest stage
                    if distances_to_later_stages:
                        closest_dist, closest_stage_num = min(distances_to_later_stages)
                        closest_stages.append(closest_stage_num)
                        circularity_results['closest_to_first_stage'][disease][layer_idx] = closest_stage_num
                    else:
                        closest_stages.append(None)
                        circularity_results['closest_to_first_stage'][disease][layer_idx] = None
                
                # Collect all valid points for this disease
                valid_layers = []
                valid_stages = []
                
                for layer_idx, closest_stage in zip(all_layers, closest_stages):
                    if closest_stage is not None:
                        valid_layers.append(layer_idx)
                        valid_stages.append(closest_stage)
                
                if valid_layers:  # Only plot if we have valid data
                    # Plot a single line for this disease
                    marker = disease_markers[disease]
                    color = plt.cm.Set1(list(DISEASES).index(disease))
                    
                    # Plot the line connecting all points
                    ax_circ.plot(valid_layers, valid_stages, 
                               marker=marker, color=color, linewidth=3, markersize=12,
                               label=disease, alpha=0.8)
            
            ax_circ.set_title("Closest Stage to First Stage (3-%d)" % nr_stages, fontsize=fontsize_circ)
            ax_circ.set_ylabel("Closest Disease Stage (3-%d)" % nr_stages, fontsize=fontsize_circ-2)
            ax_circ.set_xlabel("Layer", fontsize=fontsize_circ-2)
            
            ax_circ.grid(True, ls="--", alpha=0.3)
            ax_circ.legend(fontsize=fontsize_circ-8, loc='lower center', ncol=2)
            ax_circ.set_ylim(0.5, nr_stages+0.5)  # Set y-axis limits to show stages 3-7 clearly
            
            # # Set custom y-tick labels
            # y_ticks = list(range(1, nr_stages + 1))
            # y_labels = []
            # for i in range(1, nr_stages + 1):
            #     if i == 1:
            #         y_labels.append("(healthy) 1")
            #     elif i == nr_stages:
            #         y_labels.append("(dead) %d" % nr_stages)
            #     else:
            #         y_labels.append(str(i))
            
            # ax_circ.set_yticks(y_ticks)
            # ax_circ.set_yticklabels(y_labels, fontsize=fontsize-4)
            
            # Set white background
            ax_circ.set_facecolor('white')
            
            # Second circularity plot: closest stage to end stage (dead stage)
            circularity_plot_idx2 = n_layers_to_plot + 1
            ax_circ2 = axes[circularity_plot_idx2 // n_cols, circularity_plot_idx2 % n_cols]
            
            # For each disease and layer, find the stage closest to the end stage (excluding penultimate)
            for disease in DISEASES:
                disease_idxs = [i for i, d in enumerate(disease_labels) if d == disease]
                disease_stages_subset = [disease_stages[i] for i in disease_idxs]
                
                # Sort by disease stage to get progression order
                sorted_pairs = sorted(zip(disease_stages_subset, disease_idxs))
                sorted_idxs = [idx for _, idx in sorted_pairs]
                sorted_stages = [stage for stage, _ in sorted_pairs]
                
                # Exclude penultimate stage (second to last) when finding closest to end stage
                # We want stages 1 to (nr_stages-2), excluding the penultimate stage
                earlier_stage_indices = list(range(1, nr_stages - 2))  # Exclude penultimate (nr_stages-1)
                
                # Initialize disease results
                circularity_results['closest_to_end_stage'][disease] = {}
                
                # For each layer, find which earlier stage is closest to the end stage
                all_layers = list(range(self.num_layers))
                closest_stages_to_end = []
                for layer_idx in all_layers:
                    emb = embs[layer_idx]
                    end_emb = emb[sorted_idxs[-1]]  # Last embedding (end stage)
                    
                    # Calculate distances to all earlier stages (excluding penultimate)
                    distances_to_earlier_stages = []
                    for stage_idx in earlier_stage_indices:
                        if stage_idx < len(sorted_idxs):  # Make sure we don't go out of bounds
                            dist = np.linalg.norm(emb[sorted_idxs[stage_idx]] - end_emb)
                            # Convert 0-based index to 1-based stage number
                            stage_number = stage_idx + 1
                            distances_to_earlier_stages.append((dist, stage_number))
                    
                    # Find the closest stage
                    if distances_to_earlier_stages:
                        closest_dist, closest_stage_num = min(distances_to_earlier_stages)
                        closest_stages_to_end.append(closest_stage_num)
                        circularity_results['closest_to_end_stage'][disease][layer_idx] = closest_stage_num
                    else:
                        closest_stages_to_end.append(None)
                        circularity_results['closest_to_end_stage'][disease][layer_idx] = None
                
                # Collect all valid points for this disease
                valid_layers = []
                valid_stages = []
                
                for layer_idx, closest_stage in zip(all_layers, closest_stages_to_end):
                    if closest_stage is not None:
                        valid_layers.append(layer_idx)
                        valid_stages.append(closest_stage)
                
                if valid_layers:  # Only plot if we have valid data
                    # Plot a single line for this disease
                    marker = disease_markers[disease]
                    color = plt.cm.Set1(list(DISEASES).index(disease))
                    
                    # Plot the line connecting all points
                    ax_circ2.plot(valid_layers, valid_stages, 
                               marker=marker, color=color, linewidth=3, markersize=12,
                               label=disease, alpha=0.8)
            
            ax_circ2.set_title("Closest Stage to End Stage (1-%d)" % (nr_stages-2), fontsize=fontsize_circ)
            ax_circ2.set_xlabel("Layer", fontsize=fontsize-2)
            ax_circ2.set_ylabel("Closest Disease Stage (1-%d)" % (nr_stages-2), fontsize=fontsize_circ-2)
            ax_circ2.grid(True, ls="--", alpha=0.3)
            ax_circ2.legend(fontsize=fontsize_circ-8, loc='lower center', ncol=2)
            ax_circ2.set_ylim(0.5, nr_stages-0.5)  # Set y-axis limits to show stages 1 to (nr_stages-1)
            
            # Set white background
            ax_circ2.set_facecolor('white')

        # Hide unused subplots
        for i in range(total_subplots, n_rows*n_cols):
            axes[i // n_cols, i % n_cols].axis("off")

        sm = ScalarMappable(cmap=cmap_stage, norm=norm_stage)
        sm.set_array([])

        # ── color-bar positioning based on number of layers ──────────────────────
        from mpl_toolkits.axes_grid1.inset_locator import inset_axes  # add at top of file

        top_right_ax = axes[0][-1]
        cax = inset_axes(
                top_right_ax,
                width="60%", height="10%",          # % of parent axis - reduced width
                loc='lower center',
                bbox_to_anchor=(0, 1.2 if smallPlot else 1.3, 1, 1),   # (x, y, w, h) in axes fraction - moved up and left
                bbox_transform=top_right_ax.transAxes,
                borderpad=0)

        cbar = fig.colorbar(sm, cax=cax, orientation='horizontal')
        cbar.set_ticks([0, 1])
        cbar.set_ticklabels(['Early stage', 'Late stage'], fontsize=fontsize+1)
        cbar.ax.tick_params(labelsize=fontsize+1)

        from matplotlib.lines import Line2D
        legend_handles = [
            Line2D([0], [0], marker=disease_markers[d], color='w', markerfacecolor='gray',
                   markeredgecolor='k', markersize=15, label=d)
            for d in DISEASES
        ]
        if n_layers_to_plot < 9:
            fig.legend(handles=legend_handles, loc='upper left', bbox_to_anchor=(0.1, 0.94),
                   ncol=4, frameon=False, fontsize=fontsize)
        else:
            fig.legend(handles=legend_handles, loc='upper left', bbox_to_anchor=(0.1, 0.995),
                   ncol=4, frameon=False, fontsize=fontsize)

        if plot_title:
            fig.suptitle(f"{self.MODEL_NAME} - UMAP for Disease Progression", fontsize=fontsize+4, y=0.995)
        if n_layers_to_plot < 9:
            fig.tight_layout(rect=[0,0,0.98,0.90], pad=2.5, w_pad=2.0, h_pad=2.5)
        else:
            fig.tight_layout(rect=[0,0,0.99,0.97], pad=2, w_pad=0.1, h_pad=2)

        fig.savefig(out_file, dpi=300)
        print(f"UMAP grid saved → {out_file}")
        plt.show()
        
        return circularity_results

    def runUmap(self, load_cached=True):
        MARKERS = ['s', 'o', '*', '^']

        prompts, disease_labels, disease_stages = [], [], []

        DISEASES = ["Alzheimer's", "COVID-19", "COPD", "Parkinson's"]
        realistic_prompts = {
            "Alzheimer's": [
                "is healthy with no symptoms",
                "occasionally worries about minor forgetfulness",
                "frequently misplaces items and struggles with recent memory",
                "forgets recent conversations and repeats questions",
                "struggles significantly with language and daily tasks",
                "regularly becomes confused and has trouble recognizing family",
                "has severe cognitive decline and is unable to communicate clearly",
                "is bedridden, minimally responsive, after severe cognitive decline",
                "just died from severe cognitive decline"
            ],
            "COVID-19": [
                "is healthy with no symptoms",
                "has mild fatigue",
                "has mild fatigue, slight fever, dry cough",
                "has persistent cough, fever, mild shortness of breath",
                "has worsening respiratory distress, oxygen saturation 90%",
                "has been receiving supplemental oxygen and is in the ICU",
                "was sedated, intubated and mechanically ventilated",
                "experienced multi-organ failure and declining vital signs",
                "just died from multi-organ failure and declining vital signs"
            ],
            "COPD": [
                "is healthy with no symptoms",
                "experiences shortness of breath with strenuous exercise",
                "experiences shortness of breath when climbing stairs, mild cough",
                "is breathless during everyday tasks, chronic cough",
                "has regular flare-ups, struggles with daily tasks",
                "requires frequent hospitalizations due to breathing difficulty",
                "needs supplemental oxygen at home, has severe breathlessness",
                "experiences chronic respiratory failure, continuous oxygen dependence",
                "just died from chronic respiratory failure, continuous oxygen dependence"
            ],
            "Parkinson's": [
                "is healthy with no symptoms",
                "has mild tremors in one hand",
                "experiences stiffness and slowness of movement",
                "has mild difficulty with balance and coordination",
                "has moderate difficulty with balance and coordination",
                "has severe tremors and muscle rigidity",
                "is unable to walk without assistance",
                "is bedridden with severe motor impairment and requires full-time care",
                "just died from severe motor impairment and complications"
            ]
        }

        cache_path = self.umapCacheFile
        if load_cached and os.path.exists(cache_path):
            with open(cache_path, 'rb') as f:
                cache = pickle.load(f)
                embs = cache.get('embs'); 
                prompts = cache.get('prompts')
                disease_labels = cache.get('disease_labels'); 
                disease_stages = cache.get('disease_stages')
                DISEASES = cache.get('DISEASES'); 
                MARKERS = cache.get('MARKERS')
        else:
            for disease, descriptions in realistic_prompts.items():
                num_stages = len(descriptions)
                for i, desc in enumerate(descriptions):
                    normalized_stage = i / (num_stages - 1)
                    text = f"Someone {desc}."
                    prompts.append(text)
                    disease_labels.append(disease)
                    disease_stages.append(normalized_stage)

            print(f"Generated {len(prompts)} prompts.")


            handles, layer_buffers = self.register_hooks()
            replies = self.run_model(prompts)

            for h in handles:
                h.remove()

            embs = self.umap_fit(layer_buffers)

            # Save cache
            cache = {
                'embs': embs,
                'prompts': prompts,
                'disease_labels': disease_labels,
                'disease_stages': disease_stages,
                'DISEASES': DISEASES,
                'MARKERS': MARKERS,
            }
            with open(cache_path, 'wb') as f:
                pickle.dump(cache, f)
            print(f"Saved UMAP cache → {cache_path}")

        # ═══════════ Plot all layers ════════════════════════════════════════════
        selected_layers = self.selectLayersForPlots(nrPlots=40)
        out_file = f"results/{self.analysis_name}_umap_{self.MODEL_NAME}.pdf"
        circularity_results = self.plot_umap(embs, disease_stages, disease_labels, DISEASES, MARKERS, out_file, circularity_subplot=True, plot_title=False, layer_list=selected_layers)
        
        # ═══════════ Plot selected layers ════════════════════════════════════════════════
        selected_layers = self.selectLayersForPlots(nrPlots=4)
        print(f"Selected layers: {selected_layers}")
        out_file_selected = f"results/{self.analysis_name}_umap_{self.MODEL_NAME}_selected.pdf"
        self.plot_umap(embs, disease_stages, disease_labels, DISEASES, MARKERS, out_file_selected, layer_list=selected_layers, nr_cols=3, fontsize=16, plot_title=False, circularity_subplot=True)
        
        # ═══════════ Save circularity results to JSON ════════════════════════════════════════════
        import json
        circularity_file = f"results/{self.analysis_name}_circularity_{self.MODEL_NAME}.json"
        with open(circularity_file, 'w') as f:
            json.dump(circularity_results, f, indent=2)
        print(f"Circularity results saved → {circularity_file}")
        
