import dataclasses
import enum
import json
import logging
import os
import pathlib
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Optional

import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax import struct
from flax.core import FrozenDict, freeze
from jax import random
from tokenizers import Tokenizer
from tqdm.auto import tqdm

from ..impl import config as config_lib
from ..impl.hooks import HookType
from ..tools import compat, train_utils
from . import utils
from .jump_relu import jumprelu, step

jax.config.update("jax_default_matmul_precision", "float32")


@struct.dataclass
class SAEConfig:
    layer_id: int
    hidden_size: int
    latent_multiplier: int

    aux_coeff: float
    act_fn: Callable = struct.field(pytree_node=False)
    loss_fn: Callable = struct.field(pytree_node=False)
    ghost_coeff: float = 1e-3
    batches_until_dead: int = 500

    sparsity_k: Optional[int] = struct.field(default=None, pytree_node=False)
    jump_threshold: Optional[float] = struct.field(default=None, pytree_node=False)
    ste_bandwidth: Optional[float] = struct.field(default=None, pytree_node=False)

    placement: HookType = struct.field(default=HookType.RESID_POST, pytree_node=False)
    dtype: Any = struct.field(default=jnp.float32, pytree_node=False)
    pre_enc_bias: bool = struct.field(default=True, pytree_node=False)
    rescale_inputs: bool = struct.field(default=True, pytree_node=False)

    @property
    def latent_size(self):
        return self.hidden_size * self.latent_multiplier

    def save(self, path: str | os.PathLike) -> None:
        def _convert(k, v):
            if isinstance(v, enum.Enum):
                return v.name
            if k == "dtype":
                import numpy as _np

                try:
                    return _np.dtype(v).name
                except TypeError:
                    return str(v)
            if isinstance(v, (np.dtype, jnp.dtype)):
                return v.name
            try:
                json.dumps(v)
                return v
            except TypeError:
                return str(v)

        data = {k: _convert(k, v) for k, v in dataclasses.asdict(self).items()}

        path = pathlib.Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        with path.open("w", encoding="utf-8") as fp:
            json.dump(data, fp, indent=2, sort_keys=True)

        logging.info("SAEConfig saved to %s", path)

    @classmethod
    def from_file(cls, path: str | os.PathLike) -> "SAEConfig":
        path = pathlib.Path(path)
        if not path.exists():
            raise FileNotFoundError(path)

        with path.open("r", encoding="utf-8") as fp:
            raw = json.load(fp)

        def _restore(key, value):
            if key == "placement":
                return HookType[value]
            if key == "dtype":
                return jnp.dtype(value)
            return value

        cfg_kwargs = {k: _restore(k, v) for k, v in raw.items()}
        return cls(**cfg_kwargs)


@jax.jit
def normalize(x):
    x_mean = jnp.mean(x, axis=-1, keepdims=True)
    x_std = jnp.std(x, axis=-1, keepdims=True)
    x_norm = (x - x_mean) / (x_std + 1e-6)
    return x_norm, x_mean, x_std


@jax.jit
def rescale(x, x_mean, x_std):
    return x * x_std + x_mean


def decode_latent(latent, *, params, config, x_mean, x_std, with_act=False):
    act = config.act_fn(latent, config)

    rec_scaled = act @ params["W_dec"] + params["b_dec"]
    rec_base = rescale(rec_scaled, x_mean, x_std).astype(latent.dtype)
    if with_act:
        return rec_base, act

    return rec_base


@partial(jax.jit, static_argnames=("return_latents", "return_act"))
def run(
    inputs,
    params,
    config: SAEConfig,
    latent_mask=None,
    return_latents=False,
    return_act=False,
):
    x_mean = 0.0
    x_std = 1.0

    if config.rescale_inputs:
        inputs, x_mean, x_std = normalize(inputs)

    if config.pre_enc_bias:
        inputs = inputs - params["b_dec"]

    latent = inputs @ params["W_enc"]
    if latent_mask is not None:
        latent = latent * latent_mask

    rec_base, act = decode_latent(
        latent, params=params, config=config, x_mean=x_mean, x_std=x_std, with_act=True
    )

    to_return = (rec_base,)
    if return_latents:
        to_return += (latent,)
    if return_act:
        to_return += (act,)
    return to_return


# ----------------------------------
# SAE Activations
# ----------------------------------


@jax.jit
def relu(x, config):
    return jax.nn.relu(x)


