import os
from modules import constants
os.environ["HF_HOME"] = os.path.join(constants.DATA_DIR, "raw")
import torch
import pandas as pd

from sae_lens import SAE
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification

from transformer_lens import HookedTransformer

import argparse

from modules.helpers import shuffle_ngrams
import pickle
from functools import partial
import numpy as np


# argparser = argparse.ArgumentParser()
# argparser.add_argument(
#     "--ngram",
#     type=str,
#     default="normal",
#     help="shuffle to use",
# )
# args = argparser.parse_args()
# ngram = args.ngram
ngram = "normal"


model = HookedTransformer.from_pretrained("gemma-2b-it", device="cuda")

hhem = AutoModelForSequenceClassification.from_pretrained("vectara/hallucination_evaluation_model", trust_remote_code=True, device_map="cuda")

output_dir = os.path.join(constants.DATA_DIR, "interim", "sae_activations", f"{ngram}")
os.makedirs(output_dir, exist_ok=True)

source_df = pd.read_csv(
    os.path.join(output_dir, "gemma-2b-it", "source_text.csv"),
    index_col=0,
)

normal_hall_df = pd.read_csv(
    os.path.join(output_dir, "summary-gemma-2b-it-FIXED.csv"),
)

prompt = "You are a chat bot answering questions using data. You must stick to the answers provided solely by the text in the passage provided. \
You are asked the question 'Provide a concise summary of the following passage, covering the core pieces of information described.' <PASSAGE>\n"


ngrams = ["normal", 1, 6, 10, 30]
output_list = []
hall_score_list = []

sae, cfg_dict, sparsity = SAE.from_pretrained(
            release = "gemma-2b-res-jb", # see other options in sae_lens/pretrained_saes.yaml
            sae_id = f"blocks.10.hook_resid_post", # won't always be a hook point
            device = "cuda"
        )

def prompt_only_patch(patched_tensor: torch.Tensor):
    """
    Replace the activation **only** while the full prompt is flowing
    through the model (seq_len > 1). After that we let the network
    use its own activations for the 1‑token generation steps.
    """
    def _patch(act, hook):
        return patched_tensor if act.size(1) > 1 else act
    return _patch


#%%
top_features = np.load(
    os.path.join(constants.DATA_DIR, "interim", "sae_activations", f"{ngram}", "pls_vip_gemma-2b-it.npy"),
    allow_pickle=True
)

hall_score_pairs = []
for i in range(0, source_df.shape[0]):
    normal_hall_score = normal_hall_df.iloc[i]["hall_score"]
    # Check if less than bottom quartile
    if normal_hall_score < normal_hall_df["hall_score"].quantile(0.25):
    # if normal_hall_score < 1:

        with torch.no_grad():
            print(f"Processing {i} of {source_df.shape[0]}")
            source = source_df["source"][i]
            if ngram == "normal":
                new_prompt = prompt.replace("<PASSAGE>", source)
            else:
                ngram = int(ngram)
                shuffled_passage = shuffle_ngrams(source, n=ngram)
                new_prompt = prompt.replace("<PASSAGE>", shuffled_passage)
            
            tokens = model.to_tokens(new_prompt)

            _, cache = model.run_with_cache(
                                tokens,
                                prepend_bos=True,
                                names_filter=[sae.cfg.hook_name]
                            )
            feature_acts = sae.encode(cache[sae.cfg.hook_name])

            non_suppressed_recon_cache = sae.decode(feature_acts)
            diff_recon_cache = non_suppressed_recon_cache - cache[sae.cfg.hook_name]


            for b in range(feature_acts.shape[0]):          # usually B == 1
                for f in top_features:                      # usually length 1
                    feature_acts[b, :, f] = 0       # suppress just those slots
            suppressed_recon_cache = sae.decode(feature_acts)

            suppressed_base_cache = suppressed_recon_cache - diff_recon_cache


            # patch_fn = prompt_only_patch(cache[sae.cfg.hook_name])
            # patch_fn = prompt_only_patch(suppressed_recon_cache)
            patch_fn = prompt_only_patch(suppressed_base_cache)

            output_normal = model.generate(new_prompt, temperature=0, max_new_tokens=300, top_p=1)

            with model.hooks(fwd_hooks=[(sae.cfg.hook_name, patch_fn)],
                                reset_hooks_end=True), torch.no_grad():

                output = model.generate(new_prompt, temperature=0, max_new_tokens=300, top_p=1)
            
            output_list.append(output)

            # Get the hallucination score
            pair = [(source, output.split(new_prompt)[-1])]
            hall_score = hhem.predict(pair)[0]
            hall_score_list.append(hall_score.item())
            print(f"Hallucination score: {hall_score.item()}")
            # print(output.split(new_prompt)[-1])

            hall_score_pairs.append(
                {
                    "before": normal_hall_score,
                    "after": hall_score.item(),
                    "source": source,
                    "summary": output.split(new_prompt)[-1],
                    "summary_normal": output_normal.split(new_prompt)[-1],
                }
            )

