import torch
from tqdm import tqdm
import numpy as np
from transformer_lens.utils import sample_logits
import logging
import functools
from stitching.losses import get_ignore_mask, get_all_special_tokens

class BaseSAE(torch.nn.Module):
    def __init__(self, W_enc, W_dec, b_enc, b_dec, activation_fn, requires_grad = False, apply_b_dec=True):
        super().__init__()
        self.W_enc = torch.nn.Parameter(W_enc, requires_grad=requires_grad)
        self.W_dec = torch.nn.Parameter(W_dec, requires_grad=requires_grad)
        self.b_enc = torch.nn.Parameter(b_enc, requires_grad=requires_grad)
        self.b_dec = torch.nn.Parameter(b_dec, requires_grad=requires_grad)
        self.d_sae = len(b_enc)
        self.activation_fn = activation_fn
        self.apply_b_dec = apply_b_dec
        self.requires_grad = requires_grad
        self.dtype = self.W_enc.dtype
    def from_sae(self, sae):
        super().__init__()
        self.W_enc = torch.nn.Parameter(sae.W_enc.detach().clone(), requires_grad=False)
        self.W_dec = torch.nn.Parameter(sae.W_dec.detach().clone(), requires_grad=False)
        self.b_enc = torch.nn.Parameter(sae.b_enc.detach().clone(), requires_grad=False)
        self.b_dec = torch.nn.Parameter(sae.b_dec.detach().clone(), requires_grad=False)
        self.d_sae = len(self.b_enc)
        self.activation_fn = sae.activation_fn
        self.apply_b_dec = sae.cfg.apply_b_dec_to_input
        self.requires_grad = False
    def forward(self, x, ablate_subset=[]):
        return self.decode(self.encode(x, ablate_subset=ablate_subset), ablate_subset=ablate_subset)
    def get_preacts(self, x, ablate_subset=[]):
        mask = np.ones(self.d_sae, dtype=bool)
        mask[ablate_subset] = False
        if self.apply_b_dec:
            x = (x - self.b_dec)
        return x @ self.W_enc[:, mask] + self.b_enc[mask]
    def encode(self, x, ablate_subset=[]):
        return self.activation_fn(self.get_preacts(x, ablate_subset=ablate_subset))
    def decode(self, x, ablate_subset=[]):
        mask = np.ones(self.d_sae, dtype=bool)
        mask[ablate_subset] = False
        return x @ self.W_dec[mask, :] + self.b_dec

    def expanded_decode(self, x):
        x = x.unsqueeze(-1) # (batch, d_sae, 1)
        return x * self.W_dec # (batch, d_sae, d_out)
    
    def normalize_decoder_vectors(self) -> None:
        norms = torch.linalg.norm(self.W_dec, axis=-1, keepdims=True).detach()
        self.W_dec.data = self.W_dec.data.detach() / norms
        self.W_enc.data = self.W_enc.data.detach() * (norms.flatten())
        self.b_enc.data = self.b_enc.data * (norms.flatten())
    def get_rid_of_decoder_sub(self) -> None:
        # only do this if previously the apply_b_dec was True
        if not(self.apply_b_dec):
            logging.warning("Can only get rid of decoder sub if apply_b_dec is True. Have you already called this once?\nReturning and doing nothing.")
            return
        self.b_enc = torch.nn.Parameter(self.b_enc.data - self.b_dec.data @ self.W_enc.data, requires_grad=self.requires_grad)
        self.apply_b_dec = False

def topk_activation(latents, k):
    values, indices = torch.nn.functional.relu(latents).topk(k, sorted=False)
    activations = torch.zeros_like(latents)
    activations.scatter_(-1, indices, values)
    return activations
def jumprelu_activation(latents, thresholds):
    return torch.nn.functional.relu(latents) * (latents > thresholds)

def convert_eleuther_sae_to_BaseSAE(sae):
    baseSAE = BaseSAE(
        sae.encoder.weight.T.detach().clone(),
        sae.W_dec.detach().clone(),
        sae.encoder.bias.detach().clone(),
        sae.b_dec.detach().clone(),
        activation_fn=functools.partial(topk_activation, k=sae.cfg.k),
        apply_b_dec=True
    )
    return baseSAE

