# From https://github.com/brownvc/R3GAN/blob/main/R3GAN/FusedOperators.py
import math

import numpy
import torch
import torch.nn as nn
from torch_utils.ops import bias_act, upfirdn2d


class BiasedActivationReference(nn.Module):
    Gain = math.sqrt(2 / (1 + 0.2**2))
    Function = nn.LeakyReLU(0.2)

    def __init__(self, InputUnits):
        super(BiasedActivationReference, self).__init__()

        self.Bias = nn.Parameter(torch.empty(InputUnits))
        self.Bias.data.zero_()

    def forward(self, x):
        y = (
            x + self.Bias.to(x.dtype).view(1, -1, 1, 1)
            if len(x.shape) > 2
            else x + self.Bias.to(x.dtype).view(1, -1)
        )
        return BiasedActivationReference.Function(y)


class BiasedActivationCUDA(nn.Module):
    Gain = math.sqrt(2 / (1 + 0.2**2))
    Function = "lrelu"

    def __init__(self, InputUnits):
        super(BiasedActivationCUDA, self).__init__()

        self.Bias = nn.Parameter(torch.empty(InputUnits))
        self.Bias.data.zero_()

    def forward(self, x):
        return bias_act.bias_act(
            x, self.Bias.to(x.dtype), act=BiasedActivationCUDA.Function, gain=1
        )


BiasedActivation = BiasedActivationCUDA


# From https://github.com/brownvc/R3GAN/blob/main/R3GAN/Resamplers.py


def CreateLowpassKernel(Weights, Inplace):
    Kernel = (
        numpy.array([Weights])
        if Inplace
        else numpy.convolve(Weights, [1, 1]).reshape(1, -1)
    )
    Kernel = torch.Tensor(Kernel.T @ Kernel)
    return Kernel / torch.sum(Kernel)


class InterpolativeUpsamplerReference(nn.Module):
    def __init__(self, Filter):
        super(InterpolativeUpsamplerReference, self).__init__()

        self.register_buffer(
            "Kernel", CreateLowpassKernel(Filter, Inplace=False)
        )
        self.FilterRadius = len(Filter) // 2

    def forward(self, x):
        Kernel = 4 * self.Kernel.view(
            1, 1, self.Kernel.shape[0], self.Kernel.shape[1]
        ).to(x.dtype)
        y = nn.functional.conv_transpose2d(
            x.view(x.shape[0] * x.shape[1], 1, x.shape[2], x.shape[3]),
            Kernel,
            stride=2,
            padding=self.FilterRadius,
        )

        return y.view(x.shape[0], x.shape[1], y.shape[2], y.shape[3])


class InterpolativeDownsamplerReference(nn.Module):
    def __init__(self, Filter):
        super(InterpolativeDownsamplerReference, self).__init__()

        self.register_buffer(
            "Kernel", CreateLowpassKernel(Filter, Inplace=False)
        )
        self.FilterRadius = len(Filter) // 2

    def forward(self, x):
        Kernel = self.Kernel.view(
            1, 1, self.Kernel.shape[0], self.Kernel.shape[1]
        ).to(x.dtype)
        y = nn.functional.conv2d(
            x.view(x.shape[0] * x.shape[1], 1, x.shape[2], x.shape[3]),
            Kernel,
            stride=2,
            padding=self.FilterRadius,
        )

        return y.view(x.shape[0], x.shape[1], y.shape[2], y.shape[3])


class InplaceUpsamplerReference(nn.Module):
    def __init__(self, Filter):
        super(InplaceUpsamplerReference, self).__init__()

        self.register_buffer(
            "Kernel", CreateLowpassKernel(Filter, Inplace=True)
        )
        self.FilterRadius = len(Filter) // 2

    def forward(self, x):
        Kernel = self.Kernel.view(
            1, 1, self.Kernel.shape[0], self.Kernel.shape[1]
        ).to(x.dtype)
        x = nn.functional.pixel_shuffle(x, 2)

        return nn.functional.conv2d(
            x.view(x.shape[0] * x.shape[1], 1, x.shape[2], x.shape[3]),
            Kernel,
            stride=1,
            padding=self.FilterRadius,
        ).view(*x.shape)


