"""
UMAP of per-layer activations 
Representations for gender M / F 

The script builds 30 short prompts that describe a male individual in diverse life and
professional contexts without mentioning specific ages or names. Each prompt contains two short sentences and uses masculine pronouns (He / his). A female version is generated by systematic
pronoun substitution. Layer-wise activations are collected, reduced with UMAP, and plotted with colour coding by gender. 
"""

# ── Imports ──────────────────────────────────────────────────────────────
import math, re, torch, umap, numpy as np, matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from transformers import AutoModelForCausalLM, AutoTokenizer
from functools import partial
import pickle
   
from act import Act
from saliency import *
from utils import runUMAPClustering
         
class GenderAct(Act):

    def __init__(self, tok, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=False):
        super().__init__(tok, model, processor, MODEL_NAME, PRINT_REPLIES, TEMPERATURE, USE_ACCELERATE=USE_ACCELERATE)
        self.analysis_name = "gender"
        self.set_files()


    def plot_umap(self, embeddings, genders, num_pairs, male_prompts, female_prompts):
        colour_map = {"M": "tab:blue", "F": "tab:orange"}

        n_cols, n_rows = 5, math.ceil(self.num_layers / 5)
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols*4, n_rows*4), squeeze=False)

        for l, emb in enumerate(embeddings):
            ax = axes[l // n_cols, l % n_cols]

            # Scatter male & female points
            for g in ("M", "F"):
                idx = [i for i, gg in enumerate(genders) if gg == g]
                ax.scatter(emb[idx, 0], emb[idx, 1],
                        c=[colour_map[g]]*len(idx),
                        marker="o", s=240, edgecolors='k', linewidths=0.2,
                        label=g if l == 0 else None)

            # Connect corresponding M–F pairs with a faint line
            for i in range(num_pairs):
                m_idx = i
                f_idx = i + num_pairs
                ax.plot([emb[m_idx, 0], emb[f_idx, 0]],
                        [emb[m_idx, 1], emb[f_idx, 1]],
                        color="grey", alpha=1, linewidth=3)

            ax.set_title(f"Layer {l}", fontsize=24)
            ax.set_xticks([]); ax.set_yticks([])
            ax.grid(True, ls="--", alpha=0.15)

        # Hide empty panels if any
        for i in range(self.num_layers, n_rows*n_cols):
            axes[i // n_cols, i % n_cols].axis("off")

        handles_leg = [plt.Line2D([], [], marker='o', color='w', markerfacecolor=colour_map[g],
                                markeredgecolor='k', markersize=25, linestyle='', label="Male" if g=="M" else "Female")
                    for g in ("M", "F")]
        fig.legend(handles=handles_leg, loc='lower center', ncol=2, frameon=False, fontsize=20)

        # ── Coloured title and prompt lines ──────────────────────────────────────
        fig.suptitle(f"{self.MODEL_NAME} - UMAP representations for gender M/F", fontsize=26, y=0.97)

        # Male prompt (blue) / vs / Female prompt (orange)
        fig.text(0.5, 0.045, 'Example male prompt: ' + male_prompts[0], ha='center', color='tab:blue', fontsize=20)
        fig.text(0.5, 0.035, 'Example female prompt: ' + female_prompts[0], ha='center', color='tab:orange', fontsize=20)

        fig.tight_layout(rect=[0, 0.06, 1, 0.96], pad=2.5, w_pad=2.0, h_pad=2.5)

        out_file = f"results/{self.analysis_name}_umap_{self.MODEL_NAME}.pdf"
        fig.savefig(out_file, dpi=300)
        print(f"\nUMAP figure saved → {out_file}")
        plt.show()


    def runUmap(self, load_cached=True):
                

        # ── 3. Build prompts (30 male → 30 female) ───────────────────────────────
        male_descriptions = [
            "He is a young boy visiting the zoo. He loves feeding the giraffes.",
            "He is a student attending elementary school. He enjoys drawing rockets.",
            "He is a teenager playing guitar with his friends. He practices every evening.",
            "He is a high‑school athlete training for a track meet. He pushes his limits daily.",
            "He is a college student studying computer science. He codes small games at night.",
            "He is a software engineer working at a startup. He likes hiking on weekends.",
            "He is a doctor treating patients in a busy hospital. He reads mystery novels.",
            "He is a firefighter responding to emergencies. He maintains his equipment diligently.",
            "He is a professor lecturing on history. He collects antique maps.",
            "He is a lawyer preparing for a trial. He meditates to stay focused.",
            "He is a chef running a popular restaurant. He experiments with new spices.",
            "He is an architect designing green buildings. He sketches in his free time.",
            "He is a farmer managing crops on a large field. He repairs his tractor himself.",
            "He is a retiree volunteering at a community centre. He plays chess with visitors.",
            "He is a patient undergoing physical therapy. He keeps a daily progress journal.",
            "He is an artist painting landscapes. He walks along rivers for inspiration.",
            "He is a grandfather telling stories to his grandchildren. He bakes cookies.",
            "He is a veteran attending a memorial ceremony. He polishes his medals.",
            "He is a boy visiting the library. He borrows comic books.",
            "He is a student learning mathematics. He solves puzzles after class.",
            "He is an intern starting his first job. He notes everything diligently.",
            "He is a paramedic assisting in emergencies. He practices first‑aid drills.",
            "He is a musician playing in a band. He writes lyrics on the bus.",
            "He is a journalist covering local news. He keeps his notebook handy.",
            "He is an engineer supervising construction. He inspects safety gear.",
            "He is a police officer patrolling the city. He greets shop owners.",
            "He is a pilot flying commercial flights. He checks weather reports carefully.",
            "He is a dentist caring for patients' teeth. He recommends flossing daily.",
            "He is a librarian organising books. He catalogs rare editions.",
            "He is a gardener cultivating roses. He monitors soil moisture closely."
        ]

        male_prompts = [f"{desc} What is he doing?" for desc in male_descriptions]

        # Pronoun substitution (case‑sensitive)
        _subs = {
            r"\bHe\b": "She",
            r"\bhe\b": "she",
            r"\bHis\b": "Her",
            r"\bhis\b": "her",
            r"\bHim\b": "Her",
            r"\bhim\b": "her",
            r"\bboy\b": "girl",
            r"\bgrandfather\b": "grandmother",
            
        }

        female_prompts = []
        for p in male_prompts:
            fp = p
            for pat, repl in _subs.items():
                fp = re.sub(pat, repl, fp)
            female_prompts.append(fp)

        prompts  = male_prompts + female_prompts
        genders  = ["M"] * len(male_prompts) + ["F"] * len(female_prompts)
        num_pairs = len(male_prompts)
        print(f"Prepared {len(prompts)} prompts (male+female).")

        # Cache path
        cache_path = self.umapCacheFile
        if load_cached and os.path.exists(cache_path):
            with open(cache_path, 'rb') as f:
                cache = pickle.load(f)
                embs = cache.get('embs'); 
                male_prompts = cache.get('male_prompts'); 
                female_prompts = cache.get('female_prompts')
                genders = cache.get('genders'); 
                num_pairs = cache.get('num_pairs'); 
                embsClustering = cache.get('embsClustering')
        else:

            # ── 4. Capture activations ───────────────────────────────────────────────
            handles, layer_buffers = self.register_hooks()

            # ── 5. Forward pass ──────────────────────────────────────────────────────
            replies = self.run_model(prompts)

            for h in handles:
                h.remove()

            # ── 6. UMAP reductions ───────────────────────────────────────────────────
            embs = self.umap_fit(layer_buffers)
            embsClustering = self.umap_fit(layer_buffers, n_components=30)

            # Save cache
            cache = {
                'embs': embs,
                'embsClustering': embsClustering,
                'male_prompts': male_prompts,
                'female_prompts': female_prompts,
                'genders': genders,
                'num_pairs': num_pairs,
            }
            with open(cache_path, 'wb') as f:
                pickle.dump(cache, f)
            print(f"Saved UMAP cache → {cache_path}")

        # ── 7. Plot grid ─────────────────────────────────────────────────────────
        self.plot_umap(embs, genders, num_pairs, male_prompts, female_prompts)

        # ── 8. Compute clustering coefficient per layer using silhouette score with confidence intervals
        int_labels = np.array([0 if g == "M" else 1 for g in genders])
        runUMAPClustering(embsClustering, int_labels, self.umapAggFile)

        

    def runSaliency(self, load_cached=True):
        """
        Computes and plots saliency for gender-specific prompts (He vs. She).
        """

        # Initialize variables
        lines_male = []
        lines_female = []
        
        # check if cache exists:
        if load_cached and os.path.exists(self.saliencyFile):
            with open(self.saliencyFile, "rb") as f:
                data = pickle.load(f)
                print(data.keys())
                avg_sal = data["avg_sal"]
                avg_sal_per_prompt = data["avg_sal_per_prompt"]
                avg_sal_per_head = data["avg_sal_per_head"]
                avg_sal_per_mlp = data["avg_sal_per_mlp"]
            print(f"Loaded cached saliency data from {self.saliencyFile}")
        else:
            # Define gender prompts
            male_prompt = "He is an artist painting landscapes. He walks along rivers for inspiration. What is he doing?"
            female_prompt = "She is an artist painting landscapes. She walks along rivers for inspiration. What is she doing?"

            # Token IDs for clamping
            he_id = self.tokenizer("he", add_special_tokens=False).input_ids[0]
            she_id = self.tokenizer("she", add_special_tokens=False).input_ids[0]
            allowed_ids = [he_id, she_id]

            # --- Male saliency ---
            # self.print_clamped_probs(male_prompt, allowed_ids)
            # self.generate_clamped(male_prompt, allowed_ids)
            male_saliency = self.saliency_for(male_prompt, "he", allowed_ids)
            print("Top saliency weights for 'he':")
            lines_male = print_top_saliencies(male_saliency, "He")

            self.plot_saliency_grid(
                self.model_layers,
                male_saliency,
                f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_he.pdf",
                f"{self.MODEL_NAME} saliency for male prompt ('he')"
            )

            # --- Female saliency ---
            # self.print_clamped_probs(female_prompt, allowed_ids)
            # self.generate_clamped(female_prompt, allowed_ids)
            female_saliency = self.saliency_for(female_prompt, "she", allowed_ids)
            print("Top saliency weights for 'she':")
            lines_female = print_top_saliencies(female_saliency, "She")

            self.plot_saliency_grid(
                self.model_layers,
                female_saliency,
                f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_she.pdf",
                f"{self.MODEL_NAME} saliency for female prompt ('she')"
            )

            all_saliency_dicts = {
                "male": male_saliency,
                "female": female_saliency
            }

            # --- Average saliency ---
            avg_sal, avg_sal_per_prompt, avg_sal_per_head, avg_sal_per_mlp = merge_saliency_dicts(all_saliency_dicts, self.model)

            # Save everything in .npz
            save_saliency(self.saliencyFile, avg_sal, avg_sal_per_prompt, 
                     avg_sal_per_head, avg_sal_per_mlp)
            print(f"Saved all saliency data → {self.saliencyFile}")

        self.plot_saliency_grid(
            self.model_layers,
            avg_sal,
            f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}_avg.pdf",
            f"{self.MODEL_NAME} average saliency for gender ('he' vs 'she')"
        )

        # --- Save results ---
        lines_avg = print_top_saliencies(avg_sal, "Avg")
        with open(f"results/{self.analysis_name}_saliency_{self.MODEL_NAME}.txt", "w") as f:
            # Only write individual gender saliencies if we computed them (not from cache)
            if lines_male and lines_female:
                f.write("\n".join(lines_male + lines_female + lines_avg))
            else:
                f.write("\n".join(lines_avg))

        print(f"Saved saliency report → results/{self.analysis_name}_saliency_{self.MODEL_NAME}.txt")

        # Plot saliency table
        table_path = f"results/{self.analysis_name}_2Dsaliency_{self.MODEL_NAME}_avg.pdf"
        self.plot_saliency_2D(
            avg_sal_per_head, avg_sal_per_mlp, table_path,
            f"{self.MODEL_NAME} saliency table for gender ('he' vs 'she')"
        )
        print(f"Plotted saliency table → {table_path}")