@torch.inference_mode()
def forward_modified(x, model, layer, steer, omega, preserve_norm = False, method='steer',sae=None):
    # applies right before layer
    logits = model.forward(x, stop_at_layer=layer)

    if preserve_norm:
        orig_norms = torch.linalg.norm(logits, axis=-1, keepdims=True)
    if method == 'steer':
        logits += omega * steer
    elif method == 'clamp':
        # project logits onto steer.
        # steer is a unit vector (512,)
        # logits = (1, seq, 512)
        projected_values = logits @ steer.reshape(-1, 1)
        logits -= projected_values * steer
        logits += omega * steer
    elif method == 'sae_clamp':
        if sae is None:
            raise ValueError("Sae must not be None to use method sae_clamp")
        if not(isinstance(steer, int)):
            raise ValueError("Steer should be the feature index when using method sae_clamp")

        postacts = sae.encode(logits)
        error = logits - sae.decode(postacts)
        postacts[..., steer] = omega #torch.maximum(postacts[..., steer], omega)
        logits = error + sae.decode(postacts)
        
    if preserve_norm:
        logits *= orig_norms / torch.linalg.norm(logits, axis=-1, keepdims=True)  # I don't know if this is necessary actually...
    return model.forward(logits, start_at_layer=layer)

@torch.inference_mode()
def generate_modified(
    x,
    model,
    layer,
    steer,
    omega,
    max_new_tokens,
    method=None,
    stop_at_eos=True,
    temperature=1.0,
    preserve_norm=False,
    do_sample=True,
    verbose=True,
    top_k = None,
    top_p = None,
    return_type='str',
    stop_tokens = None,
    sae=None
):
    if isinstance(x, str):
        tokens = model.to_tokens(x, prepend_bos=True)
    else:
        return_type = 'tokens'
    if stop_tokens is None:
        stop_tokens = torch.tensor(model.tokenizer.eos_token_id)

    for index in tqdm(range(max_new_tokens), disable=not verbose):
        if method is None:
            logits = model(tokens)
        else:
            logits = forward_modified(tokens, model, layer, steer, omega, preserve_norm, method=method, sae=sae)
        final_logits = logits[:, -1, :]
        if do_sample:
            sampled_tokens = sample_logits(
                final_logits,
                top_k=top_k,
                top_p=top_p,
                temperature=temperature,
                freq_penalty=0.0,
                tokens=tokens,
            )
        else:
            sampled_tokens = final_logits.argmax(-1)
        
        finished_batches = torch.isin(sampled_tokens, stop_tokens)
        if finished_batches[0]:
            break
        tokens = torch.cat([tokens, sampled_tokens.unsqueeze(-1)], dim=-1)

    if return_type == 'str':
        return model.tokenizer.decode(tokens[0], skip_special_tokens=True)
    else:
        return tokens

@torch.inference_mode()
def gemma_generate_with_hooks(
    model,
    toks,
    max_tokens_generated: int = 64,
    fwd_hooks = [],
    verbose: bool = False,
    return_decoded=True
):
    all_toks = torch.zeros((toks.shape[0], toks.shape[1] + max_tokens_generated), dtype=torch.long, device=toks.device)
    all_toks[:, :toks.shape[1]] = toks
    p_bar = tqdm(range(max_tokens_generated)) if verbose else range(max_tokens_generated)
    with torch.no_grad():
        for i in p_bar:
            with model.hooks(fwd_hooks=fwd_hooks):
                logits = model(all_toks[:, :-max_tokens_generated + i])
                next_tokens = logits[:, -1, :].argmax(dim=-1) # greedy decoding
                if next_tokens[0] == model.tokenizer.eos_token_id or next_tokens[0] == 32007:
                    break
                if next_tokens[0] == 235292 and all_toks[0, -max_tokens_generated+i-1] == 235368:
                    all_toks[0, -max_tokens_generated+i-1] = 0
                    break
                all_toks[:,-max_tokens_generated+i] = next_tokens
    # truncate the tensor to remove padding
    all_toks = all_toks[:, :toks.shape[1] + i]
    if return_decoded:
        return model.tokenizer.batch_decode(all_toks[:, toks.shape[1]:], skip_special_tokens=True)
    else:
        return all_toks

