import os

import torch
import torch.distributions as dist
import torch.nn as nn


def calc_input_dims(tensor_dataset, model_config, forward_pass_use_groups):
    gene_expression_dim = tensor_dataset.tensors[0].shape[1]
    # Set up encoder/decoder from possible inputs (X, c, d) and (z, c, d)
    if model_config.model in ["cvae", "contrastive_cvae", "trvae", "tr_cvamp"]:
        assert len(tensor_dataset.tensors) in [2, 3]
        label_dims = tensor_dataset.tensors[1].shape[1]
        if len(tensor_dataset.tensors) == 3 and forward_pass_use_groups:
            label_dims += tensor_dataset.tensors[2].shape[1]
        encoder_input_dim = gene_expression_dim + label_dims
        decoder_input_dim = model_config.latent_dim + label_dims
    elif model_config.model in ["vae"]:
        encoder_input_dim = gene_expression_dim
        decoder_input_dim = model_config.latent_dim
    else:
        assert False, f"{model_config.model} is not handled when creating encoder"
    return encoder_input_dim, decoder_input_dim


def _make_batchnorm_layer(input_dim: int, use_batchnorm: bool) -> nn.Module:
    if use_batchnorm:
        return nn.BatchNorm1d(input_dim)
    else:
        return nn.Identity()


class Encoder(nn.Module):
    """Encoder used in all models, consisting of a multilayer perceptron with RELU activations"""
    def __init__(self, x_dim, z_dim, hidden_dim, learn_sigma="fix_all", n_layers=1, use_batchnorm=False, bandwidth=0.1):
        super().__init__()
        self.n_layers = n_layers
        # setup the three linear transformations used
        if n_layers == 0:
            # Linear Encoder
            self.fc21 = nn.Linear(x_dim, z_dim) # mean
            self.fc22 = nn.Linear(x_dim, z_dim) # scale
        else:
            # Non-linear Encoder
            self.fc1 = nn.Linear(x_dim, hidden_dim)
            self.batchnorm1 = _make_batchnorm_layer(hidden_dim, use_batchnorm)
            self.fc21 = nn.Linear(hidden_dim, z_dim) # mean
            self.fc22 = nn.Linear(hidden_dim, z_dim) # scale

        # setup the non-linearities
        self.relu = nn.LeakyReLU()
        self.learn_sigma = learn_sigma
        self.sigmoid = nn.Sigmoid()
        self.bandwidth = bandwidth

        if n_layers > 1:
            self.middle: nn.Module = nn.Sequential(
                *[
                    nn.Sequential(
                        nn.Linear(hidden_dim, hidden_dim),
                        _make_batchnorm_layer(hidden_dim, use_batchnorm),
                        self.relu,
                    )
                    for i in range(n_layers - 1)
                ]
            )
        else:
            self.middle = nn.Identity()

    def forward(self, x):
        if self.n_layers == 0:
            hidden = x
        else:
            hidden = self.relu(self.batchnorm1(self.fc1(x)))
            hidden = self.middle(hidden)

        z_loc = self.fc21(hidden)
        if self.learn_sigma == "fix_all":
            z_scale = self.bandwidth * torch.ones(z_loc.shape, device=z_loc.device)
            posterior = dist.Normal(z_loc, z_scale)
            penalty_dist = posterior
        elif self.learn_sigma == "learn_all":
            z_scale = 1e-4 + (10. - 1e-4) * self.sigmoid(self.fc22(hidden))
            posterior = dist.Normal(z_loc, z_scale)
            penalty_dist = posterior
        elif self.learn_sigma == "learn_but_decouple":
            z_scale = 1e-4 + (10. - 1e-4) * self.sigmoid(self.fc22(hidden))
            z_scale_nograd = z_scale.detach()
            posterior = dist.Normal(z_loc, z_scale)
            penalty_dist = dist.Normal(z_loc, z_scale_nograd)
        elif self.learn_sigma == "learn_elbo_fix_penalty":
            z_scale = 1e-4 + (10. - 1e-4) * self.sigmoid(self.fc22(hidden))
            z_scale_penalty = self.bandwidth * torch.ones(z_loc.shape, device=z_loc.device)
            posterior = dist.Normal(z_loc, z_scale)
            penalty_dist = dist.Normal(z_loc, z_scale_penalty)
        else:
            raise ValueError(f"learn_sigma option ({self.learn_sigma}) not recognised")
        return posterior, penalty_dist


