# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
"""The NCSNv2 model."""

import flax.linen as nn
import functools

from .utils import register_model
from .layers import (
    CondRefineBlock,
    RefineBlock,
    ResidualBlock,
    ncsn_conv3x3,
    ConditionalResidualBlock,
    get_act,
)
from .normalization import get_normalization
import ml_collections

CondResidualBlock = ConditionalResidualBlock
conv3x3 = ncsn_conv3x3


def get_network(config):
    if config.data.image_size < 96:
        return functools.partial(NCSNv2, config=config)
    elif 96 <= config.data.image_size <= 128:
        return functools.partial(NCSNv2_128, config=config)
    elif 128 < config.data.image_size <= 256:
        return functools.partial(NCSNv2_256, config=config)
    else:
        raise NotImplementedError(
            f"No network suitable for {config.data.image_size}px implemented yet."
        )


@register_model(name="ncsnv2_64")
class NCSNv2(nn.Module):
    """NCSNv2 model architecture."""

    config: ml_collections.ConfigDict

    @nn.compact
    def __call__(self, x, labels, train=True):
        # config parsing
        config = self.config
        nf = config.model.nf
        act = get_act(config)
        normalizer = get_normalization(config)
        interpolation = config.model.interpolation

        if not config.data.centered:
            h = 2 * x - 1.0
        else:
            h = x

        h = conv3x3(h, nf, stride=1, bias=True)
        # ResNet backbone
        h = ResidualBlock(nf, resample=None, act=act, normalization=normalizer)(h)
        layer1 = ResidualBlock(nf, resample=None, act=act, normalization=normalizer)(h)
        h = ResidualBlock(2 * nf, resample="down", act=act, normalization=normalizer)(
            layer1
        )
        layer2 = ResidualBlock(
            2 * nf, resample=None, act=act, normalization=normalizer
        )(h)
        h = ResidualBlock(
            2 * nf, resample="down", act=act, normalization=normalizer, dilation=2
        )(layer2)
        layer3 = ResidualBlock(
            2 * nf, resample=None, act=act, normalization=normalizer, dilation=2
        )(h)
        h = ResidualBlock(
            2 * nf, resample="down", act=act, normalization=normalizer, dilation=4
        )(layer3)
        layer4 = ResidualBlock(
            2 * nf, resample=None, act=act, normalization=normalizer, dilation=4
        )(h)
        # U-Net with RefineBlocks
        ref1 = RefineBlock(
            layer4.shape[1:3], 2 * nf, act=act, interpolation=interpolation, start=True
        )([layer4])
        ref2 = RefineBlock(
            layer3.shape[1:3], 2 * nf, interpolation=interpolation, act=act
        )([layer3, ref1])
        ref3 = RefineBlock(
            layer2.shape[1:3], 2 * nf, interpolation=interpolation, act=act
        )([layer2, ref2])
        ref4 = RefineBlock(
            layer1.shape[1:3], nf, interpolation=interpolation, act=act, end=True
        )([layer1, ref3])

        h = normalizer()(ref4)
        h = act(h)
        h = conv3x3(h, x.shape[-1])

        return h


@register_model(name="ncsn")
class NCSN(nn.Module):
    """NCSNv1 model architecture."""

    config: ml_collections.ConfigDict

    @nn.compact
    def __call__(self, x, labels, train=True):
        # config parsing
        config = self.config
        nf = config.model.nf
        act = get_act(config)
        normalizer = get_normalization(config, conditional=True)
        interpolation = config.model.interpolation

        if not config.data.centered:
            h = 2 * x - 1.0
        else:
            h = x

        h = conv3x3(h, nf, stride=1, bias=True)
        # ResNet backbone
        h = CondResidualBlock(nf, resample=None, act=act, normalization=normalizer)(
            h, labels
        )
        layer1 = CondResidualBlock(
            nf, resample=None, act=act, normalization=normalizer
        )(h, labels)
        h = CondResidualBlock(
            2 * nf, resample="down", act=act, normalization=normalizer
        )(layer1, labels)
        layer2 = CondResidualBlock(
            2 * nf, resample=None, act=act, normalization=normalizer
        )(h, labels)
        h = CondResidualBlock(
            2 * nf, resample="down", act=act, normalization=normalizer, dilation=2
        )(layer2, labels)
        layer3 = CondResidualBlock(
            2 * nf, resample=None, act=act, normalization=normalizer, dilation=2
        )(h, labels)
        h = CondResidualBlock(
            2 * nf, resample="down", act=act, normalization=normalizer, dilation=4
        )(layer3, labels)
        layer4 = CondResidualBlock(
            2 * nf, resample=None, act=act, normalization=normalizer, dilation=4
        )(h, labels)
        # U-Net with RefineBlocks
        ref1 = CondRefineBlock(
            layer4.shape[1:3],
            2 * nf,
            act=act,
            normalizer=normalizer,
            interpolation=interpolation,
            start=True,
        )([layer4], labels)
        ref2 = CondRefineBlock(
            layer3.shape[1:3],
            2 * nf,
            normalizer=normalizer,
            interpolation=interpolation,
            act=act,
        )([layer3, ref1], labels)
        ref3 = CondRefineBlock(
            layer2.shape[1:3],
            2 * nf,
            normalizer=normalizer,
            interpolation=interpolation,
            act=act,
        )([layer2, ref2], labels)
        ref4 = CondRefineBlock(
            layer1.shape[1:3],
            nf,
            normalizer=normalizer,
            interpolation=interpolation,
            act=act,
            end=True,
        )([layer1, ref3], labels)

        h = normalizer()(ref4, labels)
        h = act(h)
        h = conv3x3(h, x.shape[-1])

        return h