class InplaceDownsamplerReference(nn.Module):
    def __init__(self, Filter):
        super(InplaceDownsamplerReference, self).__init__()

        self.register_buffer(
            "Kernel", CreateLowpassKernel(Filter, Inplace=True)
        )
        self.FilterRadius = len(Filter) // 2

    def forward(self, x):
        Kernel = self.Kernel.view(
            1, 1, self.Kernel.shape[0], self.Kernel.shape[1]
        ).to(x.dtype)
        y = nn.functional.conv2d(
            x.view(x.shape[0] * x.shape[1], 1, x.shape[2], x.shape[3]),
            Kernel,
            stride=1,
            padding=self.FilterRadius,
        ).view(*x.shape)

        return nn.functional.pixel_unshuffle(y, 2)


class InterpolativeUpsamplerCUDA(nn.Module):
    def __init__(self, Filter):
        super(InterpolativeUpsamplerCUDA, self).__init__()

        self.register_buffer(
            "Kernel", CreateLowpassKernel(Filter, Inplace=False)
        )

    def forward(self, x):
        return upfirdn2d.upsample2d(x, self.Kernel)


class InterpolativeDownsamplerCUDA(nn.Module):
    def __init__(self, Filter):
        super(InterpolativeDownsamplerCUDA, self).__init__()

        self.register_buffer(
            "Kernel", CreateLowpassKernel(Filter, Inplace=False)
        )

    def forward(self, x):
        return upfirdn2d.downsample2d(x, self.Kernel)


class InplaceUpsamplerCUDA(nn.Module):
    def __init__(self, Filter):
        super(InplaceUpsamplerCUDA, self).__init__()

        self.register_buffer(
            "Kernel", CreateLowpassKernel(Filter, Inplace=True)
        )
        self.FilterRadius = len(Filter) // 2

    def forward(self, x):
        return upfirdn2d.upfirdn2d(
            nn.functional.pixel_shuffle(x, 2),
            self.Kernel,
            padding=self.FilterRadius,
        )


class InplaceDownsamplerCUDA(nn.Module):
    def __init__(self, Filter):
        super(InplaceDownsamplerCUDA, self).__init__()

        self.register_buffer(
            "Kernel", CreateLowpassKernel(Filter, Inplace=True)
        )
        self.FilterRadius = len(Filter) // 2

    def forward(self, x):
        return nn.functional.pixel_unshuffle(
            upfirdn2d.upfirdn2d(x, self.Kernel, padding=self.FilterRadius), 2
        )


InterpolativeUpsampler = InterpolativeUpsamplerCUDA
InterpolativeDownsampler = InterpolativeDownsamplerCUDA
InplaceUpsampler = InplaceUpsamplerCUDA
InplaceDownsampler = InplaceDownsamplerCUDA


# From ttps://github.com/brownvc/R3GAN/blob/main/R3GAN/Networks.py

# import math
# import torch
# import torch.nn as nn
# from .Resamplers import InterpolativeUpsampler, InterpolativeDownsampler
# from .FusedOperators import BiasedActivation


def MSRInitializer(Layer, ActivationGain=1):
    FanIn = Layer.weight.data.size(1) * Layer.weight.data[0][0].numel()
    Layer.weight.data.normal_(0, ActivationGain / math.sqrt(FanIn))

    if Layer.bias is not None:
        Layer.bias.data.zero_()

    return Layer


class Convolution(nn.Module):
    def __init__(
        self,
        InputChannels,
        OutputChannels,
        KernelSize,
        Groups=1,
        ActivationGain=1,
    ):
        super(Convolution, self).__init__()

        self.Layer = MSRInitializer(
            nn.Conv2d(
                InputChannels,
                OutputChannels,
                kernel_size=KernelSize,
                stride=1,
                padding=(KernelSize - 1) // 2,
                groups=Groups,
                bias=False,
            ),
            ActivationGain=ActivationGain,
        )

    def forward(self, x):
        return nn.functional.conv2d(
            x,
            self.Layer.weight.to(x.dtype),
            padding=self.Layer.padding,
            groups=self.Layer.groups,
        )


