from typing import Optional, Union

import torch as t
import torch.nn as nn
from einops import rearrange, repeat
from jaxtyping import Float
from loguru import logger
from torch.utils.data import DataLoader
from tqdm import tqdm

from auto_encoder.config import AutoEncoderConfig
from auto_encoder.config_enums import AutoEncoderType, ResamplingType
from auto_encoder.helpers.ae_output_types import SupermodelOutput
from auto_encoder.models.base_ae import AutoEncoderBase
from auto_encoder.models.vanilla_ae import VanillaAE
from auto_encoder.training.supermodel import FrozenTransformerAutoencoderSuperModel


def get_resample_mask(
    activations_tracker: t.Tensor, resampling_eps: float
) -> tuple[t.Tensor, t.Tensor]:
    # Feature pronounced dead if it is not activated in any of the 1000 batches
    feature_dead = ~activations_tracker  # [num_features]

    resampling_mask = repeat(feature_dead, "f -> f 1")  # [num_features, 1]
    return resampling_mask, feature_dead


def _reset_optim_state_vanilla_(
    optimizer_state: dict,
    feature_dead: t.Tensor,
    linear: nn.Linear,
    feature_bias_F: Optional[nn.Parameter] = None,
) -> None:
    optimizer_state[linear.weight]["exp_avg"].zero_()[feature_dead, :]
    optimizer_state[linear.weight]["exp_avg_sq"].zero_()[feature_dead, :]

    if feature_bias_F is not None:
        optimizer_state[feature_bias_F]["exp_avg"].zero_()[feature_dead]
        optimizer_state[feature_bias_F]["exp_avg_sq"].zero_()[feature_dead]


def reset_optim_state_for_dead_features_(
    optimizer_state: dict,
    feature_dead: t.Tensor,
    supermodel: FrozenTransformerAutoencoderSuperModel,
) -> None:
    if supermodel.autoencoder_type.vanilla:
        assert isinstance(supermodel.autoencoder, VanillaAE)
        _reset_optim_state_vanilla_(
            optimizer_state,
            feature_dead,
            supermodel.autoencoder.encoder,
            supermodel.autoencoder.feature_bias_F,  # type: ignore
        )
    else:
        raise ValueError(f"Invalid autoencoder type {supermodel.autoencoder_type}")


def compute_static_model_loss(
    dataloader: DataLoader,
    supermodel: FrozenTransformerAutoencoderSuperModel,
    device: str,
    num_examples: int,
) -> tuple[t.Tensor, t.Tensor]:
    batch_list = []
    loss_list = []
    for i, batch in tqdm(enumerate(dataloader)):
        if i >= num_examples:
            break

        input_batch = t.tensor(batch["input_ids"]).to(device)
        supermodel_output: SupermodelOutput = supermodel(input_batch)
        batched_loss = supermodel_output.batched_loss_BS

        batch_list.append(input_batch)  # batch_size, seq_len
        loss_list.append(batched_loss)  # batch_size, seq_len

    batches = t.cat(batch_list, dim=0).to(device)  # num_batches*batch_size seq_len
    losses = t.cat(loss_list, dim=0).to(device)  # num_batches*batch_size seq_len

    return batches, losses


def apply_naive_resampling_(
    supermodel: FrozenTransformerAutoencoderSuperModel,
    ae_config: AutoEncoderConfig,
    device: str,
    optimizer: Optional[t.optim.Optimizer],
    activations_tracker: t.Tensor,
) -> None:
    with t.no_grad():
        resampling_mask, feature_dead = get_resample_mask(
            activations_tracker, ae_config.resampling_eps
        )

        if supermodel.autoencoder_type.vanilla:
            assert isinstance(supermodel.autoencoder, VanillaAE)
            # Resample features that are pronounced dead
            supermodel.autoencoder.encoder.weight.data *= ~resampling_mask

            resampled = t.zeros_like(supermodel.autoencoder.encoder.weight)
            nn.init.kaiming_uniform_(resampled)

            # Resurrect the dead features
            supermodel.autoencoder.encoder.weight.data += resampled * resampling_mask

        else:
            raise ValueError(f"Invalid autoencoder type {supermodel.autoencoder_type}")

        # Set bias for dead features to 0
        zero_out_bias_for_dead_features_(supermodel, feature_dead)

        # Reset adam optimizer parameters for the modified weights and biases
        if optimizer is not None:
            reset_optim_state_for_dead_features_(optimizer.state, feature_dead, supermodel)


