from abc import ABC, abstractmethod

import torch
import torch.nn as nn
from torch import Tensor


class ConditionalGenerator(nn.Module):
    class Encoder(nn.Module, ABC):
        def __init__(self):
            super().__init__()

        @abstractmethod
        def forward(self, x: torch.Tensor, c: torch.Tensor):
            pass

        def freeze(self):
            for param in self.parameters():
                param.requires_grad = False

        def unfreeze(self):
            for param in self.parameters():
                param.requires_grad = True

    class Decoder(nn.Module, ABC):
        def __init__(self):
            super().__init__()

        @abstractmethod
        def forward(self, z: torch.Tensor, condition: torch.Tensor):
            pass

        def freeze(self):
            for param in self.parameters():
                param.requires_grad = False

        def unfreeze(self):
            for param in self.parameters():
                param.requires_grad = True

    def __init__(self, latent_dim: int, condition_shape: torch.Size):
        super().__init__()
        self.latent_dim = latent_dim
        self.condition_shape = condition_shape
        self.encoder: ConditionalGenerator.Encoder = None
        self.decoder: ConditionalGenerator.Decoder = None

    def forward(self, x: torch.Tensor, condition: torch.Tensor):
        z = self.encoder(x, condition)
        x_hat = self.decoder(z, condition)
        return x_hat


class ConditionalDiscriminator(nn.Module):
    def __init__(self, latent_dim: int, condition_shape: torch.Size):
        super().__init__()
        self.latent_dim = latent_dim
        self.condition_shape = condition_shape
        self.encoder: torch.nn.Module = None

    @abstractmethod
    def forward(self, x: torch.Tensor, condition: torch.Tensor):
        y = self.encoder(x, condition)
        return y

    def freeze(self):
        for param in self.parameters():
            param.requires_grad = False

    def unfreeze(self):
        for param in self.parameters():
            param.requires_grad = True

    @abstractmethod
    def parameterize(self):
        pass


class ConceptNN(nn.Module):
    def __init__(self, concepts: int, ):
        super().__init__()
        self.n_concepts = concepts

    @abstractmethod
    def parameterize(self):
        pass


class BatchNorm2d(torch.nn.BatchNorm2d):
    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            self.weight.data.fill_(1.0)
            self.bias.data.zero_()

    def forward(self, input: Tensor) -> Tensor:
        if (self.training or not self.track_running_stats) and len(input) <= 1:
            # Minibatch size is 1 --> unstable minibatch mean and variance.
            # Solution: Do regular forward pass with training==True but without 
            # gradients to update the running statistics.
            # Afterwards, since len(input)>1==False, use the running statistics 
            # as if BatchNorm ran in eval mode.
            original_training = self.training
            with torch.no_grad():
                out = super().forward(input)
            self.training = False
            out = super().forward(input)
            self.training = original_training
            return out
        else:
            return super().forward(input)


# Following https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/sn_projection_cgan_64x64_143c.ipynb
# and https://github.com/asmadotgh/dissect/blob/3a7a9ff5a8ff8ebf9a6183fcea9a43f8e53106a8/explainer/ops.py#L97
class CategoricalConditionalBatchNorm(torch.nn.Module):
    # as in the chainer SN-GAN implementation, we keep per-cat weight and bias
    def __init__(self, num_features, num_cats, eps=2e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super().__init__()
        self.num_features = num_features
        self.num_cats = num_cats
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = torch.nn.Parameter(torch.Tensor(num_cats, num_features))
            self.bias = torch.nn.Parameter(torch.Tensor(num_cats, num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
            self.num_batches_tracked.zero_()

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            self.weight.data.fill_(1.0)
            self.bias.data.zero_()

    def forward(self, input, cats):
        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            self.num_batches_tracked += 1
            if self.momentum is None:  # use cumulative moving average
                exponential_average_factor = 1.0 / self.num_batches_tracked.item()
            else:  # use exponential moving average
                exponential_average_factor = self.momentum

        if (self.training or not self.track_running_stats) and len(input) <= 1:
            # Minibatch size is 1 --> unstable minibatch mean and variance.
            # Solution: Do regular forward pass with training==True but without 
            # gradients to update the running statistics.
            # Afterwards, since len(input)>1==False, use the running statistics 
            # as if BatchNorm ran in eval mode.
            with torch.no_grad():
                out = torch.nn.functional.batch_norm(
                    input, self.running_mean, self.running_var, None, None,
                    self.training or not self.track_running_stats,
                    exponential_average_factor, self.eps
                )
        out = torch.nn.functional.batch_norm(
            input, self.running_mean, self.running_var, None, None,
            len(input) > 1 and (self.training or not self.track_running_stats),
            exponential_average_factor, self.eps
        )
        if self.affine:
            shape = [input.size(0), self.num_features] + (input.dim() - 2) * [1]
            weight = self.weight.index_select(0, cats).view(shape)
            bias = self.bias.index_select(0, cats).view(shape)
            out = out * weight + bias
        return out

    def extra_repr(self):
        return '{num_features}, num_cats={num_cats}, eps={eps}, momentum={momentum}, affine={affine}, ' \
               'track_running_stats={track_running_stats}'.format(**self.__dict__)


class ConditionalGAN(ConditionalGenerator):
    pass


class ADNN(nn.Module):
    @abstractmethod
    def forward(self, x: torch.Tensor, return_encoding: bool = False):
        raise NotImplementedError()