class ResidualBlock(nn.Module):
    def __init__(
        self,
        InputChannels,
        Cardinality,
        ExpansionFactor,
        KernelSize,
        VarianceScalingParameter,
    ):
        super(ResidualBlock, self).__init__()

        NumberOfLinearLayers = 3
        ExpandedChannels = InputChannels * ExpansionFactor
        ActivationGain = BiasedActivation.Gain * VarianceScalingParameter ** (
            -1 / (2 * NumberOfLinearLayers - 2)
        )

        self.LinearLayer1 = Convolution(
            InputChannels,
            ExpandedChannels,
            KernelSize=1,
            ActivationGain=ActivationGain,
        )
        self.LinearLayer2 = Convolution(
            ExpandedChannels,
            ExpandedChannels,
            KernelSize=KernelSize,
            Groups=Cardinality,
            ActivationGain=ActivationGain,
        )
        self.LinearLayer3 = Convolution(
            ExpandedChannels, InputChannels, KernelSize=1, ActivationGain=0
        )

        self.NonLinearity1 = BiasedActivation(ExpandedChannels)
        self.NonLinearity2 = BiasedActivation(ExpandedChannels)

    def forward(self, x):
        y = self.LinearLayer1(x)
        y = self.LinearLayer2(self.NonLinearity1(y))
        y = self.LinearLayer3(self.NonLinearity2(y))

        return x + y


class UpsampleLayer(nn.Module):
    def __init__(self, InputChannels, OutputChannels, ResamplingFilter):
        super(UpsampleLayer, self).__init__()

        self.Resampler = InterpolativeUpsampler(ResamplingFilter)

        if InputChannels != OutputChannels:
            self.LinearLayer = Convolution(
                InputChannels, OutputChannels, KernelSize=1
            )

    def forward(self, x):
        x = self.LinearLayer(x) if hasattr(self, "LinearLayer") else x
        x = self.Resampler(x)

        return x


class DownsampleLayer(nn.Module):
    def __init__(self, InputChannels, OutputChannels, ResamplingFilter):
        super(DownsampleLayer, self).__init__()

        self.Resampler = InterpolativeDownsampler(ResamplingFilter)

        if InputChannels != OutputChannels:
            self.LinearLayer = Convolution(
                InputChannels, OutputChannels, KernelSize=1
            )

    def forward(self, x):
        x = self.Resampler(x)
        x = self.LinearLayer(x) if hasattr(self, "LinearLayer") else x

        return x


class GenerativeBasis(nn.Module):
    def __init__(self, InputDimension, OutputChannels):
        super(GenerativeBasis, self).__init__()

        self.Basis = nn.Parameter(
            torch.empty(OutputChannels, 4, 4).normal_(0, 1)
        )
        self.LinearLayer = MSRInitializer(
            nn.Linear(InputDimension, OutputChannels, bias=False)
        )

    def forward(self, x):
        return self.Basis.view(1, -1, 4, 4) * self.LinearLayer(x).view(
            x.shape[0], -1, 1, 1
        )


class DiscriminativeBasis(nn.Module):
    def __init__(self, InputChannels, OutputDimension):
        super(DiscriminativeBasis, self).__init__()

        self.Basis = MSRInitializer(
            nn.Conv2d(
                InputChannels,
                InputChannels,
                kernel_size=4,
                stride=1,
                padding=0,
                groups=InputChannels,
                bias=False,
            )
        )
        self.LinearLayer = MSRInitializer(
            nn.Linear(InputChannels, OutputDimension, bias=False)
        )

    def forward(self, x):
        return self.LinearLayer(self.Basis(x).view(x.shape[0], -1))


class GeneratorStage(nn.Module):
    def __init__(
        self,
        InputChannels,
        OutputChannels,
        Cardinality,
        NumberOfBlocks,
        ExpansionFactor,
        KernelSize,
        VarianceScalingParameter,
        ResamplingFilter=None,
        DataType=torch.float32,
    ):
        super(GeneratorStage, self).__init__()

        TransitionLayer = (
            GenerativeBasis(InputChannels, OutputChannels)
            if ResamplingFilter is None
            else UpsampleLayer(InputChannels, OutputChannels, ResamplingFilter)
        )
        self.Layers = nn.ModuleList(
            [TransitionLayer]
            + [
                ResidualBlock(
                    OutputChannels,
                    Cardinality,
                    ExpansionFactor,
                    KernelSize,
                    VarianceScalingParameter,
                )
                for _ in range(NumberOfBlocks)
            ]
        )
        self.DataType = DataType

    def forward(self, x):
        x = x.to(self.DataType)

        for Layer in self.Layers:
            x = Layer(x)

        return x