def apply_fancy_resampling_(
    supermodel: FrozenTransformerAutoencoderSuperModel,
    train_dataloader: DataLoader,
    ae_config: AutoEncoderConfig,
    device: str,
    optimizer: Optional[t.optim.Optimizer],
    activations_tracker: t.Tensor,
) -> None:

    with t.no_grad():
        #### Step 1: Identify dead neurons
        _, feature_dead = get_resample_mask(
            activations_tracker, ae_config.resampling_eps
        )  # num_features 1

        #### Step 2: Compute loss for static model on (large number of)
        # random inputs
        batches, losses = compute_static_model_loss(
            train_dataloader,
            supermodel,
            device=device,
            num_examples=ae_config.num_static_loss_samples,
        )  # batch_size*num_chosen_batches, seq_len

        #### Step 3: Assign input vectors prob of being picked
        prob_input_example_picked = t.sum(losses**2, dim=-1) / t.sum(losses**2)  # num_examples

        #### Step 4: For each dead feature neuron sample an dictionary vector by
        # the above probs
        indices = t.multinomial(
            prob_input_example_picked,
            num_samples=ae_config.num_features,
            replacement=True,
        ).to(
            device
        )  # num_features

        batches = batches.to(device)  # num_examples seq_len

        randomly_chosen_input_sequences = batches[indices, :]  # Int [num_features seq_len]

        # Split randomly chosen input sequences into batches
        # if len(randomly_chosen_input_sequences) % ae_config.batch_size != 0:
        #     # Making the length of `randomly_chosen_input_sequences` i.e. num_features be divisible by batch size
        #     randomly_chosen_input_sequences = randomly_chosen_input_sequences[
        #         : len(randomly_chosen_input_sequences)
        #         - (len(randomly_chosen_input_sequences) % ae_config.batch_size)
        #     ]

        randomly_chosen_input_sequences = t.split(
            randomly_chosen_input_sequences, ae_config.batch_size
        )  # List of [batch_size seq_len]

        # Randomly choose one of the sampled inputs
        mlp_activations: t.Tensor
        mlp_acts_list: list[t.Tensor] = []
        for randomly_chosen_input_sequence in tqdm(randomly_chosen_input_sequences):
            supermodel_output: SupermodelOutput = supermodel(randomly_chosen_input_sequence)
            mlp_activations = (
                supermodel_output.initial_neuron_activations_BSN
            )  # num_feats seq_len num_neurons
            mlp_acts_list.append(mlp_activations)

        mlp_activations = t.cat(mlp_acts_list, dim=0).to(
            device
        )  # num_feats seq_len num_neurons

        _, max_seq_len, _ = mlp_activations.shape

        randomly_chosen_mlp_acts = mlp_activations[
            :, t.randint(0, max_seq_len, (1,)), :
        ]  # num_features 1 num_neurons
        randomly_chosen_mlp_acts = randomly_chosen_mlp_acts.squeeze(
            1
        )  # num_features num_neurons

        # Rescaling the randomly chosen mlp activations
        randomly_chosen_mlp_acts = randomly_chosen_mlp_acts / t.norm(
            randomly_chosen_mlp_acts, dim=-1, keepdim=True
        )  # num_features num_neurons

        resampled_dictionary_vectors = randomly_chosen_mlp_acts * feature_dead.unsqueeze(
            1
        )  # num_features num_neurons

        autoencoder: AutoEncoderBase = supermodel.autoencoder

        autoencoder.decoder.weight.data = rearrange(
            resampled_dictionary_vectors, "feature neuron -> neuron feature"
        ) + (
            ~feature_dead.unsqueeze(0) * autoencoder.decoder.weight.data
        )  # num_neurons num_features

        #### Step 5: Each encoder vector (there are num_features of these), renormalize the input vector
        # to equal the average norm of the encoder weights for alive neurons × 0.2.
        # Set the corresponding encoder bias element to zero.

        if supermodel.autoencoder_type.vanilla:
            assert isinstance(autoencoder, VanillaAE)
            resampled_encoder_weights = encoder_vector_renormalisation(
                ae_config=ae_config,
                feature_dead=feature_dead,
                encoder_weights=autoencoder.encoder.weight.data,
            )

            # Resample features that are pronounced dead
            autoencoder.encoder.weight.data = (
                resampled_encoder_weights  # num_features num_neurons
            )

        # Set bias for dead features to 0
        zero_out_bias_for_dead_features_(supermodel, feature_dead)

        #### Step 6: Reset adam optimizer parameters for the modified weights and biases
        if optimizer is not None:
            reset_optim_state_for_dead_features_(optimizer.state, feature_dead, supermodel)