@jax.jit
def per_token_topk(x, config):
    if config.sparsity_k is None:
        raise ValueError("Top-k activation requires SAEConfig.sparsity_k")
    x_pos = jax.nn.relu(x)
    values, indices = jax.lax.top_k(x_pos, config.sparsity_k)
    kth = values[..., -1:]
    sparsified = jnp.where(x_pos < kth, 0.0, x_pos)

    return sparsified


# ----------------------------------
# Initialization
# ----------------------------------


def create(random_key, config):
    he_init = jax.nn.initializers.he_uniform()
    W_enc = he_init(
        random_key,
        (config.hidden_size, config.latent_size),
        dtype=config.dtype,
    )
    W_dec = W_enc.T
    W_dec = W_dec / jnp.linalg.norm(W_dec, axis=-1, keepdims=True)
    b_dec = jnp.zeros((config.hidden_size,), dtype=config.dtype)

    if config.act_fn is jumprelu or config.act_fn is step:
        thresholds = jnp.full(
            (config.latent_size,), config.jump_threshold, config.dtype
        )
    else:
        thresholds = None
    return {"W_enc": W_enc, "W_dec": W_dec, "b_dec": b_dec, "thresholds": thresholds}


# ----------------------------------
# TrainState and Configuration
# ----------------------------------


@dataclasses.dataclass
class TrainConfig:
    sae_configs: tuple[SAEConfig]
    lm_params: Any
    lm_config: FrozenDict
    tokenizer: Tokenizer
    optimizers: list[optax.GradientTransformation]
    checkpoint_root: str

    wandb_project: str = "MolecularSAEs"
    wandb_run_name: Optional[str] = None

    log_every: int = 100
    checkpoint_every: int = 500
    token_limit: Optional[int] = None


@struct.dataclass
class TrainState:
    config: SAEConfig = struct.field(pytree_node=False)
    optimizer: optax.GradientTransformation = struct.field(pytree_node=False)
    opt_state: Any
    params: Any
    last_fired_batch: jnp.ndarray

    @classmethod
    def create(cls, rng, config: SAEConfig, optimizer):
        initial_params = create(rng, config)
        initial_opt_state = optimizer.init(initial_params)
        last_fired = -jnp.ones((config.latent_size,), dtype=jnp.int32)
        return cls(
            config=config,
            optimizer=optimizer,
            opt_state=initial_opt_state,
            params=initial_params,
            last_fired_batch=last_fired,
        )


def create_states(random_key, config: TrainConfig):
    rng = random_key
    states = []
    for sae_config, optimizer in zip(config.sae_configs, config.optimizers):
        rng, subkey = random.split(rng)
        state = TrainState.create(subkey, sae_config, optimizer)
        states.append(state)

    return states


# ----------------------------------
# Loss functions
# ----------------------------------


def _broadcast_mask(mask: jnp.ndarray, ref: jnp.ndarray) -> jnp.ndarray:
    if mask is None:
        return None
    mask = mask.astype(ref.dtype)
    while mask.ndim < ref.ndim:
        mask = mask[..., None]
    return mask


