import os, sys

os.environ["EINOPS_BACKEND"] = "jax"
sys.modules.pop("tensorflow", None) 

import einops

if hasattr(einops, "set_default_backend"):
    einops.set_default_backend("jax")

from einops.array_api import rearrange

from functools import partial

import jax
import jax.numpy as jnp

from .hooks import HookType

has_cuda = jax.default_backend() == "gpu"
print(f"Cuda processing allowed: {has_cuda}")


def rms_norm(x, weight, eps=1e-6):
    dtype = x.dtype
    x = x.astype(jnp.float32)
    normed = x * jax.lax.rsqrt(jnp.mean(x**2, axis=-1, keepdims=True) + eps)
    out = weight * normed.astype(dtype)
    return out


def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return jnp.concatenate((-x2, x1), axis=-1)


def rope(x, sin, cos):
    if x.ndim == 4 and sin.ndim == 3:
        sin = sin[:, :, None, :]
        cos = cos[:, :, None, :]
    elif x.ndim > sin.ndim and x.shape[-1] == sin.shape[-1]:
        num_broadcast_dims = x.ndim - sin.ndim
        new_shape = list(sin.shape)
        for _ in range(num_broadcast_dims):
            new_shape.insert(-1, 1)
        sin = jnp.reshape(sin, new_shape)
        cos = jnp.reshape(cos, new_shape)
        if sin.shape[:-1] != x.shape[:-1] or cos.shape[:-1] != x.shape[:-1]:
            try:
                sin = sin[..., None, :]
                cos = cos[..., None, :]
            except IndexError:
                raise ValueError(
                    f"Cannot broadcast sin/cos shapes {sin.shape} to x shape {x.shape}"
                )

    rotated_x = (x * cos) + (rotate_half(x) * sin)
    return rotated_x.astype(x.dtype)


@partial(jax.vmap, in_axes=(0, 0, 0))
def update_dynamic(arr, update, start_idx):
    arr_update = jax.lax.dynamic_update_slice_in_dim(arr, update, start_idx, axis=0)
    return arr_update


@partial(jax.jit, static_argnames=["layer_id", "config"])
def attention(layer_id, inputs, cache, params, config):
    positions = cache.cur_pos
    seq_lens = jnp.max(positions, axis=-1).astype(jnp.int32) + 1

    sin, cos = cache.sin, cache.cos

    query = inputs @ params["W_q"]
    key = inputs @ params["W_k"]
    value = inputs @ params["W_v"]

    query = rearrange(query, "... t (n h) -> ... t n h", n=config["num_heads"])
    query = rope(query, sin, cos)
    key = rearrange(key, "... t (n h) -> ... t n h", n=config["num_kv_heads"])
    key = rope(key, sin, cos)
    value = rearrange(value, "... t (n h) -> ... t n h", n=config["num_kv_heads"])

    start_idx = jnp.max(cache.cur_pos, axis=-1)

    full_key = update_dynamic(cache[layer_id].keys, key, start_idx)
    full_value = update_dynamic(cache[layer_id].values, value, start_idx)

    def attention_causal(q, k, v, seq_len):
        return jax.nn.dot_product_attention(
            query=q,
            key=k,
            value=v,
            is_causal=True,
            query_seq_lengths=seq_len,
            key_value_seq_lengths=seq_len,
            implementation="cudnn" if has_cuda else "xla",
        )

    def attention_non_causal(q, k, v, seq_len):
        return jax.nn.dot_product_attention(
            query=q,
            key=k,
            value=v,
            is_causal=False,
            query_seq_lengths=seq_len,
            key_value_seq_lengths=seq_len,
            implementation="cudnn" if has_cuda else "xla",
        )

    x = jax.lax.cond(
        cache.dirty,
        lambda: attention_non_causal(query, full_key, full_value, seq_lens),
        lambda: attention_causal(query, full_key, full_value, seq_lens),
    )

    x = rearrange(x, "... t n h -> ... t (n h)")
    x = x @ params["W_o"]

    new_layers = tuple(
        layer if i != layer_id else layer.replace(keys=full_key, values=full_value)
        for i, layer in enumerate(cache.layers)
    )

    return x, cache.replace(layers=new_layers)


