from typing import Optional, Union

import torch as t
from einops import rearrange
from loguru import logger
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam

from auto_encoder import device
from auto_encoder.config import AutoEncoderConfig
from auto_encoder.helpers.ae_metrics import AutoEncoderMetrics
from auto_encoder.helpers.ae_output_types import MutualChoiceAEOutput, SupermodelOutput
from auto_encoder.models.vanilla_ae import VanillaAE


class MutualChoiceSAE(VanillaAE):
    def __init__(
        self,
        config: AutoEncoderConfig,
        medoid_initial_tensor_N: Optional[t.Tensor],
        preprocess_scaling_factor: Optional[float],
        device: str = device,
    ):
        super().__init__(config, medoid_initial_tensor_N, preprocess_scaling_factor, device)

        ## ENCODER
        self.encoder = nn.Linear(config.num_neurons, config.num_features, bias=False)
        # Initialise the weights of the encoder as the decoder.T
        decoder_weight = self.decoder.weight.data.clone()
        self.encoder.weight = nn.Parameter(decoder_weight.T * 0.1)

        self.feature_bias_F = nn.Parameter(t.zeros(config.num_features))

        ## ACTIVATION FUNCTION
        self.topk = min(self.topk, config.num_features)
        logger.info(f"Using TopK activation function with topk={config.topk}")

        self.feature_choice_likelihood = config.feature_choice_likelihood

    def encode(self, x: t.Tensor) -> t.Tensor:
        return self.encoder(x) + self.feature_bias_F

    def _feature_choice_act_fn(self, pre_activation_features_BSF: t.Tensor) -> t.Tensor:

        topm_output = self.top_m_maybe_stochastic(pre_activation_features_BSF)
        features_BSF = topm_output.features_BSF

        return features_BSF

    def _mutual_choice_act_fn(self, pre_activation_features_BSF: t.Tensor) -> t.Tensor:
        batch_size, seq_len, num_features = pre_activation_features_BSF.shape

        pre_activation_features_Bsf = rearrange(
            pre_activation_features_BSF,
            "batch seq_len num_features -> (batch seq_len num_features)",
        )
        batch_topk_out = t.topk(pre_activation_features_Bsf, self.topk * batch_size * seq_len)
        features_Bsf = t.zeros_like(pre_activation_features_Bsf)
        features_Bsf = features_Bsf.scatter(-1, batch_topk_out.indices, batch_topk_out.values)

        features_BSF = rearrange(
            features_Bsf,
            "(batch seq_len num_features) -> batch seq_len num_features",
            seq_len=seq_len,
            num_features=num_features,
        )

        return features_BSF

    def activation_fn(self, pre_activation_features_BSF: t.Tensor) -> t.Tensor:

        pre_activation_features_BSF = F.gelu(pre_activation_features_BSF)

        # use_feature_choice = t.rand(1) > self.feature_choice_likelihood
        if self.feature_choice_likelihood == 1.0:
            use_feature_choice = True
        elif self.feature_choice_likelihood == 0.0:
            use_feature_choice = False
        else:
            # use_feature_choice = t.rand(1) > self.feature_choice_likelihood
            raise NotImplementedError(
                "Feature choice likelihood randomisation not implemented"
            )

        if use_feature_choice:
            features_BSF = self._feature_choice_act_fn(pre_activation_features_BSF)
        else:
            features_BSF = self._mutual_choice_act_fn(pre_activation_features_BSF)

        return features_BSF

    def forward(
        self,
        x_BSN: t.Tensor,
        output_intermediate_activations: bool = False,
        output_supermodel_output: bool = False,
    ) -> Union[MutualChoiceAEOutput, SupermodelOutput, t.Tensor]:

        ### RUN THROUGH AUTOENCODER
        x_center_BSN, normalising_constants_BS = self.preprocess(x_BSN)

        pre_activation_features_BSF = self.encode(x_center_BSN)

        features_BSF = self.activation_fn(pre_activation_features_BSF)

        raw_x_hat_BSN = self.decode(features_BSF)

        x_hat_BSN = self.postprocess(raw_x_hat_BSN, normalising_constants_BS)

        if not output_intermediate_activations and not output_supermodel_output:
            return x_hat_BSN

        ### UPDATE FEATURE ACTIVITY
        self.update_feature_activity_(features_BSF)

        ### COMPUTE LOSSES
        aux_dead_reconstruction_loss, aux_dying_reconstruction_loss = self.aux_k_losses(
            x_BSN, x_hat_BSN, pre_activation_features_BSF
        )
        l1_sparsity_loss = self.l1_loss(features_BSF)

        # zero_tensor = t.tensor(0, dtype=x_BSN.dtype, device=x_BSN.device)
        # nfm_loss, nfm_inf_loss = zero_tensor, zero_tensor.clone()
        nfm_loss, nfm_inf_loss = self.nfm_losses()

        ae_output = MutualChoiceAEOutput(
            x_hat_BSN=x_hat_BSN,
            features_BSF=features_BSF,
            normalising_constants_BS=normalising_constants_BS,
            l1_sparsity_loss=l1_sparsity_loss,
            aux_dead_loss=aux_dead_reconstruction_loss,
            aux_dying_loss=aux_dying_reconstruction_loss,
            nfm_loss=nfm_loss,
            nfm_inf_loss=nfm_inf_loss,
        )

        if not output_supermodel_output:
            return ae_output

        mse_reconstruction_loss = F.mse_loss(x_hat_BSN, x_BSN)
        l0_sparsity_metric = self._l0_sparsity(features_BSF)

        l1_sparsity_loss = l1_sparsity_loss * self.config.auxiliary_l1_sparsity_coef
        aux_dead_reconstruction_loss = (
            aux_dead_reconstruction_loss * self.config.auxiliary_dead_loss_coef
        )
        aux_dying_reconstruction_loss = (
            aux_dying_reconstruction_loss * self.config.auxiliary_dying_loss_coef
        )
        nfm_loss = nfm_loss * self.config.auxiliary_nfm_loss_coef
        nfm_inf_loss = nfm_inf_loss * self.config.auxiliary_nfm_inf_loss_coef

        overall_loss = (
            mse_reconstruction_loss
            + l1_sparsity_loss
            + aux_dead_reconstruction_loss
            + aux_dying_reconstruction_loss
        )

        metrics = AutoEncoderMetrics(
            overall_loss=overall_loss.item(),
            mse_reconstruction_loss=mse_reconstruction_loss.item(),
            l0_sparsity_metric=l0_sparsity_metric.item(),
            l1_sparsity_loss=l1_sparsity_loss.item(),
            aux_dead_reconstruction_loss=aux_dead_reconstruction_loss.item(),
            aux_dying_reconstruction_loss=aux_dying_reconstruction_loss.item(),
            nfm_loss=nfm_loss.item(),
            nfm_inf_loss=nfm_inf_loss.item(),
        )

        return SupermodelOutput(
            scalar_loss=overall_loss,
            feature_activations_BSF=features_BSF,
            metrics=metrics,
            initial_neuron_activations_BSN=x_BSN,
            reconstructed_neuron_activations_BSN=x_hat_BSN,
            normalising_constants_BS=normalising_constants_BS,
            batched_loss_BS=None,
        )

    def reorder_features_(self, features_hit_order_F: t.Tensor) -> None:
        self.encoder.weight.data = self.encoder.weight.data[features_hit_order_F]
        self.feature_bias_F.data = self.feature_bias_F.data[features_hit_order_F]
        self.decoder.weight.data = self.decoder.weight.data[:, features_hit_order_F]


