import os
import shutil

import jax
import jax.numpy as jnp
import numpy as np
from huggingface_hub import HfFileSystem
from safetensors.flax import load_file, save_file

sae_cache = {}

...

def sae_encode(sae, vector, **kwargs):
    if "threshold" in sae:
        return sae_encode_threshold(sae, vector, **kwargs)
    if "s_gate" in sae:
        return sae_encode_gated(sae, vector, **kwargs)
    pre_relu = vector @ sae["W_enc"] + sae["b_enc"]
    post_relu = jax.nn.relu(pre_relu)
    if "scaling_factor" in sae:
        post_relu = post_relu * sae["scaling_factor"]
    decoded = post_relu @ sae["W_dec"] + sae["b_dec"]
    return pre_relu, post_relu, decoded

def resids_to_weights(vector, sae):
    inputs = vector

    if "norm_factor" in sae:
        inputs = inputs * sae["norm_factor"]

    pre_relu = inputs @ sae["W_enc"]
    pre_relu = pre_relu +sae["b_enc"]
    post_relu = jax.nn.relu(pre_relu)
    
    post_relu = (post_relu > 0) * jax.nn.relu((inputs @ sae["W_enc"]) * jax.nn.softplus(sae["s_gate"]) * sae["scaling_factor"] + sae["b_gate"])   

    return post_relu

def weights_to_resid(weights, sae):
    weights = jax.nn.relu(weights)

    recon = jnp.einsum("fv,...f->...v", sae["W_dec"], weights)

    recon = recon + sae["b_dec"]

    if "out_norm_factor" in sae:
        recon = recon / sae["out_norm_factor"]

    # recon = recon.astype('bfloat16')
    return recon.astype(weights.dtype)

def sae_encode_threshold(sae, vector, pre_relu=None, keep_features=None, ablate_to=0, post_relu=None):
    inputs = vector


    if post_relu is None:
        if pre_relu is None:
            pre_relu = inputs @ sae["W_enc"] + sae["b_enc"]

        post_relu = jax.nn.relu(pre_relu)
        post_relu = post_relu * (pre_relu > sae["threshold"])

    if keep_features is not None:
        post_relu = post_relu * keep_features + ablate_to * (1 - keep_features)

    recon = post_relu @ sae["W_dec"] + sae["b_dec"]

    return pre_relu, post_relu, recon

def sae_encode_gated(sae, vector, ablate_features=None, keep_features=None, pre_relu=None, ablate_to=0):
    inputs = vector

    s = jax.nn.softplus(sae["s_gate"]) * sae["scaling_factor"]
    
    if pre_relu is None:
        if "norm_factor" in sae:
            inputs = inputs * sae["norm_factor"]

        pre_relu = inputs @ sae["W_enc"]
        pre_relu = pre_relu +sae["b_enc"]
        pre_relu = pre_relu * s

    if keep_features is not None:
        pre_relu = pre_relu * keep_features + ablate_to * (1 - keep_features)

    post_relu = jax.nn.relu(pre_relu)
    threshold = jnp.maximum(0, sae["b_gate"] - sae["b_enc"] * s)
    post_relu = (post_relu > threshold) * post_relu

    
        # axes = tuple(range(post_relu.ndim - 1))
        
        # post_relu = jax.vmap(jax.vmap(lambda a, b: a.at[keep_features].set(b[keep_features]), in_axes=(0, 0), out_axes=0),
        #                      in_axes=(0, 0), out_axes=0)(jnp.zeros_like(post_relu), post_relu)

    if ablate_features is not None:
        post_relu = post_relu.at[ablate_features].set(0)

    recon = post_relu @ sae["W_dec"]

    recon = recon + sae["b_dec"]
    
    if "out_norm_factor" in sae:
        recon = recon / sae["out_norm_factor"]

    return pre_relu, post_relu, recon

def name_bf16(fname):
    return fname.replace(".safetensors", "_bf16.safetensors")

def convert_to_bf16(fname, fname_out):
    shutil.move(fname, fname_out)
    save_file({k: v.astype("bfloat16") for k, v in load_file(fname_out).items()}, fname_out)
