from typing import Optional, Union

import torch as t
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam

from auto_encoder import device
from auto_encoder.config import AutoEncoderConfig
from auto_encoder.helpers.ae_metrics import AutoEncoderMetrics
from auto_encoder.helpers.ae_output_types import SupermodelOutput, VanillaAEOutput
from auto_encoder.models.base_ae import AutoEncoderBase


class VanillaAE(AutoEncoderBase):
    def __init__(
        self,
        config: AutoEncoderConfig,
        medoid_initial_tensor_N: Optional[t.Tensor],
        preprocess_scaling_factor: Optional[float],
        device: str = device,
    ):
        super().__init__(config, medoid_initial_tensor_N, preprocess_scaling_factor, device)

        ## ENCODER
        self.encoder = nn.Linear(config.num_neurons, config.num_features, bias=False)
        # Initialise the weights of the encoder as the decoder.T
        decoder_weight_NF = self.decoder.weight.data.clone()
        self.encoder.weight = nn.Parameter(decoder_weight_NF.T * 0.1)

        self.feature_bias_F = nn.Parameter(t.zeros(config.num_features))

        ## ACTIVATION FUNCTION
        self.relu = nn.ReLU()

    def encode(self, x: t.Tensor) -> t.Tensor:
        return self.encoder(x) + self.feature_bias_F

    def activation_fn(self, pre_activation_features: t.Tensor) -> t.Tensor:
        return self.relu(pre_activation_features)

    def forward(
        self,
        x_BSN: t.Tensor,
        output_intermediate_activations: bool = False,
        output_supermodel_output: bool = False,
    ) -> Union[VanillaAEOutput, SupermodelOutput, t.Tensor]:

        x_center_BSN, normalising_constants_BS = self.preprocess(x_BSN)

        pre_activation_features_BSF = self.encode(x_center_BSN)
        features_BSF = self.activation_fn(pre_activation_features_BSF)
        raw_x_hat_BSN = self.decode(features_BSF)

        x_hat_BSN = self.postprocess(raw_x_hat_BSN, normalising_constants_BS)

        if not output_intermediate_activations and not output_supermodel_output:
            return x_hat_BSN

        ### UPDATE FEATURE ACTIVITY
        self.update_feature_activity_(features_BSF)

        ### COMPUTE LOSSES
        l1_sparsity_loss = self.l1_loss(features_BSF)

        if not output_supermodel_output:
            return VanillaAEOutput(
                x_hat_BSN=x_hat_BSN,
                features_BSF=features_BSF,
                normalising_constants_BS=normalising_constants_BS,
            )

        mse_reconstruction_loss = F.mse_loss(x_hat_BSN, x_BSN)
        l0_sparsity_metric = self._l0_sparsity(features_BSF)

        l1_sparsity_loss = l1_sparsity_loss * self.config.auxiliary_l1_sparsity_coef

        overall_loss = mse_reconstruction_loss + l1_sparsity_loss

        metrics = AutoEncoderMetrics(
            overall_loss=overall_loss.item(),
            mse_reconstruction_loss=mse_reconstruction_loss.item(),
            l0_sparsity_metric=l0_sparsity_metric.item(),
            l1_sparsity_loss=l1_sparsity_loss.item(),
        )

        return SupermodelOutput(
            scalar_loss=overall_loss,
            feature_activations_BSF=features_BSF,
            metrics=metrics,
            initial_neuron_activations_BSN=x_BSN,
            reconstructed_neuron_activations_BSN=x_hat_BSN,
            normalising_constants_BS=normalising_constants_BS,
            batched_loss_BS=None,
        )


if __name__ == "__main__":
    # import argparse

    # parser = argparse.ArgumentParser()
    # parser.add_argument("--gpu", type=int, default=0)
    # args = parser.parse_args()
    # device = f"cuda:{args.gpu}"

    # print("Using device", device)

    model = VanillaAE(
        AutoEncoderConfig(num_neurons=15, num_features=30),
        medoid_initial_tensor_N=None,
        preprocess_scaling_factor=None,
    )
    x = t.randn(3, 10, 15).to(device)

    print(x[0])
    print(model(x)[0])

    # loss = t.empty((1,))
    # for _ in range(100):
    #     optimiser.zero_grad()
    #     reconstructed_neuron_acts, feature_acts = model(neuron_activations)

    #     loss = autoencoder_loss(
    #         reconstructed_neuron_acts, neuron_activations, feature_acts, device=device
    #     )
    #     loss.backward()
    #     optimiser.step()
    #     print("Loss", loss.item())

    # print("Final loss", loss)
    # print("cuda visible devices", os.environ["CUDA_VISIBLE_DEVICES"])