@partial(jax.jit, static_argnames=["sae_config"])
def loss_fn(params, _state, inputs, loss_mask, sae_config, step):
    # -----------------------------------------------------------
    # Forward pass
    # -----------------------------------------------------------
    rec, latents, acts = run(
        inputs, params, sae_config, return_latents=True, return_act=True
    )

    # Masks
    mask_tok = _broadcast_mask(loss_mask, rec)  # (B,T,1) or None
    mask_2d = None if mask_tok is None else jnp.squeeze(mask_tok, -1)  # (B,T)
    mask_el = None if mask_tok is None else jnp.broadcast_to(mask_tok, rec.shape)

    # -----------------------------------------------------------
    # Reconstruction loss – L2² per token
    # -----------------------------------------------------------
    sq_err_el = (inputs - rec) ** 2  # (B,T,d)
    sq_err_tok = jnp.sum(sq_err_el, axis=-1)  # (B,T)

    if mask_2d is not None:
        sq_err_tok *= mask_2d
        acts *= mask_tok  # keep 3-D here

    num_tokens = sq_err_tok.size if mask_2d is None else jnp.sum(mask_2d)
    mse_loss = jnp.sum(sq_err_tok) / num_tokens

    # -----------------------------------------------------------
    # Sparsity terms
    # -----------------------------------------------------------
    l1_loss = jnp.sum(jnp.abs(acts)) / num_tokens
    scaled_l1_loss = sae_config.aux_coeff * l1_loss
    norm_l1_loss = l1_loss / sae_config.latent_size
    l0_loss = jnp.sum((acts != 0).sum(-1)) / num_tokens
    total_loss = mse_loss + scaled_l1_loss

    # -----------------------------------------------------------
    # NMSE (token-level variance)
    # -----------------------------------------------------------
    if mask_el is None:
        mu = inputs.mean()
        var_tok = jnp.sum((inputs - mu) ** 2, axis=-1)  # (B,T)
        var_x = var_tok.mean()
    else:
        mu = jnp.sum(inputs * mask_el) / jnp.sum(mask_el)
        centered = inputs - mu
        var_tok = jnp.sum((centered**2) * mask_el, axis=-1)
        var_x = jnp.sum(var_tok) / num_tokens

    nmse = mse_loss / (var_x + 1e-8)

    # -----------------------------------------------------------
    # Health metrics
    # -----------------------------------------------------------
    fires = acts != 0
    dead_frac = 1.0 - fires.any(axis=(0, 1)).mean()
    dense_frac_001 = (fires.mean(axis=(0, 1)) > 0.01).mean()
    dense_frac_005 = (fires.mean(axis=(0, 1)) > 0.05).mean()
    l1_over_mse = norm_l1_loss / (mse_loss + 1e-8)

    # -----------------------------------------------------------
    # Ghost gradient (masked)
    # -----------------------------------------------------------
    delta = rec - inputs
    if mask_tok is not None:
        delta *= mask_tok
    act_grad = delta @ params["W_dec"].T
    silent_for = step - _state.last_fired_batch
    dead_mask = (silent_for >= sae_config.batches_until_dead).astype(inputs.dtype)
    ghost_loss = (
        sae_config.ghost_coeff
        * jnp.sum((act_grad**2) * dead_mask)
        / (num_tokens * sae_config.latent_size)
    )
    total_loss += ghost_loss

    # -----------------------------------------------------------
    # Return
    # -----------------------------------------------------------
    return total_loss, {
        "mse_loss": mse_loss.astype(float),
        "l1_loss": l1_loss.astype(float),
        "l0_loss": l0_loss.astype(float),
        "scaled_l1": scaled_l1_loss.astype(float),
        "nmse": nmse.astype(float),
        "dead_frac": dead_frac.astype(float),
        "dense_frac": dense_frac_001.astype(float),
        "dense_frac_0.05": dense_frac_005.astype(float),
        "l1_over_mse": l1_over_mse.astype(float),
        "ghost_loss": ghost_loss.astype(float),
        "ghost_dead_mask": dead_mask.mean().astype(float),
        "total_loss": total_loss.astype(float),
    }


# ----------------------------------
# Trainer code
# ----------------------------------


def project_grad_on_decoder(updates, params, eps=1e-8):
    W_dec = params["W_dec"]
    g_dec = updates["W_dec"]
    dot = jnp.sum(g_dec * W_dec, axis=-1, keepdims=True)
    updates = updates.copy()
    updates["W_dec"] = g_dec - dot * W_dec / (1.0 + eps)
    return updates


def add_decoder_projection(optimizer):
    def init_fn(params):
        return optimizer.init(params)

    def update_fn(grads, state, params):
        grads = project_grad_on_decoder(grads, params)
        return optimizer.update(grads, state, params)

    return optax.GradientTransformation(init_fn, update_fn)


def renorm_decoder(params, eps=1e-8):
    W = params["W_dec"]
    params = params.copy()
    params["W_dec"] = W / (jnp.linalg.norm(W, axis=-1, keepdims=True) + eps)
    return params


@partial(jax.jit, static_argnames=["sae_config"])
def train_step(state, inputs, mask, sae_config, step):
    grad_fn = jax.value_and_grad(sae_config.loss_fn, has_aux=True)

    (total_loss, metrics), grads = grad_fn(
        state.params, state, inputs, mask, sae_config, step
    )

    updates, new_opt_state = state.optimizer.update(
        grads, state.opt_state, state.params
    )
    new_params = optax.apply_updates(state.params, updates)
    new_params = renorm_decoder(new_params)

    _, acts = run(inputs, new_params, sae_config, return_act=True)
    mask_tokens = _broadcast_mask(mask, acts)
    acts = acts * mask_tokens
    fired_mask = jnp.any(acts != 0, axis=(0, 1))
    new_last = jnp.where(fired_mask, step, state.last_fired_batch)

    return state.replace(
        params=new_params,
        opt_state=new_opt_state,
        last_fired_batch=new_last,
    ), metrics