@torch.inference_mode()
def feature_activations(dataloader, model, layer, sae, feature_subset, preacts=False, return_next_tokens=False):
    cached_activations = []
    device = next(model.parameters()).device
    spec_tokens = get_all_special_tokens(model.tokenizer)
    if return_next_tokens:
        all_next_tokens = []
        pred_next_tokens = []
    for i, sample in tqdm(enumerate(dataloader)):
        sample = sample.to(device)
        cur_tokens = sample[:, :-1]
        next_tokens = sample[:, 1:]
        logits = model(cur_tokens, stop_at_layer=layer) # (b, nseq-1, d)
        ignore_mask = get_ignore_mask(sample, spec_tokens) # (b, nseq)
        ignore_mask = torch.logical_or(ignore_mask[:, 1:], ignore_mask[:, :-1])  # (b, nseq-1) # either this token or next token is padding
        if preacts:
            activations = sae.get_preacts(logits[~ignore_mask])[..., feature_subset].cpu().numpy()
        else:
            activations = sae.encode(logits[~ignore_mask])[..., feature_subset].cpu().numpy()
        cached_activations.append(activations)
        if return_next_tokens:
            pred_next_tokens.append(model(logits, start_at_layer=layer)[~ignore_mask].argmax(dim=-1).cpu().numpy())
            all_next_tokens.append(next_tokens[~ignore_mask].cpu().numpy())
    
    if return_next_tokens:
        return np.concatenate(cached_activations, axis=0), np.concatenate(all_next_tokens), np.concatenate(pred_next_tokens)
    else:
        return np.concatenate(cached_activations, axis=0)

@torch.inference_mode()
def get_densities(dataloader, model, layer, sae, ctx_size=128):
    total_toks = 0
    device = next(model.parameters()).device
    act_sums = torch.zeros(sae.W_enc.shape[1], device=device)
    spec_tokens = get_all_special_tokens(model.tokenizer)
    for i, sample in tqdm(enumerate(dataloader)):
        sample = sample.to(device)[..., :ctx_size]
        logits = model(sample, stop_at_layer=layer).to(sae.dtype)
        acts = sae.encode(logits)
        ignore_mask = get_ignore_mask(sample, spec_tokens)
        acts = acts[~ignore_mask]
        act_sums += (acts > 0).sum(dim=0)
        total_toks += acts.shape[0]
    return (act_sums / total_toks).cpu()


# Stuff down here not really used
def precision_recall_f1(activations_true, activations_pred, axis=0):
    correct = ((activations_true > 0) & (activations_pred > 0)).sum(axis=axis)
    total_retrieved = (activations_pred > 0).sum(axis=axis)
    total_all = (activations_true > 0).sum(axis=axis)

    precision = np.divide(correct, total_retrieved, out=np.zeros_like(correct, dtype=float), where=total_retrieved > 0)
    recall = np.divide(correct, total_all, out=np.zeros_like(correct, dtype=float), where=total_all > 0)
    f1 = np.divide((2 * precision*recall), (precision + recall), out=np.zeros_like(correct, dtype=float), where=(precision + recall) > 0)
    return precision, recall, f1

@torch.inference_mode()
def identify_dead_indices(dataloader, model, layer, W_enc, b_enc):
    activated_agg = torch.zeros(b_enc.shape, dtype=torch.bool, device='cpu')
    for sample in tqdm(dataloader):
        sample = sample.to(W_enc.device)
        out = model(sample, stop_at_layer=layer)
        activated = torch.max(((out @ W_enc + b_enc).reshape(-1, b_enc.shape[-1]) > 0), dim=0).values.cpu()
        activated_agg = activated_agg | activated
    return activated_agg

@torch.inference_mode()
def max_csim_transfer_to_orig(orig_sae, transferred_sae, attr = 'decoder'):
    maxes = []
    if attr == 'decoder':
        transfer_mat = transferred_sae.W_dec
        orig_mat = orig_sae.W_dec
    elif attr == 'encoder':
        transfer_mat = transferred_sae.W_enc.mT
        orig_mat = orig_sae.W_enc.mT
    else:
        raise ValueError(f"'attr' cannot be {attr}, must be one of 'decoder' or 'encoder'")
    for i in tqdm(range(transfer_mat.shape[0])):
        transferred_feature_vec = transfer_mat[i]
        maxes.append(((orig_mat @ transferred_feature_vec) / orig_mat.norm(dim=-1) / transferred_feature_vec.norm()).cpu().numpy().max())
    return np.array(maxes)

@torch.inference_mode()
def argmax_csim_for_subset(orig_sae, transferred_sae, subset, attr = 'decoder'):
    maxes = []
    if attr == 'decoder':
        transfer_mat = transferred_sae.W_dec
        orig_mat = orig_sae.W_dec
    elif attr == 'encoder':
        transfer_mat = transferred_sae.W_enc.mT
        orig_mat = orig_sae.W_enc.mT
    else:
        raise ValueError(f"'attr' cannot be {attr}, must be one of 'decoder' or 'encoder'")
    for i in tqdm(subset):
        transferred_feature_vec = transfer_mat[i]
        maxes.append(((orig_mat @ transferred_feature_vec) / orig_mat.norm(dim=-1) / transferred_feature_vec.norm()).cpu().numpy().argmax())
    return np.array(maxes)