import torch

from xad.models.bases import CategoricalConditionalBatchNorm, BatchNorm2d


class ResConv2d(torch.nn.Conv2d):
    def __init__(self, *args, xavier_gain: float = 1., **kwargs):
        self.gain = xavier_gain
        super().__init__(*args, **kwargs)

    def reset_parameters(self):
        super().reset_parameters()
        torch.nn.init.xavier_uniform_(self.weight, gain=self.gain)
        torch.nn.init.zeros_(self.bias)


class ResLinear(torch.nn.Linear):
    def __init__(self, *args, xavier_gain: float = 1., **kwargs):
        self.gain = xavier_gain
        super().__init__(*args, **kwargs)

    def reset_parameters(self):
        super().reset_parameters()
        torch.nn.init.xavier_uniform_(self.weight, gain=self.gain)
        torch.nn.init.zeros_(self.bias)


class ResEmbedding(torch.nn.Embedding):
    def __init__(self, *args, xavier_gain: float = 1., **kwargs):
        self.gain = xavier_gain
        super().__init__(*args, **kwargs)

    def reset_parameters(self):
        super().reset_parameters()
        torch.nn.init.xavier_uniform_(self.weight, gain=self.gain)


class ResGenEncoderBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels=None, ksize=3, pad=1,
                 activation=torch.nn.functional.relu, downsample=False, n_classes=0, xavier_gain=2**0.5):
        super().__init__()
        self.activation = activation
        self.downsample = downsample
        self.learnable_sc = in_channels != out_channels or downsample
        hidden_channels = out_channels if hidden_channels is None else hidden_channels
        self.n_classes = n_classes
        self.c1 = ResConv2d(in_channels, hidden_channels, ksize, padding=pad, xavier_gain=xavier_gain)
        self.c2 = ResConv2d(hidden_channels, out_channels, ksize, padding=pad, xavier_gain=xavier_gain)
        if n_classes > 0:
            self.b1 = CategoricalConditionalBatchNorm(in_channels, n_classes)
            self.b2 = CategoricalConditionalBatchNorm(hidden_channels, n_classes)
        else:
            self.b1 = BatchNorm2d(in_channels)
            self.b2 = BatchNorm2d(hidden_channels)
        if self.learnable_sc:
            self.c_sc = ResConv2d(in_channels, out_channels, 1, padding=0, xavier_gain=xavier_gain)

    def forward(self, x, condition):
        h = x
        h = self.b1(h, condition) if condition is not None else self.b1(h)
        h = self.activation(h)
        if self.downsample:
            h = torch.nn.functional.avg_pool2d(h, 2)
        h = self.c1(h)
        h = self.b2(h, condition) if condition is not None else self.b2(h)
        h = self.activation(h)
        h = self.c2(h)
        if self.learnable_sc:
            sc = self.c_sc(x)
            if self.downsample:
                sc = torch.nn.functional.avg_pool2d(sc, 2)
        else:
            sc = x
        return h + sc


class ResGenDecoderBlock(torch.nn.Module):
    # https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/sn_projection_cgan_64x64_143c.ipynb
    def __init__(self, in_channels, out_channels, hidden_channels=None, ksize=3, pad=1,
                 activation=torch.nn.functional.relu, upsample=False, n_classes=0, xavier_gain=2**0.5):
        super().__init__()
        self.activation = activation
        self.upsample = upsample
        self.learnable_sc = in_channels != out_channels or upsample
        hidden_channels = out_channels if hidden_channels is None else hidden_channels
        self.n_classes = n_classes
        self.c1 = ResConv2d(in_channels, hidden_channels, ksize, padding=pad, xavier_gain=xavier_gain)
        self.c2 = ResConv2d(hidden_channels, out_channels, ksize, padding=pad, xavier_gain=xavier_gain)
        if n_classes > 0:
            self.b1 = CategoricalConditionalBatchNorm(in_channels, n_classes)
            self.b2 = CategoricalConditionalBatchNorm(hidden_channels, n_classes)
        else:
            self.b1 = BatchNorm2d(in_channels)
            self.b2 = BatchNorm2d(hidden_channels)
        if self.learnable_sc:
            self.c_sc = ResConv2d(in_channels, out_channels, 1, padding=0, xavier_gain=xavier_gain)

    def forward(self, x, condition):
        h = x
        h = self.b1(h, condition) if condition is not None else self.b1(h)
        h = self.activation(h)
        if self.upsample:
            h = torch.nn.functional.interpolate(h, scale_factor=2)
        h = self.c1(h)
        h = self.b2(h, condition) if condition is not None else self.b2(h)
        h = self.activation(h)
        h = self.c2(h)
        if self.learnable_sc:
            if self.upsample:
                x = torch.nn.functional.interpolate(x, scale_factor=2)
            sc = self.c_sc(x)
        else:
            sc = x
        return h + sc


class ResDisBlock(torch.nn.Module):
    # https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/sn_projection_cgan_64x64_143c.ipynb
    def __init__(self, in_channels, out_channels, hidden_channels=None, ksize=3, pad=1,
                 activation=torch.nn.functional.relu, downsample=False):
        super().__init__()
        self.activation = activation
        self.downsample = downsample
        self.learnable_sc = (in_channels != out_channels) or downsample
        hidden_channels = in_channels if hidden_channels is None else hidden_channels
        self.c1 = ResConv2d(in_channels, hidden_channels, ksize, padding=pad, xavier_gain=2**0.5)
        self.c2 = ResConv2d(hidden_channels, out_channels, ksize, padding=pad, xavier_gain=2**0.5)
        if self.learnable_sc:
            self.c_sc = ResConv2d(in_channels, out_channels, 1, padding=0)

    def parameterize(self):
        torch.nn.utils.spectral_norm(self.c1)
        torch.nn.utils.spectral_norm(self.c2)
        if self.learnable_sc:
            torch.nn.utils.spectral_norm(self.c_sc)

    def forward(self, x):
        h = x
        h = self.activation(h)
        h = self.c1(h)
        h = self.activation(h)
        h = self.c2(h)
        if self.downsample:
            h = torch.nn.functional.avg_pool2d(h, 2)
        if self.learnable_sc:
            sc = self.c_sc(x)
            if self.downsample:
                sc = torch.nn.functional.avg_pool2d(sc, 2)
        else:
            sc = x
        return h + sc


class ResDisOptimizedBlock(torch.nn.Module):
    # https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/sn_projection_cgan_64x64_143c.ipynb
    def __init__(self, in_channels, out_channels, ksize=3, pad=1, activation=torch.nn.functional.relu):
        super().__init__()
        self.activation = activation
        self.c1 = ResConv2d(in_channels, out_channels, ksize, padding=pad, xavier_gain=2**0.5)
        self.c2 = ResConv2d(out_channels, out_channels, ksize, padding=pad, xavier_gain=2**0.5)
        self.c_sc = ResConv2d(in_channels, out_channels, 1, padding=0)

    def parameterize(self):
        torch.nn.utils.spectral_norm(self.c1)
        torch.nn.utils.spectral_norm(self.c2)
        torch.nn.utils.spectral_norm(self.c_sc)

    def forward(self, x):
        h = x
        h = self.c1(h)
        h = self.activation(h)
        h = self.c2(h)
        h = torch.nn.functional.avg_pool2d(h, 2)
        sc = self.c_sc(x)
        sc = torch.nn.functional.avg_pool2d(sc, 2)
        return h + sc
