BASE_DIR = "/mnt/sda1/jengels/gemma_2b_sae_scaling"
import torch
from torch import nn
import os
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from matplotlib import cm
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import mean_squared_error, r2_score, accuracy_score
from sae_lens import SAE
import pickle

d_sae = 16384

def get_elephants_top_k(freqs,k=10):
    return torch.argsort(freqs,descending=True)[:k]

def get_elephants_thres(freqs,thres=0.1):
    indices = torch.nonzero(freqs >= thres, as_tuple=True)[0]
    sorted_indices = indices[torch.argsort(freqs[indices], descending=True)]
    return sorted_indices


def print_elephants_with_pairs(freqs, saes, thres, pair_sim=-0.85, pair_thres=0, to_print=False, flat=True):
    printed = []
    elephants = get_elephants_thres(freqs, thres)
    cosine_sim = cosine_similarity(saes, saes)
    pairs = set()

    for elephant in elephants:
        elephant = elephant.item()
        sim_row = cosine_sim[elephant]
        min_cosine_sim = np.min(sim_row)

        if min_cosine_sim < pair_sim:
            opp = np.argmin(sim_row).item()
            if (opp, elephant) in pairs or (elephant, opp) in pairs:
                continue

            freq1 = freqs[elephant]
            freq2 = freqs[opp]
            total_freq = freq1 + freq2

            if total_freq > pair_thres:
                pairs.add((elephant, opp))
                if to_print:
                    print(f'#{elephant}, f={freq1:.3f} and #{opp}, f={freq2:.3f}. sim={min_cosine_sim:.3f}, total f={total_freq:.3f}')
                if flat:
                    printed.extend([elephant, opp])
                else:
                    printed.append((elephant, opp))
            else:
                if to_print:
                    print(f'#{elephant}, f={freq1:.3f} and #{opp}, f={freq2:.3f}. sim={min_cosine_sim:.3f}, total f={total_freq:.3f} LOW')
        else:
            if to_print:
                print(f'#{elephant}, f={freqs[elephant]:.3f}, NO PAIR, min sim={min_cosine_sim:.3f}')
            if flat:
                printed.append(elephant)
            else:
                printed.append((elephant,))

    return printed

def load_sae_lens(sae_id, model_name, layer_type="res"):
    if model_name == "gemma_2_2b":
        sae = SAE.from_pretrained(
            release = f"gemma-scope-2b-pt-{layer_type}-canonical",
            sae_id = sae_id,
            device = "cpu",
        )[0]
    else:
        raise ValueError
    return sae

def get_mydata(layer,freqs=False,saes=False,elephant_acts=False,proj_data=False,all_data=False):
    result = []
    if freqs:
        result.append(torch.load(f'gemmascope/layer{layer}_freqs.pt',weights_only=True))
    if saes:
        result.append(load_sae_lens(f'layer_{layer}/width_16k/canonical','gemma_2_2b').W_dec.detach().numpy())
    if elephant_acts:
        with open(f'gemmascope/layer{layer}_elephant_acts_0-5000.pkl','rb') as f:
            elephant_acts_dict = pickle.load(f)
        result.append(elephant_acts_dict)
    if proj_data:
        proj_tensor = torch.load(f'gemmascope/temp/layer{layer}_proj.pt',weights_only=True)
        result.append(proj_tensor)
    if all_data:
        with open(f'gemmascope/layer{layer}_data_0-5000.pkl','rb') as f:
            all_data = pickle.load(f)
        result.append(all_data)
    return result


def load_pickle(path):
    with open(path, "rb") as f:
        return pickle.load(f)

class JumpReLUSAE(nn.Module):
    def __init__(self, d_model, d_sae):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.threshold = nn.Parameter(torch.zeros(d_sae))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

    def encode(self, input_acts):
        pre_acts = input_acts @ self.W_enc + self.b_enc
        mask = (pre_acts > self.threshold)
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec

    def forward(self, acts):
        acts = self.encode(acts)
        recon = self.decode(acts)
        return recon

class TopKSAE(nn.Module):
    def __init__(self, d_model, d_sae, k):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_sae, d_model))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.b_dec = nn.Parameter(torch.zeros(d_model))
        self.k = k

    def pre_acts(self,x):
        sae_in = x - self.b_dec
        pre_acts = sae_in @ self.W_enc.T + self.b_enc
        acts = torch.nn.functional.relu(pre_acts)[0]
        return acts

    def encode(self, x):
        sae_in = x - self.b_dec
        pre_acts = sae_in @ self.W_enc.T + self.b_enc
        acts = torch.nn.functional.relu(pre_acts)[0]
        topk_acts = acts.topk(self.k,sorted=False)
        return topk_acts
 
