from typing import Optional, Union

import torch as t
from einops import rearrange
from loguru import logger
from torch import nn
from torch.nn import functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

from auto_encoder import device
from auto_encoder.config import AutoEncoderConfig
from auto_encoder.helpers.ae_output_types import AutoEncoderOutput, SupermodelOutput
from auto_encoder.helpers.get_autoencoder import create_autoencoder
from auto_encoder.models.base_ae import AutoEncoderBase


class FrozenTransformerAutoencoderSuperModel(nn.Module):
    def __init__(
        self,
        transformer: PreTrainedModel,
        device: str,
        medoid_initial_tensor_N: Optional[t.Tensor],
        expert_initial_tensors: Optional[t.Tensor],
        scaling_factor: Optional[float],
        ae_config: AutoEncoderConfig,
        autoencoder: Optional[AutoEncoderBase] = None,
        layer_to_hook: Optional[int] = None,
    ):
        super().__init__()
        self.transformer = transformer.to(device)  # type: ignore
        self.transformer.eval()

        # Attach the hook function after the MLP layer to get the activations
        num_transformer_layers = len(self.transformer.transformer.h)
        self.layer_to_hook = (
            layer_to_hook if layer_to_hook is not None else num_transformer_layers // 2
        )

        # Freeze the transformer
        for param in self.transformer.parameters():
            param.requires_grad = False

        self.autoencoder_type = ae_config.autoencoder_type

        # Create the autoencoder
        if autoencoder is not None:
            self.autoencoder = autoencoder.to(device)
        else:
            self.autoencoder = create_autoencoder(
                config=ae_config,
                autoencoder_type=ae_config.autoencoder_type,
                medoid_initial_tensor_N=medoid_initial_tensor_N,
                expert_initial_tensors=expert_initial_tensors,
                scaling_factor=scaling_factor,
            ).to(device)

        logger.info(f"Autoencoder: {self.autoencoder}")

        self.ae_config = ae_config

        self.load_balancing_loss_coef = ae_config.auxiliary_balancing_loss_coef
        self.router_z_loss_coef = ae_config.router_z_loss_coef
        self.expert_importance_loss_coef = ae_config.expert_importance_loss_coef

        self.auxiliary_gating_recon_loss_coef = ae_config.auxiliary_gating_recon_coef
        self.auxiliary_l0_sparsity_coef = ae_config.auxiliary_l0_sparsity_coef
        self.auxiliary_l1_sparsity_coef = ae_config.auxiliary_l1_sparsity_coef
        self.auxiliary_l2_sparsity_coef = ae_config.auxiliary_l2_sparsity_coef

        self.auxiliary_nfm_loss_coef = ae_config.auxiliary_nfm_loss_coef
        self.auxiliary_nfm_inf_loss_coef = ae_config.auxiliary_nfm_inf_loss_coef
        self.auxiliary_multi_info_loss_coef = ae_config.auxiliary_multi_info_loss_coef
        self.auxiliary_hessian_loss_coef = ae_config.auxiliary_hessian_loss_coef

        self.auxiliary_codebook_loss_coef = ae_config.auxiliary_codebook_loss_coef
        self.auxiliary_commitment_loss_coef = ae_config.auxiliary_commitment_loss_coef

        self.feature_reconstruction_loss_coef = ae_config.feature_reconstruction_loss_coef

        # self.decorr_strength = ae_config.decorr_strength

        # self.baseline_decorr_loss: Optional[t.Tensor] = None

    def stash_mlp_acts_hook(self, module, input, output):
        self.mlp_neuron_activations = output

    def stash_pre_mlp_hook(self, module, input, output):
        self.pre_mlp_activations = input[0]

    def mse(self, x: t.Tensor, y: t.Tensor) -> t.Tensor:
        return t.mean((x - y) ** 2, dim=-1)

    def l1(self, x: t.Tensor) -> t.Tensor:
        return t.mean(t.abs(x), dim=-1)

    def get_neuron_activations(self, input_ids: t.Tensor) -> t.Tensor:
        with t.no_grad():
            output: CausalLMOutputWithCrossAttentions = self.transformer(
                input_ids, output_hidden_states=True
            )
            hooked_residual_stream_activations = output.hidden_states[self.layer_to_hook]  # type: ignore
            return hooked_residual_stream_activations

    def get_feature_activations(self, input_ids: t.Tensor) -> t.Tensor:
        neuron_activations_BSN = self.get_neuron_activations(input_ids)
        autoencoder_output: AutoEncoderOutput = self.autoencoder(
            neuron_activations_BSN, output_intermediate_activations=True
        )

        return autoencoder_output.features_BSF

    def forward(
        self,
        input_ids_BS: t.Tensor,
        output_eval_metrics: bool = False,
    ) -> SupermodelOutput:
        # Run the data through the transformer
        with t.no_grad():
            clean_loss: t.Tensor
            output: CausalLMOutputWithCrossAttentions = self.transformer(
                input_ids_BS, output_hidden_states=True
            )
            clean_logits_BSV = output.logits

            clean_logits_BsV = rearrange(clean_logits_BSV[:, :-1, :], "b s v -> (b s) v")

            flattened_labels = rearrange(input_ids_BS[:, 1:], "b s -> (b s)")

            clean_loss = F.cross_entropy(clean_logits_BsV, flattened_labels, reduction="mean")

            initial_neuron_activations_BSN = output.hidden_states[self.layer_to_hook]  # type: ignore

        # Run the activations through the autoencoder
        supermodel_output: SupermodelOutput = self.autoencoder(
            initial_neuron_activations_BSN, output_supermodel_output=True
        )
        assert isinstance(supermodel_output, SupermodelOutput)

        #     metrics.naive_description_length_bits = (
        #         autoencoder_output.naive_description_length_bits.item()
        #     )

        feature_activations = supermodel_output.feature_activations_BSF
        normalising_constants_BS = supermodel_output.normalising_constants_BS

        if output_eval_metrics:
            downstream_loss_recovered = self._compare_ablated_losses_resid_stream(
                input_ids_BS,
                clean_loss,
                feature_activations,
                normalising_constants_BS,
            )

            supermodel_output.metrics.downstream_loss_recovered = (
                downstream_loss_recovered.clamp(0, 1).item()
            )

        logger.debug(supermodel_output.metrics)

        return supermodel_output

    def autoencoder_loss(
        self, initial_mlp_neuron_activations: t.Tensor, dtype: t.dtype
    ) -> SupermodelOutput:
        self.autoencoder.to(dtype=dtype)

        initial_mlp_neuron_activations = initial_mlp_neuron_activations.to(dtype=dtype)

        supermodel_output: SupermodelOutput = self.autoencoder(
            initial_mlp_neuron_activations, output_supermodel_output=True
        )

        if supermodel_output.metrics.nfm_loss is None:
            nfm_loss, nfm_inf_loss = self.autoencoder.nfm_losses()
            supermodel_output.metrics.nfm_loss = nfm_loss.item()
            supermodel_output.metrics.nfm_inf_loss = nfm_inf_loss.item()

        return supermodel_output

    def save_autoencoder(self, path: str, other_details: str = "") -> None:
        self.autoencoder.save(path, other_details=other_details)

    def _compare_ablated_losses_mlp(
        self,
        input_ids: t.Tensor,
        clean_loss: t.Tensor,
        pre_mlp_activations: t.Tensor,
        feature_activations: t.Tensor,
        normalising_constants: t.Tensor,
    ) -> t.Tensor:
        mlp_ablated_logits = self.transformer.unembedding(pre_mlp_activations)

        # Swap out the MLP activations for the reconstructed ones using the autoencoder
        raw_codebook_mlp_activations = self.autoencoder.decode(
            feature_activations
        )  # raw_x_hat
        codebook_mlp_activations = self.autoencoder.postprocess(
            raw_codebook_mlp_activations, normalising_constants
        )  # x_hat
        codebook_post_mlp_activations = self.transformer.mlp[2](codebook_mlp_activations)  # y
        post_resid_connection = codebook_post_mlp_activations + pre_mlp_activations  # x
        post_resid_connection = self.transformer.final_layer_norm(post_resid_connection)  # x
        codebook_ablated_logits = self.transformer.unembedding(
            post_resid_connection
        )  # batch_size, seq_len, vocab_size

        mlp_ablated_logits = rearrange(mlp_ablated_logits[:, :-1, :], "b s v -> (b s) v")
        codebook_ablated_logits = rearrange(
            codebook_ablated_logits[:, :-1, :], "b s v -> (b s) v"
        )

        flattened_labels = rearrange(input_ids[:, 1:], "b s -> (b s)")

        mlp_ablated_loss = F.cross_entropy(
            mlp_ablated_logits, flattened_labels, reduction="mean"
        )
        codebook_ablated_loss = F.cross_entropy(
            codebook_ablated_logits, flattened_labels, reduction="mean"
        )

        # Downstream loss recovered is the proportion of the downstream loss that is recovered with the codebook
        # using the loss recovered when the whole MLP is ablated as a baseline
        downstream_loss_recovered = 1 - (clean_loss - codebook_ablated_loss) / (
            clean_loss - mlp_ablated_loss
        )

        return downstream_loss_recovered

    def _compare_ablated_losses_resid_stream(
        self,
        input_ids_BS: t.Tensor,
        clean_loss: t.Tensor,
        feature_activations_BSF: t.Tensor,
        normalising_constants_BS: t.Tensor,
    ) -> t.Tensor:
        # Swap out the neuron activations for the reconstructed ones using the autoencoder
        raw_codebook_neuron_activations_BSN = self.autoencoder.decode(
            feature_activations_BSF
        )  # raw_x_hat
        codebook_neuron_activations_BSN = self.autoencoder.postprocess(
            raw_codebook_neuron_activations_BSN, normalising_constants_BS
        )  # x_hat

        def recon_acts_hook(module, input, output):
            output = (codebook_neuron_activations_BSN,) + output[1:]
            return output

        def zero_ablated_hook(module, input, output):
            output = (t.zeros_like(output[0]),) + output[1:]
            return output

        # ## Get codebook ablated logits

        handle_recon = self.transformer.transformer.h[
            self.layer_to_hook
        ].register_forward_hook(recon_acts_hook)

        codebook_ablated_output: CausalLMOutputWithCrossAttentions = self.transformer(
            input_ids_BS
        )
        codebook_ablated_logits_BSV = codebook_ablated_output.logits

        handle_recon.remove()

        ## Get zero ablated logits
        handle_zero = self.transformer.transformer.h[self.layer_to_hook].register_forward_hook(
            zero_ablated_hook
        )

        zero_ablated_output: CausalLMOutputWithCrossAttentions = self.transformer(input_ids_BS)
        zero_ablated_logits_BSV = zero_ablated_output.logits

        handle_zero.remove()

        ## Rearranging and calculating losses
        codebook_ablated_logits_BsV = rearrange(
            codebook_ablated_logits_BSV[:, :-1, :], "b s v -> (b s) v"
        )

        zero_ablated_logits_BsV = rearrange(
            zero_ablated_logits_BSV[:, :-1, :], "b s v -> (b s) v"
        )

        flattened_labels = rearrange(input_ids_BS[:, 1:], "b s -> (b s)")

        zero_ablated_loss = F.cross_entropy(
            zero_ablated_logits_BsV, flattened_labels, reduction="mean"
        )
        codebook_ablated_loss = F.cross_entropy(
            codebook_ablated_logits_BsV, flattened_labels, reduction="mean"
        )

        # Downstream loss recovered is the proportion of the downstream loss that is recovered with the codebook
        # using the loss recovered when the whole MLP is ablated as a baseline
        downstream_loss_recovered = 1 - (clean_loss - codebook_ablated_loss) / (
            clean_loss - zero_ablated_loss
        )

        return downstream_loss_recovered

    def _l0_sparsity(self, feature_activations_BSF: t.Tensor):
        batch_size, seq_len, _num_features = feature_activations_BSF.shape

        num_active_features = t.count_nonzero(feature_activations_BSF)
        active_features_per_token = num_active_features / (batch_size * seq_len)
        return active_features_per_token
