import torch
from torch import nn

from XXX.uib.utils.safe_module import SafeModule
from experiments.models import stochastic_dropout
from experiments.models import stochastic_model


class NoDropoutModel(SafeModule):
    C: int

    def __init__(self, C: int, *, expansion: int = 1):
        super().__init__()

        self.C = C

        num_units = 1024 * expansion
        self.seqs = nn.Sequential(
            nn.Linear(784, num_units), nn.ReLU(), nn.Linear(num_units, num_units), nn.ReLU(), nn.Linear(num_units, C)
        )

    def safe_forward(self, x):
        x = torch.flatten(x, 1)
        x = self.seqs(x)
        # x = torch.nn.functional.softmax(x, dim=1)
        return x


class DropoutModel(SafeModule):
    C: int

    def __init__(self, C: int, *, expansion: int = 1):
        super().__init__()

        self.C = C

        num_units = 1024 * expansion
        self.seqs = nn.Sequential(
            nn.Linear(784, num_units),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(num_units, num_units),
            nn.ReLU(),
            nn.Dropout2d(0.25),
            nn.Linear(num_units, C),
        )

    def safe_forward(self, x):
        x = torch.flatten(x, 1)
        x = self.seqs(x)
        # x = torch.nn.functional.softmax(x, dim=1)
        return x


class StochasticDropoutModel(stochastic_model.StochasticModel):
    C: int

    def __init__(self, C: int, *, num_samples, dropout_rate=0.25, expansion: int = 1):
        super().__init__(num_samples)

        self.C = C

        num_units = 1024 * expansion
        self.seqs = nn.Sequential(
            nn.Linear(784, num_units),
            nn.ReLU(),
            stochastic_dropout.StochasticDropout(dropout_rate),
            nn.Linear(num_units, num_units),
            nn.ReLU(),
            stochastic_dropout.StochasticDropout(dropout_rate),
            nn.Linear(num_units, C),
        )

    def stochastic_forward_impl(self, x: torch.Tensor):
        x = torch.flatten(x, 1)
        x = self.seqs(x)
        # output = F.log_softmax(x, dim=1)
        return x


class DeterministicModelResFC(SafeModule):
    out_capacity: int

    def __init__(self, in_capacity:int, out_capacity: int, *, width, num_blocks, batch_norm=True):
        super().__init__()

        self.out_capacity = out_capacity

        if batch_norm:
            batch_norm_layer = nn.BatchNorm1d
        else:
            batch_norm_layer = nn.Identity

        self.initial_bottleneck = nn.Linear(in_capacity, width)

        self.blocks = nn.ModuleList(nn.Sequential(
            batch_norm_layer(width),
            nn.ReLU(),
            nn.Linear(width, width),
            batch_norm_layer(width),
            nn.ReLU(),
            nn.Linear(width, width),
        ) for _ in range(num_blocks))

        self.encoder = nn.Sequential(batch_norm_layer(width), nn.Linear(width, out_capacity))

    def safe_forward(self, x: torch.Tensor):
        x = torch.flatten(x, 1)
        x = self.initial_bottleneck(x)
        for block in self.blocks:
            x = x + block(x)
        x = self.encoder(x)
        return x


class StochasticDropoutModelResFC(stochastic_model.StochasticModel):
    C: int

    def __init__(self, C: int, *, width, num_blocks, num_samples, dropout_rate=0.25, batch_norm=True):
        super().__init__(num_samples)

        self.C = C

        if batch_norm:
            batch_norm_layer = nn.BatchNorm1d
        else:
            batch_norm_layer = nn.Identity

        self.initial_bottleneck = nn.Linear(784, width)

        self.blocks = nn.ModuleList(nn.Sequential(
            batch_norm_layer(width),
            nn.ReLU(),
            stochastic_dropout.StochasticDropout(dropout_rate),
            nn.Linear(width, width),
            batch_norm_layer(width),
            nn.ReLU(),
            stochastic_dropout.StochasticDropout(dropout_rate),
            nn.Linear(width, width),
        ) for _ in range(num_blocks))

        self.encoder = nn.Sequential(batch_norm_layer(width), nn.Linear(width, C))

    def stochastic_forward_impl(self, x: torch.Tensor):
        x = torch.flatten(x, 1)
        x = self.initial_bottleneck(x)
        for block in self.blocks:
            x = x + block(x)
        x = self.encoder(x)
        return x


class StochasticDropoutModelDVIBStyle(stochastic_model.StochasticModel):
    C: int

    def __init__(self, out_capacity: int, *, num_samples, latent_capacity=512, dropout_rate=0.25, expansion: int = 1):
        super().__init__(num_samples)

        self.latent_capacity = latent_capacity
        self.out_capacity = out_capacity

        num_units = 1024 * expansion
        self.encoder = nn.Sequential(
            nn.Linear(784, num_units),
            nn.ReLU(),
            stochastic_dropout.StochasticDropout(dropout_rate),
            nn.Linear(num_units, num_units),
            nn.ReLU(),
            stochastic_dropout.StochasticDropout(dropout_rate),
            nn.Linear(num_units, latent_capacity),
        )
        self.decoder = nn.Linear(latent_capacity, out_capacity)

    def stochastic_forward_impl(self, x: torch.Tensor):
        x = torch.flatten(x, 1)
        x = self.encoder(x)
        x = self.decoder(x)
        # output = F.log_softmax(x, dim=1)
        return x


if __name__ == "__main__":
    from experiments.utils import print_module

    print_module(DropoutModel(10))

    print_module(NoDropoutModel(10))

    print_module(StochasticDropoutModel(10, num_samples=10))

    print_module(StochasticDropoutModelResFC(10, num_samples=10, width=512))