if __name__ == "__main__":
    # import argparse

    # parser = argparse.ArgumentParser()
    # parser.add_argument("--gpu", type=int, default=0)
    # args = parser.parse_args()
    # device = f"cuda:{args.gpu}"

    # print("Using device", device)
    BATCH_SIZE = 3
    SEQ_LEN = 10
    NUM_NEURONS = 15
    NUM_FEATURES = 30

    model = MutualChoiceSAE(
        AutoEncoderConfig(
            num_neurons=NUM_NEURONS,
            num_features=NUM_FEATURES,
            batch_size=BATCH_SIZE,
            seq_len=SEQ_LEN,
        ),
        medoid_initial_tensor_N=None,
        preprocess_scaling_factor=None,
        device=device,
    ).to(device)
    x = t.randn(BATCH_SIZE, SEQ_LEN, NUM_NEURONS).to(device)

    print(x[0])
    print(model(x)[0])

    optimizer = Adam(model.parameters(), lr=1e-3)

    loss = t.empty((1,))
    for _ in range(100):
        optimizer.zero_grad()
        reconstructed_neuron_acts = model(x)

        mse_loss = F.mse_loss(reconstructed_neuron_acts, x)
        loss = mse_loss

        loss.backward()
        optimizer.step()
        print("Loss", loss.item())

    print("Final loss", loss)
    # print("cuda visible devices", os.environ["CUDA_VISIBLE_DEVICES"])