def get_histogram(values,bins=50,density=False):
    hist_values, hist_edges = np.histogram(values,bins=bins,density=density)
    hist_centers = (hist_edges[1:] + hist_edges[:-1])/2
    return hist_values, hist_centers


def linear_probe(X,Y):
    X_train, X_test, Y_train, Y_test = train_test_split(X,Y,test_size=0.2,random_state=42)
    model = LinearRegression()
    model.fit(X_train,Y_train)
    Y_pred = model.predict(X_test)
    r2 = r2_score(Y_test,Y_pred)
    coefs = model.coef_
    intercept = model.intercept_
    return model, r2, coefs, intercept

def linear_class_probe(X,Y):
    X_train, X_test, Y_train, Y_test = train_test_split(X,Y,test_size=0.3,random_state=42)
    model = LogisticRegression(max_iter=1000)
    model.fit(X_train,Y_train)
    Y_pred = model.predict(X_test)
    acc = accuracy_score(Y_test, Y_pred)
    coefs = model.coef_
    intercept = model.intercept_
    return model, acc, coefs, intercept

def get_principal_angles(vectors1, vectors2):
    '''
    n x d arrays
    '''
    Q_A, _ = np.linalg.qr(vectors1.T)
    Q_B, _ = np.linalg.qr(vectors2.T)

    M = Q_A.T @ Q_B
    sigma = np.linalg.svd(M, compute_uv=False)
    principal_angles = np.arccos(np.clip(sigma, -1.0, 1.0))
    similarity = np.sum(sigma**2) / len(sigma)
    return np.rad2deg(principal_angles), similarity

def clean_strings(string_list, mode):
    if mode == 'space':
        return [s.replace(" ", "") for s in string_list]
    if mode == 'space&lower':
        return [s.replace(" ", "").lower() for s in string_list]

def get_closest_words(word,wvmodel,n=5):
    if word in wvmodel.wv:
        similar_words = wvmodel.wv.most_similar(word,topn=n)
        print(f"Top 5 words most similar to '{word}':")
        for w, sim in similar_words:
            print(f"  {w}: {sim:.4f}")
    else:
        print(f"'{word}' not in the vocabulary.")

def hex_to_rgb(hex_color):
    """
    Convert a hex color string (e.g., '#ADD8E6') to an (R,G,B) tuple in [0,1].
    """
    hex_color = hex_color.lstrip('#')
    r = int(hex_color[0:2], 16)
    g = int(hex_color[2:4], 16)
    b = int(hex_color[4:6], 16)
    return (r / 255.0, g / 255.0, b / 255.0)

def rgb_to_hex(rgb):
    """
    Convert an (R,G,B) tuple in [0,1] to a hex string '#RRGGBB'.
    """
    r = int(rgb[0] * 255)
    g = int(rgb[1] * 255)
    b = int(rgb[2] * 255)
    return f"#{r:02X}{g:02X}{b:02X}"

def blend_color(alpha, base_color="#FFFFFF", highlight_color="#ADD8E6"):
    """
    Linearly blend between base_color and highlight_color based on alpha in [0,1].
    alpha=0 -> base color
    alpha=1 -> highlight color
    """
    base_rgb = hex_to_rgb(base_color)
    high_rgb = hex_to_rgb(highlight_color)
    
    blend_r = (1 - alpha) * base_rgb[0] + alpha * high_rgb[0]
    blend_g = (1 - alpha) * base_rgb[1] + alpha * high_rgb[1]
    blend_b = (1 - alpha) * base_rgb[2] + alpha * high_rgb[2]
    
    return rgb_to_hex((blend_r, blend_g, blend_b))