class MultinomialDecoder(nn.Module):
    # TODO: Introduce Dropout and Batchnorm and LeakyRelu
    def __init__(
        self,
        x_dim,
        z_dim,
        hidden_dim,
        n_layers=1,
        baseline=None,
        return_hidden=False,
        use_batchnorm=False,
    ):
        super().__init__()
        self.n_layers = n_layers
        self.relu = nn.LeakyReLU()
        self.return_hidden = return_hidden
        self.batchnorm1 = _make_batchnorm_layer(hidden_dim, use_batchnorm)
        if n_layers == 0:
            # Linear Decoder
            self.fc1 = nn.Linear(z_dim, x_dim)
        else:
            # Non-linear Decoder
            self.fc1 = nn.Linear(z_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, x_dim)
            if n_layers > 1:
                # self.middle = nn.Sequential(
                #     *[
                #         nn.Sequential(nn.Linear(hidden_dim, hidden_dim), self.relu)
                #         for _ in range(n_layers - 1)
                #     ]
                # )
                self.middle = nn.Sequential(
                    *[
                        nn.Sequential(
                            nn.Linear(hidden_dim, hidden_dim),
                            _make_batchnorm_layer(hidden_dim, use_batchnorm),
                            self.relu,
                        )
                        for i in range(n_layers - 1)
                    ]
                )
            else:
                self.middle = nn.Identity()

        if baseline is not None:
            self.baseline = baseline.clamp(min=-1e16)
        else:
            self.baseline = torch.zeros(x_dim)

    def forward(self, z):
        if self.n_layers == 0:
            # Linear Decoder
            logits_minus_baseline = self.fc1(z)
            logits = logits_minus_baseline + self.baseline
            if self.return_hidden:
                return (
                    dist.Multinomial(logits=logits, validate_args=False),
                    logits_minus_baseline,
                )
            else:
                return dist.Multinomial(logits=logits, validate_args=False)
        else:
            # Non-linear Decoder
            hidden_0 = self.relu(self.batchnorm1(self.fc1(z)))
            hidden = self.middle(hidden_0)
            logits = self.fc2(hidden) + self.baseline

            if self.return_hidden:
                return (
                    dist.Multinomial(logits=logits, validate_args=False),
                    hidden_0,
                )  # Aaron: Needs updating with multilayer
            else:
                return dist.Multinomial(logits=logits, validate_args=False)


class GaussianDecoder(nn.Module):
    # TODO: Introduce dropout
    # TODO: Have option for Decoder to be Linear if n_layers = 0
    #       In a similar way to the Encoder and MultinomialDecoder
    def __init__(
        self,
        x_dim,
        z_dim,
        hidden_dim,
        n_layers=1,
        return_hidden=False,
        use_batchnorm=False,
    ):
        super().__init__()
        # setup the three linear transformations used
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.batchnorm1 = _make_batchnorm_layer(hidden_dim, use_batchnorm)
        self.fc21 = nn.Linear(hidden_dim, x_dim)
        self.fc22 = nn.Linear(hidden_dim, 1)
        self.relu = nn.LeakyReLU()  # To match TrVAE
        self.softplus = nn.Softplus()
        self.return_hidden = return_hidden
        if n_layers > 1:
            self.middle = nn.Sequential(
                *[
                    nn.Sequential(
                        nn.Linear(hidden_dim, hidden_dim),
                        _make_batchnorm_layer(hidden_dim, use_batchnorm),
                        self.relu,
                    )
                    for i in range(n_layers - 1)
                ]
            )
        else:
            self.middle = nn.Identity()

    def forward(self, z):
        hidden_0 = self.relu(self.batchnorm1(self.fc1(z)))
        hidden_1 = self.middle(hidden_0)
        xhat = self.fc21(hidden_1)
        x_sd = 1e-2 + self.softplus(self.fc22(hidden_1)).expand(xhat.shape)
        if self.return_hidden:
            return dist.Normal(xhat, x_sd), hidden_0
        else:
            return dist.Normal(xhat, x_sd)
