import torch
import seaborn as sns
import numpy as np
from tqdm import tqdm

# get activations on text
@torch.inference_mode()
def get_activations_for_feature(dataloader, feature_idx, model, layer, sae):
    all_toks = []
    all_acts = []
    device = next(model.parameters()).device
    for i, sample in tqdm(enumerate(dataloader)):
        sample.to(device)
        logits = model(sample, stop_at_layer=layer)
        acts = sae.encode(logits)
        ignore_mask = (sample == 0)
        ignore_mask[:,0] = 0
        acts = acts[~ignore_mask]
        batch_toks = sample[~ignore_mask]
        feature_acts = acts[..., feature_idx].flatten().cpu().numpy()
        all_acts.append(feature_acts)
        all_toks.append(batch_toks.numpy())
    all_acts = np.concatenate(all_acts)
    all_toks = np.concatenate(all_toks)
    return all_acts, all_toks

def highlight_tokens(tokens, activations, verbose=False):
    # Normalize activation values to be between 0 and 1
    min_activation, max_activation = min(activations), max(activations)
    if max_activation == 0:
        print("Feature dead on sample")
        return
    normalized_activations = [(a - min_activation) / (max_activation - min_activation) for a in activations]

    # Define color range (from light to intense red)
    for token, norm_activation in zip(tokens, normalized_activations):
        # Scale normalized activation to 0-255 for color intensity
        token = token.replace("\n", "↵").replace("\t", "→")
        color_intensity = int(255 * norm_activation)
        # Print token with color based on activation (intensity of red)
        print(f"\033[38;2;{color_intensity};0;0m{token}\033[0m", end='')
    print()
    if verbose:
        print(tokens)
        print(normalized_activations)


def display_acts(toks, acts, model, k=10, ctx=10, upper_cap=None, display_density=False, verbose=False):
    indices = acts.argsort()[::-1]
    printed = 0
    for index in indices:
        if not(upper_cap is None) and acts[index] > upper_cap:
            continue
        if printed >= k:
            break
        l = max(0, index - ctx)
        r = min(len(toks), index + ctx)
        print(f"{acts[index]:.2f}")
        highlight_tokens(model.to_str_tokens(toks[l:r]), acts[l:r], verbose=verbose)
        printed += 1
    if display_density:
        ax = sns.histplot(acts[acts > 0])
        ax.set(title = f"density {((acts>0).sum() * 100/ len(toks)):.4f}%")

# get activations on text
@torch.inference_mode()
def get_top_features(sample, model, layer, sae, k, measurement='max'):
    # mark out max activation
    device = next(model.parameters()).device
    sample = sample.to(device)
    logits = model(sample, stop_at_layer=layer)
    acts = sae.encode(logits)
    ignore_mask = (sample == 0)
    acts = acts[~ignore_mask]

    # we want to get 
    if measurement == 'max':
        over_sample_top_features = acts.max(dim=0).values
    elif measurement == 'mean':
        over_sample_top_features = acts.mean(dim=0)
    
    over_sample_top_features = over_sample_top_features.argsort(dim=-1, descending=True)[...,:k]
    activations_over_text = acts[..., over_sample_top_features]
    
    return over_sample_top_features.cpu().numpy(), activations_over_text.cpu().numpy(), sample[~ignore_mask].cpu()