def highlight_tokens_scaled_html(tokens, acts1, acts2=None,
                                color1="#0000FF",    # pure blue for Feature 1
                                color2="#FF0000",    # pure red for Feature 2
                                fixed_width=600):
    assert len(tokens) == len(acts1), "Tokens and acts1 must have the same length."
    if acts2 is not None:
        assert len(tokens) == len(acts2), "Tokens and acts2 must have the same length."

    max_act1 = max(1e-9, max(acts1))
    max_act2 = max(1e-9, max(acts2)) if acts2 is not None else 1e-9 

    html_parts = []
    html_parts.append("<!DOCTYPE html>")
    html_parts.append("<html>")
    html_parts.append("<head>")
    html_parts.append('<meta charset="utf-8">')
    html_parts.append(f"""
    <style>
    /* Fix container width and ensure wrapping within it */
    .highlight-container {{
        width: {fixed_width}px;
        white-space: pre-wrap;  /* wraps text while preserving whitespace/newlines */
        font-family: 'Open Sans', sans-serif;
        font-size: 13px;
        border: 1px solid #CCC; /* optional border */
        padding: 10px;          /* optional padding */
        margin: 0 auto;         /* center horizontally, if you like */
        background-color: #FFFFFF;
    }}
    </style>
    """)
    html_parts.append("</head>")
    html_parts.append("<body>")
    
    html_parts.append('<div class="highlight-container">')
    
    for i, token in enumerate(tokens):
        a1 = acts1[i]
        a2 = acts2[i] if acts2 is not None else 0 

        alpha1 = min(1.0, max(0.0, a1 / max_act1)) if max_act1 > 0 else 0
        alpha2 = min(1.0, max(0.0, a2 / max_act2)) if acts2 is not None and max_act2 > 0 else 0

        if acts2 is None:
            bg_color = blend_color(alpha1, "#FFFFFF", color1)
        else:
            bg_color = blend_color(alpha1, "#FFFFFF", color1) if alpha1 > alpha2 else blend_color(alpha2, "#FFFFFF", color2)
        
        token_escaped = (token
                         .replace("&", "&amp;")
                         .replace("<", "&lt;")
                         .replace(">", "&gt;"))
        
        if token_escaped == "\n":
            span_html = f'<span style="background-color: {bg_color}; color:black;">\\n</span><br/>'
        else:
            span_html = f'<span style="background-color: {bg_color}; color:black;">{token_escaped}</span>'
        
        html_parts.append(span_html)
    
    html_parts.append('</div>')
    html_parts.append("</body>")
    html_parts.append("</html>")
    
    return "".join(html_parts)

def scale_acts(act_vector):
    return act_vector/act_vector.max()

def correlation_matrix(data1, data2, binary=False):
    '''
    data are N x 1
    '''
    if binary:
        data1 = (data1>0).astype(int)
        data2 = (data2>0).astype(int)
    data1 = torch.tensor(data1,device=device,dtype=torch.float32)
    data2 = torch.tensor(data2,device=device,dtype=torch.float32)
    data1 = data1 - torch.mean(data1, axis=1,keepdims=True)
    data2 = data2 - torch.mean(data2, axis=1,keepdims=True)
    std1 = torch.linalg.norm(data1, axis=1, keepdims=True)
    std2 = torch.linalg.norm(data2, axis=1, keepdims=True)
    corr_matrix = ((data1 @ data2.T) / (std1 @ std2.T))
    return corr_matrix.cpu().numpy()

def distance_to_previous_one(arr, context_length=1024):
    N = len(arr)
    proximity_backward = np.full(N, np.nan)
    dist = None
    for i in range(N):
        if i % context_length == 0:
            dist = None

        if arr[i] == 1:
            dist = 0
            proximity_backward[i] = 0
        else:
            if dist is not None:
                dist += 1
                proximity_backward[i] = dist
            else:
                proximity_backward[i] = np.nan

    return proximity_backward

def project_subspace(data,subspace_vecs):
    """
    subspace_vecs is n x d
    data is also n x d
    """
    q, _ = np.linalg.qr(subspace_vecs.T)
    p = q @ q.T
    return data @ p

def vec_in_subspace_norm(vector,subspace_vecs):
    projected = project_subspace(vector.reshape(1,-1),subspace_vecs)
    return np.linalg.norm(projected) / np.linalg.norm(vector)

def vecs_in_subspace_norm(vectors,subspace_vecs):
    projecteds = project_subspace(vectors,subspace_vecs)
    return np.linalg.norm(projecteds,axis=1) / np.linalg.norm(vectors,axis=1)

def subspace_overlap(subspace_vecs1, subspace_vecs2):
    q1, _ = np.linalg.qr(subspace_vecs1.T)
    q2, _ = np.linalg.qr(subspace_vecs2.T)
    M = q1.T @ q2
    s = np.linalg.svd(M,compute_uv=False)
    return np.sum(s**2) / len(s)

def remove_bos_acts(act_vector,context_length=1024):
    act_vector[::context_length] = 0
    return act_vector

def get_acts_contextj(act_vector,j,context_length=1024):
    start_idx = context_length * j
    end_idx = context_length * (j+1)
    return act_vector[start_idx:end_idx]