class DiscriminatorStage(nn.Module):
    def __init__(
        self,
        InputChannels,
        OutputChannels,
        Cardinality,
        NumberOfBlocks,
        ExpansionFactor,
        KernelSize,
        VarianceScalingParameter,
        ResamplingFilter=None,
        DataType=torch.float32,
    ):
        super(DiscriminatorStage, self).__init__()

        TransitionLayer = (
            DiscriminativeBasis(InputChannels, OutputChannels)
            if ResamplingFilter is None
            else DownsampleLayer(
                InputChannels, OutputChannels, ResamplingFilter
            )
        )
        self.Layers = nn.ModuleList(
            [
                ResidualBlock(
                    InputChannels,
                    Cardinality,
                    ExpansionFactor,
                    KernelSize,
                    VarianceScalingParameter,
                )
                for _ in range(NumberOfBlocks)
            ]
            + [TransitionLayer]
        )
        self.DataType = DataType

    def forward(self, x):
        x = x.to(self.DataType)

        for Layer in self.Layers:
            x = Layer(x)

        return x


class Generator(nn.Module):
    def __init__(
        self,
        NoiseDimension,
        WidthPerStage,
        CardinalityPerStage,
        BlocksPerStage,
        ExpansionFactor,
        ConditionDimension=None,
        ConditionEmbeddingDimension=0,
        KernelSize=3,
        ResamplingFilter=[1, 2, 1],
    ):
        super(Generator, self).__init__()

        VarianceScalingParameter = sum(BlocksPerStage)
        MainLayers = [
            GeneratorStage(
                NoiseDimension + ConditionEmbeddingDimension,
                WidthPerStage[0],
                CardinalityPerStage[0],
                BlocksPerStage[0],
                ExpansionFactor,
                KernelSize,
                VarianceScalingParameter,
            )
        ]
        MainLayers += [
            GeneratorStage(
                WidthPerStage[x],
                WidthPerStage[x + 1],
                CardinalityPerStage[x + 1],
                BlocksPerStage[x + 1],
                ExpansionFactor,
                KernelSize,
                VarianceScalingParameter,
                ResamplingFilter,
            )
            for x in range(len(WidthPerStage) - 1)
        ]

        self.MainLayers = nn.ModuleList(MainLayers)
        self.AggregationLayer = Convolution(WidthPerStage[-1], 3, KernelSize=1)

        if ConditionDimension is not None:
            self.EmbeddingLayer = MSRInitializer(
                nn.Linear(
                    ConditionDimension, ConditionEmbeddingDimension, bias=False
                )
            )

    def forward(self, x, y=None):
        x = (
            torch.cat([x, self.EmbeddingLayer(y)], dim=1)
            if hasattr(self, "EmbeddingLayer")
            else x
        )

        for Layer in self.MainLayers:
            x = Layer(x)

        return self.AggregationLayer(x)


class Discriminator(nn.Module):
    def __init__(
        self,
        WidthPerStage,
        CardinalityPerStage,
        BlocksPerStage,
        ExpansionFactor,
        ConditionDimension=None,
        ConditionEmbeddingDimension=0,
        KernelSize=3,
        ResamplingFilter=[1, 2, 1],
    ):
        super(Discriminator, self).__init__()

        VarianceScalingParameter = sum(BlocksPerStage)
        MainLayers = [
            DiscriminatorStage(
                WidthPerStage[x],
                WidthPerStage[x + 1],
                CardinalityPerStage[x],
                BlocksPerStage[x],
                ExpansionFactor,
                KernelSize,
                VarianceScalingParameter,
                ResamplingFilter,
            )
            for x in range(len(WidthPerStage) - 1)
        ]
        MainLayers += [
            DiscriminatorStage(
                WidthPerStage[-1],
                1
                if ConditionDimension is None
                else ConditionEmbeddingDimension,
                CardinalityPerStage[-1],
                BlocksPerStage[-1],
                ExpansionFactor,
                KernelSize,
                VarianceScalingParameter,
            )
        ]

        self.ExtractionLayer = Convolution(3, WidthPerStage[0], KernelSize=1)
        self.MainLayers = nn.ModuleList(MainLayers)

        if ConditionDimension is not None:
            self.EmbeddingLayer = MSRInitializer(
                nn.Linear(
                    ConditionDimension, ConditionEmbeddingDimension, bias=False
                ),
                ActivationGain=1 / math.sqrt(ConditionEmbeddingDimension),
            )

    def forward(self, x, y=None):
        x = self.ExtractionLayer(x.to(self.MainLayers[0].DataType))

        for Layer in self.MainLayers:
            x = Layer(x)

        x = (
            (x * self.EmbeddingLayer(y)).sum(dim=1, keepdim=True)
            if hasattr(self, "EmbeddingLayer")
            else x
        )

        return x.view(x.shape[0])