def train(ds, random_key, config: TrainConfig):
    run_fn, mask_fn, hooks = utils.get_hook_tools(config)
    states = create_states(random_key, config)

    token_limit = config.token_limit
    current_step = 0
    elapsed_tokens = 0

    try:
        for batch in tqdm(ds):
            inputs = batch["inputs"]
            positions = batch["positions"]
            mask = mask_fn(inputs)

            residuals = utils.run_and_capture(
                run_fn,
                inputs,
                positions,
                config.lm_params,
                config.lm_config,
                hooks,
            )

            new_states = []
            metrics = []

            for state, sae_config in zip(states, config.sae_configs):
                new_state, per_sae_metrics = train_step(
                    state,
                    residuals[(sae_config.layer_id, sae_config.placement)],
                    mask,
                    sae_config,
                    current_step,
                )
                new_states.append(new_state)
                metrics.append(per_sae_metrics)
            states = new_states

            utils.maybe_log(current_step + 1, metrics, elapsed_tokens, config)
            utils.maybe_save(current_step + 1, states, config)

            elapsed_tokens += jnp.sum(mask)
            current_step += 1

            if token_limit is not None and elapsed_tokens >= token_limit:
                logging.info(
                    f"Reached limit of {(token_limit / 1e9):.2f}B tokens, stopping training."
                )
                break
    finally:
        utils.maybe_save("final", states, config)


# ----------------------------------
# Utility class
# ----------------------------------


@dataclass
class SAEKit:
    tokenizer: any
    lm_config: any
    lm_params: any
    sae_configs: list
    sae_params: list
    train_config: any
    run_fn: any
    mask_fn: any
    hooks: any

    @classmethod
    def load(cls, model_dir, checkpoint_id, sae_dir):
        tokenizer = compat.load_tokenizer(
            mode="train",
            tokenizer_path=f"{model_dir}/tokenizer.json",
            generation_config_file=f"{model_dir}/generation_config.json",
            trunc_length=256,
        )

        model_config_raw = config_lib.load_from_dir(model_dir)
        model_config_raw.copy(
            dict(
                bos_id=tokenizer.bos_token_id,
                eos_id=tokenizer.eos_token_id,
                pad_id=tokenizer.pad_token_id,
            )
        )
        lm_config = freeze(model_config_raw)

        lm_params, *_ = train_utils.load_checkpoint(
            f"{model_dir}/checkpoints/checkpoint_{checkpoint_id}.pkl"
        )

        sae_configs = [
            SAEConfig.from_file(f"{sae_dir}/sae_{i}_config.json").replace(
                act_fn=relu, loss_fn=None
            )
            for i in range(model_config_raw["num_layers"])
        ]
        sae_params = [
            train_utils.load_checkpoint(f"{sae_dir}/sae_{i}/checkpoint_final.pkl")[0]
            for i in range(model_config_raw["num_layers"])
        ]

        train_config = TrainConfig(
            sae_configs=sae_configs,
            lm_params=lm_params,
            lm_config=lm_config,
            tokenizer=tokenizer,
            optimizers=None,
            checkpoint_root=None,
        )
        run_fn, mask_fn, hooks = utils.get_hook_tools(train_config)

        return cls(
            tokenizer=tokenizer,
            lm_config=lm_config,
            lm_params=lm_params,
            sae_configs=sae_configs,
            sae_params=sae_params,
            train_config=train_config,
            run_fn=run_fn,
            mask_fn=mask_fn,
            hooks=hooks,
        )

    def get_encoded(self, inputs, positions, layer_id: int, active_only=True):
        config = self.sae_configs[layer_id]
        prm = self.sae_params[layer_id]
        return run(
            utils.run_and_capture(
                self.run_fn,
                inputs,
                positions,
                self.lm_params,
                self.lm_config,
                self.hooks,
            )[(config.layer_id, config.placement)],
            prm,
            config,
            return_latents=not active_only,
            return_act=active_only,
        )[1]