@register_model(name="ncsnv2_128")
class NCSNv2_128(nn.Module):  # pylint: disable=invalid-name
    """NCSNv2 model architecture for 128px images."""

    config: ml_collections.ConfigDict

    @nn.compact
    def __call__(self, x, labels, train=True):
        # config parsing
        config = self.config
        nf = config.model.nf
        act = get_act(config)
        normalizer = get_normalization(config)
        interpolation = config.model.interpolation

        if not config.data.centered:
            h = 2 * x - 1.0
        else:
            h = x

        h = conv3x3(h, nf, stride=1, bias=True)
        # ResNet backbone
        h = ResidualBlock(nf, resample=None, act=act, normalization=normalizer)(h)
        layer1 = ResidualBlock(nf, resample=None, act=act, normalization=normalizer)(h)
        h = ResidualBlock(2 * nf, resample="down", act=act, normalization=normalizer)(
            layer1
        )
        layer2 = ResidualBlock(
            2 * nf, resample=None, act=act, normalization=normalizer
        )(h)
        h = ResidualBlock(2 * nf, resample="down", act=act, normalization=normalizer)(
            layer2
        )
        layer3 = ResidualBlock(
            2 * nf, resample=None, act=act, normalization=normalizer
        )(h)
        h = ResidualBlock(
            4 * nf, resample="down", act=act, normalization=normalizer, dilation=2
        )(layer3)
        layer4 = ResidualBlock(
            4 * nf, resample=None, act=act, normalization=normalizer, dilation=2
        )(h)
        h = ResidualBlock(
            4 * nf, resample="down", act=act, normalization=normalizer, dilation=4
        )(layer4)
        layer5 = ResidualBlock(
            4 * nf, resample=None, act=act, normalization=normalizer, dilation=4
        )(h)
        # U-Net with RefineBlocks
        ref1 = RefineBlock(
            layer5.shape[1:3], 4 * nf, interpolation=interpolation, act=act, start=True
        )([layer5])
        ref2 = RefineBlock(
            layer4.shape[1:3], 2 * nf, interpolation=interpolation, act=act
        )([layer4, ref1])
        ref3 = RefineBlock(
            layer3.shape[1:3], 2 * nf, interpolation=interpolation, act=act
        )([layer3, ref2])
        ref4 = RefineBlock(layer2.shape[1:3], nf, interpolation=interpolation, act=act)(
            [layer2, ref3]
        )
        ref5 = RefineBlock(
            layer1.shape[1:3], nf, interpolation=interpolation, act=act, end=True
        )([layer1, ref4])

        h = normalizer()(ref5)
        h = act(h)
        h = conv3x3(h, x.shape[-1])

        return h


@register_model(name="ncsnv2_256")
class NCSNv2_256(nn.Module):  # pylint: disable=invalid-name
    """NCSNv2 model architecture for 256px images."""

    config: ml_collections.ConfigDict

    @nn.compact
    def __call__(self, x, labels, train=True):
        # config parsing
        config = self.config
        nf = config.model.nf
        act = get_act(config)
        normalizer = get_normalization(config)
        interpolation = config.model.interpolation

        if not config.data.centered:
            h = 2 * x - 1.0
        else:
            h = x

        h = conv3x3(h, nf, stride=1, bias=True)
        # ResNet backbone
        h = ResidualBlock(nf, resample=None, act=act, normalization=normalizer)(h)
        layer1 = ResidualBlock(nf, resample=None, act=act, normalization=normalizer)(h)
        h = ResidualBlock(2 * nf, resample="down", act=act, normalization=normalizer)(
            layer1
        )
        layer2 = ResidualBlock(
            2 * nf, resample=None, act=act, normalization=normalizer
        )(h)
        h = ResidualBlock(2 * nf, resample="down", act=act, normalization=normalizer)(
            layer2
        )
        layer3 = ResidualBlock(
            2 * nf, resample=None, act=act, normalization=normalizer
        )(h)
        h = ResidualBlock(2 * nf, resample="down", act=act, normalization=normalizer)(
            layer3
        )
        layer31 = ResidualBlock(
            2 * nf, resample=None, act=act, normalization=normalizer
        )(h)
        h = ResidualBlock(
            4 * nf, resample="down", act=act, normalization=normalizer, dilation=2
        )(layer31)
        layer4 = ResidualBlock(
            4 * nf, resample=None, act=act, normalization=normalizer, dilation=2
        )(h)
        h = ResidualBlock(
            4 * nf, resample="down", act=act, normalization=normalizer, dilation=4
        )(layer4)
        layer5 = ResidualBlock(
            4 * nf, resample=None, act=act, normalization=normalizer, dilation=4
        )(h)
        # U-Net with RefineBlocks
        ref1 = RefineBlock(
            layer5.shape[1:3], 4 * nf, interpolation=interpolation, act=act, start=True
        )([layer5])
        ref2 = RefineBlock(
            layer4.shape[1:3], 2 * nf, interpolation=interpolation, act=act
        )([layer4, ref1])
        ref31 = RefineBlock(
            layer31.shape[1:3], 2 * nf, interpolation=interpolation, act=act
        )([layer31, ref2])
        ref3 = RefineBlock(
            layer3.shape[1:3], 2 * nf, interpolation=interpolation, act=act
        )([layer3, ref31])
        ref4 = RefineBlock(layer2.shape[1:3], nf, interpolation=interpolation, act=act)(
            [layer2, ref3]
        )
        ref5 = RefineBlock(
            layer1.shape[1:3], nf, interpolation=interpolation, act=act, end=True
        )([layer1, ref4])

        h = normalizer()(ref5)
        h = act(h)
        h = conv3x3(h, x.shape[-1])

        return h