#%%

# Get average difference between before and after
hall_score_pairs_df = pd.DataFrame(hall_score_pairs)
hall_score_pairs_df["diff"] = hall_score_pairs_df["after"] - hall_score_pairs_df["before"]

import matplotlib.pyplot as plt
# Overlap before and after histograms
plt.hist(hall_score_pairs_df["before"], bins=10, alpha=0.5, label="Before")
plt.hist(hall_score_pairs_df["after"], bins=10, alpha=0.5, label="After")
plt.xlabel("Hallucination Score")
plt.ylabel("Count")
plt.title("Hallucination Score Before and After Suppression")
plt.legend()
plt.show()

# Draw a kdeplot of the difference
plt.hist(hall_score_pairs_df["diff"], bins=10, alpha=0.5, label="Diff")
plt.xlabel("Hallucination Score Difference")
plt.ylabel("Count")
plt.title("Hallucination Score Difference")
plt.legend()
plt.show()

#%%
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import font_manager
import numpy as np

# Set font as dejavu sans
mpl.rcParams['font.family'] = 'DejaVu Sans'
mpl.rcParams['font.sans-serif'] = 'DejaVu Sans'

# Select data
subset = hall_score_pairs_df[hall_score_pairs_df["before"] < hall_score_pairs_df["before"].median()]
before_scores = 1 - subset["before"]
after_scores = 1 - subset["after"]

# Shared bins
all_scores = pd.concat([before_scores, after_scores])
bins = np.histogram_bin_edges(all_scores, bins=10)

# Plot
fig, ax = plt.subplots()
fig.set_size_inches(6, 4)
ax.hist(before_scores, bins=bins, alpha=0.7, color=plt.get_cmap("Greys")(0.6), label="Before")
ax.hist(after_scores, bins=bins, alpha=0.7, color=plt.get_cmap("Blues")(0.6), label="After")

# Clean spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Labels
ax.set_xlabel("Gemma 2B-IT Hallucination Score", size=11*1.3)
ax.set_ylabel("Number of Samples", size=11*1.3)

# Add legend title
legend = ax.legend(title="Suppression", loc='upper left', title_fontsize='13', fontsize='11', frameon=False)

# Add vertical line for mean
mean_before = 1 - subset["before"].mean()
mean_after = 1 - subset["after"].mean()
ax.axvline(mean_before, color=plt.get_cmap("Greys")(0.6), linestyle='--', label='Median Before')
ax.axvline(mean_after, color=plt.get_cmap("Blues")(0.6), linestyle='--', label='Median After')

# Calculate histogram heights for accurate y position
hist_vals_before, _ = np.histogram(before_scores, bins=bins)
hist_vals_after, _ = np.histogram(after_scores, bins=bins)

# Get bar tops near each mean
bin_centers = (bins[:-1] + bins[1:]) / 2
before_idx = np.argmin(np.abs(bin_centers - mean_before))
after_idx = np.argmin(np.abs(bin_centers - mean_after))
arrow_y = max(hist_vals_before[before_idx], hist_vals_after[after_idx])

# Convert to display coordinates for annotation placement
ax_y = arrow_y + 1  # small buffer above bar top

# Arrow between means
arrow_props = dict(arrowstyle="<->", color="black", lw=1.2)
ax.annotate(
    "", 
    xy=(mean_before, ax_y), 
    xytext=(mean_after, ax_y), 
    arrowprops=arrow_props
)

# Set axis labels and ticks to size 11
ax.tick_params(axis='both', which='major', labelsize=13)
ax.tick_params(axis='both', which='minor', labelsize=13)

# Annotate mean difference just above the arrow
mean_diff = mean_after - mean_before
ax.text(
    (mean_before + mean_after) / 2, 
    ax_y + 1,  # a bit above the arrow
    f"Δ = {-mean_diff:.2f}", 
    ha="center", 
    va="bottom", 
    fontsize=12
)

plt.tight_layout()
plt.show()

# Save the figure as SVG (high res)
fig.savefig(os.path.join(constants.VIZ_DIR, f"hall_score_diff.svg"), format='svg', dpi=300)

#%%