def zero_out_bias_for_dead_features_(
    supermodel: FrozenTransformerAutoencoderSuperModel,
    feature_dead: t.Tensor,
) -> None:
    if supermodel.autoencoder_type.vanilla:
        assert isinstance(supermodel.autoencoder, VanillaAE)
        supermodel.autoencoder.feature_bias_F *= feature_dead.float()

    else:
        raise ValueError(f"Invalid autoencoder type {supermodel.autoencoder_type}")


def encoder_vector_renormalisation(
    ae_config: AutoEncoderConfig,
    feature_dead: t.Tensor,
    encoder_weights: Float[t.Tensor, "num_features num_neurons"],
) -> t.Tensor:
    masked_encoder_weights = encoder_weights * ~feature_dead.unsqueeze(
        1
    )  # num_features num_neurons

    masked_encoder_weight_norms = t.norm(masked_encoder_weights, dim=0)  # num_neurons
    total_masked_encoder_weight_norms = t.sum(masked_encoder_weight_norms)  # scalar
    num_alive_neurons = ae_config.num_features - t.sum(feature_dead)  # scalar
    avg_encoder_weight_norm_for_neurons = (
        total_masked_encoder_weight_norms / num_alive_neurons
    )  # scalar

    resampled_encoder_weights = feature_dead.unsqueeze(1) * encoder_weights / t.norm(
        encoder_weights, dim=0
    ).unsqueeze(
        0
    ) * avg_encoder_weight_norm_for_neurons * 0.2 + (  # num_features num_neurons
        ~feature_dead.unsqueeze(1) * encoder_weights  # num_features num_neurons
    )

    return resampled_encoder_weights  # num_features num_neurons


def apply_resampling_(
    supermodel: FrozenTransformerAutoencoderSuperModel,
    train_dataloader: DataLoader,
    ae_config: AutoEncoderConfig,
    resampling_type: ResamplingType,
    device: str,
    optimizer: Optional[t.optim.Optimizer],
    features_activated_tracker: t.Tensor,
    router_activations_tracker: Optional[t.Tensor],
    num_dead_features: int,
    num_dead_experts: int,
) -> None:
    if num_dead_features == 0 and num_dead_experts == 0:
        logger.info("No feature or router resampling required")

        # Reset activations list
        features_activated_tracker = t.zeros(ae_config.num_features, device=device).bool()
        router_activations_tracker = t.zeros(ae_config.num_neurons, device=device)
        return

    if num_dead_features > 0:

        if resampling_type == ResamplingType.FANCY:
            apply_fancy_resampling_(
                supermodel=supermodel,
                train_dataloader=train_dataloader,
                ae_config=ae_config,
                device=device,
                optimizer=optimizer,
                activations_tracker=features_activated_tracker,
            )
            logger.info("Completed fancy resampling")
        elif resampling_type == ResamplingType.NAIVE:
            apply_naive_resampling_(
                supermodel=supermodel,
                ae_config=ae_config,
                device=device,
                optimizer=optimizer,
                activations_tracker=features_activated_tracker,
            )

            logger.info("Completed naive resampling")

    else:
        logger.info("No router resampling required")

    # Reset activations list
    features_activated_tracker = t.zeros(ae_config.num_features, device=device)
