from typing import Optional, Union

import torch as t
from loguru import logger
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam

from auto_encoder import device
from auto_encoder.config import AutoEncoderConfig
from auto_encoder.helpers.ae_metrics import AutoEncoderMetrics
from auto_encoder.helpers.ae_output_types import MutualChoiceAEOutput, SupermodelOutput
from auto_encoder.models.base_ae import TopKOutput
from auto_encoder.models.vanilla_ae import VanillaAE


class TopKSAE(VanillaAE):
    def __init__(
        self,
        config: AutoEncoderConfig,
        medoid_initial_tensor_N: Optional[t.Tensor],
        preprocess_scaling_factor: Optional[float],
        device: str = device,
    ):
        super().__init__(config, medoid_initial_tensor_N, preprocess_scaling_factor, device)

        ## ENCODER
        self.encoder = nn.Linear(config.num_neurons, config.num_features, bias=False)
        # Initialise the weights of the encoder as the decoder.T
        decoder_weight = self.decoder.weight.data.clone()
        self.encoder.weight = nn.Parameter(decoder_weight.T * 0.1)

        self.feature_bias_F = nn.Parameter(t.zeros(config.num_features))

        ## ACTIVATION FUNCTION
        self.topk = min(self.topk, config.num_features)
        logger.info(f"Using TopK activation function with topk={config.topk}")

    def encode(self, x: t.Tensor) -> t.Tensor:
        return self.encoder(x) + self.feature_bias_F

    def activation_fn(self, pre_activation_features_BSF: t.Tensor) -> TopKOutput:
        pre_activation_features_BSF = F.gelu(pre_activation_features_BSF)

        topk_output = self.topk_maybe_stochastic(pre_activation_features_BSF)

        return topk_output

    def forward(
        self,
        x_BSN: t.Tensor,
        output_intermediate_activations: bool = False,
        output_supermodel_output: bool = False,
    ) -> Union[MutualChoiceAEOutput, SupermodelOutput, t.Tensor]:

        x_center_BSN, normalising_constants_BS = self.preprocess(x_BSN)

        pre_activation_features_BSF = self.encode(x_center_BSN)

        top_k_output = self.activation_fn(pre_activation_features_BSF)
        features_BSF = top_k_output.features_BSF

        raw_x_hat_BSN = self.decode(features_BSF)

        x_hat_BSN = self.postprocess(raw_x_hat_BSN, normalising_constants_BS)

        if not output_intermediate_activations and not output_supermodel_output:
            return x_hat_BSN

        ### UPDATE FEATURE ACTIVITY
        self.update_feature_activity_(features_BSF)

        ### COMPUTE LOSSES
        l1_sparsity_loss = self.l1_loss(features_BSF)

        aux_dead_reconstruction_loss, aux_dying_reconstruction_loss = self.aux_k_losses(
            x_BSN, x_hat_BSN, pre_activation_features_BSF
        )

        zero_tensor = t.tensor(0, dtype=x_BSN.dtype, device=x_BSN.device)
        nfm_loss, nfm_inf_loss = zero_tensor, zero_tensor.clone()

        ae_output = MutualChoiceAEOutput(
            x_hat_BSN=x_hat_BSN,
            features_BSF=features_BSF,
            normalising_constants_BS=normalising_constants_BS,
            l1_sparsity_loss=l1_sparsity_loss,
            aux_dead_loss=aux_dead_reconstruction_loss,
            aux_dying_loss=aux_dying_reconstruction_loss,
            nfm_loss=nfm_loss,
            nfm_inf_loss=nfm_inf_loss,
        )

        if not output_supermodel_output:
            return ae_output

        mse_reconstruction_loss = F.mse_loss(x_hat_BSN, x_BSN)
        l0_sparsity_metric = self._l0_sparsity(features_BSF)

        l1_sparsity_loss = l1_sparsity_loss * self.config.auxiliary_l1_sparsity_coef
        aux_dead_reconstruction_loss = (
            aux_dead_reconstruction_loss * self.config.auxiliary_dead_loss_coef
        )
        aux_dying_reconstruction_loss = (
            aux_dying_reconstruction_loss * self.config.auxiliary_dying_loss_coef
        )

        overall_loss = (
            mse_reconstruction_loss
            + l1_sparsity_loss
            + aux_dead_reconstruction_loss
            + aux_dying_reconstruction_loss
        )

        metrics = AutoEncoderMetrics(
            overall_loss=overall_loss.item(),
            mse_reconstruction_loss=mse_reconstruction_loss.item(),
            l0_sparsity_metric=l0_sparsity_metric.item(),
            l1_sparsity_loss=l1_sparsity_loss.item(),
            nfm_loss=nfm_loss.item(),
            nfm_inf_loss=nfm_inf_loss.item(),
            aux_dead_reconstruction_loss=aux_dead_reconstruction_loss.item(),
            aux_dying_reconstruction_loss=aux_dying_reconstruction_loss.item(),
        )

        return SupermodelOutput(
            scalar_loss=overall_loss,
            feature_activations_BSF=features_BSF,
            metrics=metrics,
            initial_neuron_activations_BSN=x_BSN,
            reconstructed_neuron_activations_BSN=x_hat_BSN,
            normalising_constants_BS=normalising_constants_BS,
            batched_loss_BS=None,
        )


if __name__ == "__main__":
    # import argparse

    # parser = argparse.ArgumentParser()
    # parser.add_argument("--gpu", type=int, default=0)
    # args = parser.parse_args()
    # device = f"cuda:{args.gpu}"

    # print("Using device", device)
    BATCH_SIZE = 3
    SEQ_LEN = 10
    NUM_NEURONS = 15
    NUM_FEATURES = 30

    model = TopKSAE(
        AutoEncoderConfig(
            num_neurons=NUM_NEURONS,
            num_features=NUM_FEATURES,
            batch_size=BATCH_SIZE,
            seq_len=SEQ_LEN,
        ),
        medoid_initial_tensor_N=None,
        preprocess_scaling_factor=None,
        device=device,
    ).to(device)
    x = t.randn(BATCH_SIZE, SEQ_LEN, NUM_NEURONS).to(device)

    print(x[0])
    print(model(x)[0])

    optimizer = Adam(model.parameters(), lr=1e-3)

    loss = t.empty((1,))
    for _ in range(100):
        optimizer.zero_grad()
        reconstructed_neuron_acts = model(x)

        mse_loss = F.mse_loss(reconstructed_neuron_acts, x)
        loss = mse_loss

        loss.backward()
        optimizer.step()
        print("Loss", loss.item())

    print("Final loss", loss)
    # print("cuda visible devices", os.environ["CUDA_VISIBLE_DEVICES"])