@partial(
    jax.jit,
    static_argnames=["config", "hooks_to_return", "hooks_to_stream", "editor"],
)
def run(
    inputs,
    cache,
    params,
    config,
    *,
    hooks_to_return=frozenset(),
    hooks_to_stream=frozenset(),
    editor=None,
):
    # ---------- helpers -------------------------------------------------- #
    captured = []

    def handle_hook(tag, tensor):
        if tag in hooks_to_return:
            captured.append(tensor)
        if tag in hooks_to_stream:
            jax.debug.callback(
                lambda x, t=tag: print(f"stream {t}: {x.shape}"),
                tensor,
                has_side_effect=True,
            )
        if editor is not None:
            bos_id = config["bos_id"]
            eos_id = config["eos_id"]
            pad_id = config["pad_id"]

            edit_mask = (inputs != bos_id) & (inputs != eos_id) & (inputs != pad_id)
            tensor = editor.apply(
                layer=tag[0], kind=tag[1], x=tensor, token_mask=edit_mask
            )
        return tensor

    # --------------------------------------------------------------------- #

    x = jnp.take(params["embed_table"], inputs, axis=0, fill_value=-1e6)

    for layer_id, layer_params in enumerate(params["layers"]):
        # -------- resid_pre (after LN, before Attention) ------------------ #
        y = rms_norm(x, layer_params["input_norm"], eps=config["norm_eps"])
        tag = (layer_id, HookType.RESID_PRE)
        y = handle_hook(tag, y)  # edits propagate into Attention

        # ---------------------- Multi‑Head Attention ---------------------- #
        attn_out, cache = attention(layer_id, y, cache, layer_params["attn"], config)
        tag = (layer_id, HookType.ATTN_OUT)
        handle_hook(tag, attn_out)

        # -------- resid_mid (x += attn_out) ------------------------------- #
        x = x + attn_out
        tag = (layer_id, HookType.RESID_MID)
        x = handle_hook(tag, x)

        # ----------------------------- MLP -------------------------------- #
        y = rms_norm(x, layer_params["post_attn_norm"], eps=config["norm_eps"])

        gate = y @ layer_params["ffn"]["W_gate"]
        act = config["act_fn"](gate)
        tag = (layer_id, HookType.MLP_ACT)
        handle_hook(tag, act)  # capture gate activations

        up = y @ layer_params["ffn"]["W_up"]
        ffn_out = (act * up) @ layer_params["ffn"]["W_down"]
        tag = (layer_id, HookType.MLP_OUT)
        handle_hook(tag, ffn_out)

        # -------- resid_post (x += ffn_out) ------------------------------- #
        x = x + ffn_out
        tag = (layer_id, HookType.RESID_POST)
        x = handle_hook(tag, x)  # final chance to edit layer‑out

    # --------------------------------------------------------------------- #
    if cache.dynamic:
        cache = cache.replace(dirty=True)

    x = rms_norm(x, params["out_norm"], eps=config["norm_eps"])
    lm_head = (
        params["embed_table"].T if config.get("io_tying", False) else params["lm_head"]
    )
    logits = x @ lm_head

    return logits, cache, tuple(captured)


def create(key, config, dtype=jnp.bfloat16, stddev=0.006):
    def normal_init(key, shape, dim_in, stddev=None, dtype=jnp.float32):
        if stddev is None:
            stddev = 1.0 / jnp.sqrt(dim_in)
        return stddev * jax.random.normal(key, shape, dtype=dtype)

    def ones_init(shape, dtype=jnp.float32):
        return jnp.ones(shape, dtype=dtype)

    hidden_size = config["hidden_size"]
    vocab_size = config["vocab_size"]
    num_layers = config["num_layers"]
    num_heads = config["num_heads"]
    num_kv_heads = config.get("num_kv_heads", num_heads)
    intermediate_size = config["intermediate_size"]
    head_dim = config.get("head_dim", hidden_size // num_heads)

    keys = jax.random.split(key, num_layers + 3)
    params = {}

    params["embed_table"] = normal_init(
        keys[0],
        (vocab_size, hidden_size),
        hidden_size,
        stddev=stddev,
        dtype=dtype,
    )

    params["layers"] = []
    layer_keys = jax.random.split(keys[1], num_layers)
    for i in range(num_layers):
        layer_key = layer_keys[i]
        attn_key, ffn_key = jax.random.split(layer_key, 2)
        attn_q_key, attn_k_key, attn_v_key, attn_o_key = jax.random.split(attn_key, 4)
        ffn_g_key, ffn_u_key, ffn_d_key = jax.random.split(ffn_key, 3)

        layer_params = {
            "input_norm": ones_init((hidden_size,), dtype=jnp.float32),
            "attn": {
                "W_q": normal_init(
                    attn_q_key,
                    (hidden_size, num_heads * head_dim),
                    hidden_size,
                    dtype=dtype,
                ),
                "W_k": normal_init(
                    attn_k_key,
                    (hidden_size, num_kv_heads * head_dim),
                    hidden_size,
                    dtype=dtype,
                ),
                "W_v": normal_init(
                    attn_v_key,
                    (hidden_size, num_kv_heads * head_dim),
                    hidden_size,
                    dtype=dtype,
                ),
                "W_o": normal_init(
                    attn_o_key,
                    (num_heads * head_dim, hidden_size),
                    num_heads * head_dim,
                    dtype=dtype,
                ),
            },
            "post_attn_norm": ones_init((hidden_size,), dtype=jnp.float32),
            "ffn": {
                "W_gate": normal_init(
                    ffn_g_key,
                    (hidden_size, intermediate_size),
                    hidden_size,
                    dtype=dtype,
                ),
                "W_up": normal_init(
                    ffn_u_key,
                    (hidden_size, intermediate_size),
                    hidden_size,
                    dtype=dtype,
                ),
                "W_down": normal_init(
                    ffn_d_key,
                    (intermediate_size, hidden_size),
                    intermediate_size,
                    dtype=dtype,
                ),
            },
        }
        params["layers"].append(layer_params)

    params["out_norm"] = ones_init((hidden_size,), dtype=jnp.float32)

    if config.get("io_tying", False):
        params["lm_head"] = normal_init(
            keys[-1], (hidden_size, vocab_size), hidden_size, dtype=dtype
        )

    return params
