import torch
from torch import nn
from torch.nn import functional as F

from collections import abc
import math
import random

def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
    if bias is not None:
        rest_dim = [1] * (input.ndim - bias.ndim - 1)
        return F.leaky_relu(
                    input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
                ) * scale
    else:
        return F.leaky_relu(input, negative_slope=negative_slope) * scale

def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
    if not isinstance(up, abc.Iterable):
        up = (up, up)

    if not isinstance(down, abc.Iterable):
        down = (down, down)

    if len(pad) == 2:
        pad = (pad[0], pad[1], pad[0], pad[1])

    return upfirdn2d_native(input, kernel, *up, *down, *pad)

def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
    _, channel, in_h, in_w = input.shape
    input = input.reshape(-1, in_h, in_w, 1)

    _, in_h, in_w, minor = input.shape
    kernel_h, kernel_w = kernel.shape

    out = input.view(-1, in_h, 1, in_w, 1, minor)
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape(
        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
    )
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x

    return out.view(-1, channel, out_h, out_w)

class FusedLeakyReLU(nn.Module):
    def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
        super().__init__()

        if bias:
            self.bias = nn.Parameter(torch.zeros(channel))

        else:
            self.bias = None

        self.negative_slope = negative_slope
        self.scale = scale

    def forward(self, input):
        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)

class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)


def make_kernel(k):
    k = torch.tensor(k, dtype=torch.float32)

    if k.ndim == 1:
        k = k[None, :] * k[:, None]

    k /= k.sum()

    return k


class Upsample(nn.Module):
    def __init__(self, kernel, factor=2):
        super().__init__()

        self.factor = factor
        kernel = make_kernel(kernel) * (factor ** 2)
        self.register_buffer("kernel", kernel)

        p = kernel.shape[0] - factor

        pad0 = (p + 1) // 2 + factor - 1
        pad1 = p // 2

        self.pad = (pad0, pad1)

    def forward(self, input):
        out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)

        return out


class Downsample(nn.Module):
    def __init__(self, kernel, factor=2):
        super().__init__()

        self.factor = factor
        kernel = make_kernel(kernel)
        self.register_buffer("kernel", kernel)

        p = kernel.shape[0] - factor

        pad0 = (p + 1) // 2
        pad1 = p // 2

        self.pad = (pad0, pad1)

    def forward(self, input):
        out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)

        return out


class Blur(nn.Module):
    def __init__(self, kernel, pad, upsample_factor=1):
        super().__init__()

        kernel = make_kernel(kernel)

        if upsample_factor > 1:
            kernel = kernel * (upsample_factor ** 2)

        self.register_buffer("kernel", kernel)

        self.pad = pad

    def forward(self, input):
        out = upfirdn2d(input, self.kernel, pad=self.pad)

        return out


class EqualConv2d(nn.Module):
    def __init__(
        self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
    ):
        super().__init__()

        self.weight = nn.Parameter(
            torch.randn(out_channel, in_channel, kernel_size, kernel_size)
        )
        self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)

        self.stride = stride
        self.padding = padding

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channel))

        else:
            self.bias = None

    def forward(self, input):
        out = F.conv2d(
            input,
            self.weight * self.scale,
            bias=self.bias,
            stride=self.stride,
            padding=self.padding,
        )

        return out

    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
            f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
        )


class EqualLinear(nn.Module):
    def __init__(
        self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
    ):
        super().__init__()
        
        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))

        else:
            self.bias = None

        self.activation = activation

        self.scale = (1 / math.sqrt(in_dim)) * lr_mul
        self.lr_mul = lr_mul

    def forward(self, input):
        if self.activation:
            out = F.linear(input, self.weight * self.scale)
            out = fused_leaky_relu(out, self.bias * self.lr_mul)

        else:
            out = F.linear(
                input, self.weight * self.scale, bias=self.bias * self.lr_mul
            )

        return out

    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
        )


class ModulatedConv2d(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        demodulate=True,
        upsample=False,
        downsample=False,
        blur_kernel=[1, 3, 3, 1],
        fused=True,
    ):
        super().__init__()

        self.eps = 1e-8
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.upsample = upsample
        self.downsample = downsample

        if upsample:
            factor = 2
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2 + 1

            self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)

        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2

            self.blur = Blur(blur_kernel, pad=(pad0, pad1))

        fan_in = in_channel * kernel_size ** 2
        self.scale = 1 / math.sqrt(fan_in)
        self.padding = kernel_size // 2

        self.weight = nn.Parameter(
            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
        )

        self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)

        self.demodulate = demodulate
        self.fused = fused

    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
            f"upsample={self.upsample}, downsample={self.downsample})"
        )

    def forward(self, input, style):
        batch, in_channel, height, width = input.shape

        if not self.fused:
            weight = self.scale * self.weight.squeeze(0)
            style = self.modulation(style)

            if self.demodulate:
                w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
                dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()

            input = input * style.reshape(batch, in_channel, 1, 1)

            if self.upsample:
                weight = weight.transpose(0, 1)
                out = F.conv_transpose2d(
                    input, weight, padding=0, stride=2
                )
                out = self.blur(out)

            elif self.downsample:
                input = self.blur(input)
                out = F.conv2d(input, weight, padding=0, stride=2)

            else:
                out = F.conv2d(input, weight, padding=self.padding)

            if self.demodulate:
                out = out * dcoefs.view(batch, -1, 1, 1)

            return out

        style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
        weight = self.scale * self.weight * style

        if self.demodulate:
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
            weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)

        weight = weight.view(
            batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
        )

        if self.upsample:
            input = input.view(1, batch * in_channel, height, width)
            weight = weight.view(
                batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
            )
            weight = weight.transpose(1, 2).reshape(
                batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
            )
            out = F.conv_transpose2d(
                input, weight, padding=0, stride=2, groups=batch
            )
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)
            out = self.blur(out)

        elif self.downsample:
            input = self.blur(input)
            _, _, height, width = input.shape
            input = input.view(1, batch * in_channel, height, width)
            out = F.conv2d(
                input, weight, padding=0, stride=2, groups=batch
            )
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        else:
            input = input.view(1, batch * in_channel, height, width)
            out = F.conv2d(
                input, weight, padding=self.padding, groups=batch
            )
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        return out


class NoiseInjection(nn.Module):
    def __init__(self):
        super().__init__()

        self.weight = nn.Parameter(torch.zeros(1))

    def forward(self, image, noise=None):
        if noise is None:
            batch, _, height, width = image.shape
            noise = image.new_empty(batch, 1, height, width).normal_()

        return image + self.weight * noise


class ConstantInput(nn.Module):
    def __init__(self, channel, size=4):
        super().__init__()

        self.input = nn.Parameter(torch.randn(1, channel, size, size))

    def forward(self, input):
        batch = input.shape[0]
        out = self.input.repeat(batch, 1, 1, 1)

        return out


class StyledConv(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        style_dim,
        upsample=False,
        blur_kernel=[1, 3, 3, 1],
        demodulate=True,
    ):
        super().__init__()

        self.conv = ModulatedConv2d(
            in_channel,
            out_channel,
            kernel_size,
            style_dim,
            upsample=upsample,
            blur_kernel=blur_kernel,
            demodulate=demodulate,
        )

        self.noise = NoiseInjection()
        self.activate = FusedLeakyReLU(out_channel)

    def forward(self, input, style, noise=None):
        out = self.conv(input, style)
        out = self.noise(out, noise=noise)
        out = self.activate(out)

        return out

class ToRGB(nn.Module):
    def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
        super().__init__()

        if upsample:
            self.upsample = Upsample(blur_kernel)

        self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
        self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))

    def forward(self, input, style, skip=None):
        out = self.conv(input, style)
        out = out + self.bias

        if skip is not None:
            skip = self.upsample(skip)

            out = out + skip

        return out

class Quantize(nn.Module):
    def __init__(self, dim, n_embed, decay=0.997, eps=1e-5):
        super().__init__()

        self.dim = dim
        self.n_embed = n_embed
        self.decay = decay
        self.eps = eps

        embed = torch.randn(dim, n_embed) * 0.1 + 0.1
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(n_embed))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, input):
        
        flatten = input.reshape(-1, self.dim)
        dist = (
            flatten.pow(2).sum(1, keepdim=True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim=True)
        )
        _, embed_ind = (-dist).max(1)
        
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)

        diff = (quantize.detach() - input).pow(2).mean()
        quantize = input + (quantize - input).detach()

        return quantize, diff, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))

class Reshape(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def forward(self, tensor):
        return tensor.view(*self.shape)

class ConvLayer(nn.Sequential):
    def __init__(
        self,
        in_channel,
        out_channel,
        kernel_size,
        downsample=False,
        blur_kernel=[1, 3, 3, 1],
        bias=True,
        activate=True,
    ):
        layers = []

        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2

            layers.append(Blur(blur_kernel, pad=(pad0, pad1)))

            stride = 2
            self.padding = 0

        else:
            stride = 1
            self.padding = kernel_size // 2

        layers.append(
            EqualConv2d(
                in_channel,
                out_channel,
                kernel_size,
                padding=self.padding,
                stride=stride,
                bias=bias and not activate,
            )
        )

        if activate:
            layers.append(FusedLeakyReLU(out_channel, bias=bias))

        super().__init__(*layers)


class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
        super().__init__()

        self.conv1 = ConvLayer(in_channel, in_channel, 3)
        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)

        self.skip = ConvLayer(
            in_channel, out_channel, 1, downsample=True, activate=False, bias=False
        )

    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)

        skip = self.skip(input)
        out = (out + skip) / math.sqrt(2)

        return out

class StyledConvMod(StyledConv):
    def forward(self, input, style):
        style = F.adaptive_avg_pool2d(style,1).squeeze(-1).squeeze(-1)
        out = self.conv(input, style)
        out = self.activate(out)
        return out

class ToRGBMod(ToRGB):
    def forward(self, input, style, skip=None):
        style = F.adaptive_avg_pool2d(style,1).squeeze(-1).squeeze(-1)
        out = self.conv(input, style)
        out = out + self.bias
        if skip is not None:
            skip = self.upsample(skip)
            out = out + skip
        return out

def get_channels(channel_multiplier):
    return {
            4: 512, 8: 512, 16: 512, 32: 512, 
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier, }

class Discriminator(nn.Module):
    def __init__(self, size, channel_multiplier=2, blur_kernel=[1,3,3,1]):

        super().__init__()

        self.size = size
        self.channels = get_channels(channel_multiplier)

        convs = [ConvLayer(3, self.channels[size], 1)]

        log_size = int(math.log(size, 2))
        in_channel = self.channels[size]
        for i in range(log_size, 2, -1):
            out_channel = self.channels[2 ** (i - 1)]
            convs.append(ResBlock(in_channel, out_channel, blur_kernel))
            in_channel = out_channel

        self.convs = nn.Sequential(*convs)
        self.stddev_group = 4
        self.stddev_feat = 1
        self.final_rf = nn.Sequential(
            ConvLayer(in_channel + 1, self.channels[4], 3),
            Reshape([-1, self.channels[4] * 4 * 4]),
            EqualLinear(self.channels[4] * 4 * 4, self.channels[4], activation="fused_lrelu"),
            EqualLinear(self.channels[4], 1),
        )

        self.extract_res = [2**p for p in range( int(math.log2(self.size)) , 1, -1)]

        self.decoder = nn.ModuleList( [
            StyledConvMod( self.channels[8], self.channels[16]//2, 3, 
                        style_dim=self.channels[8], upsample=True, blur_kernel=blur_kernel ),
            StyledConvMod( self.channels[16]//2, self.channels[32]//2, 3, 
                        style_dim=self.channels[16], upsample=True, blur_kernel=blur_kernel ),
            StyledConvMod( self.channels[32]//2, self.channels[64]//2, 3, 
                        style_dim=self.channels[32], upsample=True, blur_kernel=blur_kernel ),
            StyledConvMod( self.channels[64]//2, self.channels[128]//2, 3, 
                        style_dim=self.channels[64], upsample=True, blur_kernel=blur_kernel ),
            ToRGBMod(self.channels[128]//2, self.channels[128], upsample=False)            
        ] )

    def extract(self, input, feature_res=None ):
        if feature_res==None:
            feature_res = self.extract_res
        out = []
        feat = input
        for i in range(len(self.convs)):
            feat = self.convs[i](feat)
            if feat.shape[-1] in feature_res:
                out.append(feat)
                if feat.shape[-1]==feature_res[-1]:
                    break
        return out

    def reconstruct(self, input):
        feats = self.extract(input, feature_res=[128,64,32,16,8,4])

        decode = feats[-2]
        for i in range(4, 0, -1):
            decode = self.decoder[4-i](decode, feats[i] )

        rec_img = self.decoder[4](decode, feats[0] )
        return rec_img, feats[-1]
        
    def getRFFeat(self, out):
        batch, channel, height, width = out.shape
        group = min(batch, self.stddev_group)
        stddev = out.view(
            group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
        )
        stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
        stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
        stddev = stddev.repeat(group, 1, height, width)
        out_rf = torch.cat([out, stddev], 1)

        return out_rf

    def forward(self, input, ae=False, extract=False):
        if ae: 
            rec_img, out = self.reconstruct(input)
        else: 
            if extract:
                feats = self.extract(input)
                out = feats[-1]
            else:
                out = self.convs(input)

        out_rf = self.getRFFeat(out)
        out_rf = self.final_rf(out_rf)
        
        outputs = [out_rf]
        if ae:
            outputs.append(rec_img)
        if extract:
            outputs.append(feats)

        if len(outputs)==1: return outputs[0]
        return outputs


class StyleEncoder(nn.Module):
    def __init__(self, size, style_dim, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
        super().__init__()

        channels = get_channels(channel_multiplier)

        self.style_dim = style_dim

        self.to_style = nn.ModuleList()
        for p in range( int(math.log2(size)) , 1, -1):
            s = 2**p
            rep = 2
            if s==4: rep=1
            elif s==size: rep=3
            
            for _ in range(rep):
                self.to_style.append( nn.Sequential(
                    ConvLayer(channels[s], channels[s], 3, downsample=True, blur_kernel=blur_kernel),
                    nn.AdaptiveMaxPool2d(4), Reshape([-1, channels[s]*4*4]),
                    EqualLinear(channels[s]*4*4, style_dim, activation="fused_lrelu"),
                    EqualLinear(style_dim, style_dim, activation="fused_lrelu"),
                    EqualLinear(style_dim, style_dim, activation=None),
                    ))

    def forward(self, feat_list):
        # input size: large -> small
        # output size: small -> large
        style_list = [ self.to_style[0](feat_list[0]).unsqueeze(1) ]
        for i, feat in enumerate(feat_list):
            style_list.append( self.to_style[i*2+1](feat).unsqueeze(1) )
            if feat.shape[-1] != 4:
                style_list.append( self.to_style[i*2+2](feat).unsqueeze(1) )

        style_list.reverse()
        return torch.cat(style_list, dim=1)

class VQ0(nn.Module):
    def __init__(self, style_dim, hw, dim, n_embed, vq_condition=False, decay=0.99, eps=1e-5):
        super().__init__()

        self.base_dim = 128
        self.hw = hw
        self.dim = dim
        self.vq_condition = vq_condition
        self.base = ConstantInput(self.base_dim)
        self.pre_quant = nn.ModuleList([ 
                        StyledConv(self.base_dim, self.base_dim, 3, style_dim, upsample= (hw>8) ),
                        StyledConv(self.base_dim, self.base_dim, 3, style_dim, upsample=True),
                        StyledConv(self.base_dim, self.base_dim, 1, style_dim, upsample= (hw==32) )])
        self.quantize = Quantize(self.base_dim, n_embed, decay, eps)
        
        self.after_quant = nn.Sequential(
            ConvLayer(self.base_dim, dim, 3),
            ConvLayer(dim, dim*2, 1))

        self.to_vector = nn.Sequential(
            ConvLayer(self.base_dim, style_dim//2, 3, downsample=True),
            ConvLayer(style_dim//2, style_dim, 3, downsample=True))

        self.activate = nn.LeakyReLU(0.1)

    def get_quant_and_vector(self, style, pre_quant=None):
        vq_feat = self.base(style)
        for conv in self.pre_quant:
            vq_feat = conv(vq_feat, style)
        
        vq_feat = vq_feat.permute(0,2,3,1)
        
        quantize, diff, embed_ind = self.quantize(vq_feat)
        quantize = quantize.permute(0, 3, 1, 2)

        new_style_vector = F.adaptive_avg_pool2d( self.to_vector(quantize) , 1).squeeze(-1).squeeze(-1)
        return quantize, diff, embed_ind, new_style_vector

    def forward(self, feat, quantize):
        
        quantize = self.after_quant( quantize ) 
        
        gamma = quantize[:,:self.dim]
        beta = quantize[:,self.dim:]
        out = feat * gamma + beta
        return self.activate(out)
    
    def pre_vq_condition(self, vq_feat, pre_quantize):
        affines = self.from_pre_vq( pre_quantize )
        affines = F.interpolate(affines, vq_feat.shape[2:])
        gamma = affines[:,:self.base_dim]
        beta = affines[:,self.base_dim:]
        out = vq_feat * gamma + beta
        return self.activate(out)

    def forward_with_embed(self, feat, emb_ind, style_conv, noise):
        # emb_ind shape: b x h x w
        quantize = self.quantize.embed_code(emb_ind) # b x h x w x feat_dim
        
        quantize = quantize.permute(0, 3, 1, 2)
        quantize = F.interpolate(quantize, self.hw)

        new_style_vector = F.adaptive_avg_pool2d( self.to_vector(quantize) , 1).squeeze(-1).squeeze(-1)
        feat = style_conv(feat, new_style_vector, noise)
        return self.forward(feat, quantize)
      
class VQ3(nn.Module):
    def __init__(self, style_dim, hw, dim, n_embed, vq_condition=False, decay=0.99, eps=1e-5):
        super().__init__()

        self.base_dim = 128
        self.hw = hw
        self.dim = dim
        self.vq_condition = vq_condition
        self.base = ConstantInput(self.base_dim)
        self.pre_quant = nn.ModuleList([ 
                        StyledConv(self.base_dim, self.base_dim, 3, style_dim, upsample= (hw>8) ),
                        StyledConv(self.base_dim, self.base_dim, 3, style_dim, upsample=True),
                        StyledConv(self.base_dim, self.base_dim, 1, style_dim, upsample= (hw==32) )])
        if self.vq_condition:
            self.from_pre_vq = ConvLayer(self.base_dim, self.base_dim*2, 1)

        self.quantize = Quantize(self.base_dim, n_embed, decay, eps)
        
        self.from_content = ConvLayer(dim, self.base_dim//4, 3)

        self.after_quant = nn.Sequential(
            ConvLayer(self.base_dim + self.base_dim//4, dim, 3),
            ConvLayer(dim, dim*2, 1))

        self.to_vector = nn.Sequential(
            ConvLayer(self.base_dim, style_dim//2, 3, downsample=True),
            ConvLayer(style_dim//2, style_dim, 3, downsample=True))

        self.activate = nn.LeakyReLU(0.1)

    def get_quant_and_vector(self, style, pre_quant=None):
        vq_feat = self.base(style)
        for conv in self.pre_quant:
            vq_feat = conv(vq_feat, style)
        
        if self.vq_condition:
            vq_feat = self.pre_vq_condition(vq_feat, pre_quant)

        vq_feat = vq_feat.permute(0,2,3,1)
        
        quantize, diff, embed_ind = self.quantize(vq_feat)
        quantize = quantize.permute(0, 3, 1, 2)

        new_style_vector = F.adaptive_avg_pool2d( self.to_vector(quantize) , 1).squeeze(-1).squeeze(-1)
        return quantize, diff, embed_ind, new_style_vector

    def forward(self, feat, quantize):
        
        content_feat = self.from_content( feat )

        quantize = self.after_quant( torch.cat([quantize, content_feat], dim=1) )
        
        gamma = quantize[:,:self.dim]
        beta = quantize[:,self.dim:]
        out = feat * gamma + beta
        return self.activate(out)
    
    def pre_vq_condition(self, vq_feat, pre_quantize):
        affines = self.from_pre_vq( pre_quantize )
        affines = F.interpolate(affines, vq_feat.shape[2:])
        gamma = affines[:,:self.base_dim]
        beta = affines[:,self.base_dim:]
        out = vq_feat * gamma + beta
        return self.activate(out)

    def forward_with_embed(self, feat, emb_ind, style_conv, noise):
        # emb_ind shape: b x h x w
        quantize = self.quantize.embed_code(emb_ind) # b x h x w x feat_dim
        
        quantize = quantize.permute(0, 3, 1, 2)
        quantize = F.interpolate(quantize, self.hw)

        new_style_vector = F.adaptive_avg_pool2d( self.to_vector(quantize) , 1).squeeze(-1).squeeze(-1)
        feat = style_conv(feat, new_style_vector, noise)
        return self.forward(feat, quantize)

class VQ6(nn.Module):
    def __init__(self, style_dim, hw, dim, n_embed, vq_condition=False, decay=0.99, eps=1e-5):
        super().__init__()

        self.base_dim = 128
        self.hw = hw
        self.dim = dim
        self.vq_condition = vq_condition

        self.base = ConstantInput(self.base_dim)
        self.pre_quant_1 = StyledConv(self.base_dim, self.base_dim, 3, style_dim, upsample= (hw>=8) )  
        if self.vq_condition: 
            self.from_pre_vq = StyledConv(self.base_dim, self.base_dim, 3, style_dim, upsample=False)
        self.pre_quant_2 = StyledConv(self.base_dim, self.base_dim, 3, style_dim, upsample= (hw>8))

        self.quantize = Quantize(self.base_dim, n_embed, decay, eps)
        
        self.after_quant_1 = ConvLayer(dim, self.base_dim//2, kernel_size=1)
        self.after_quant_2 = nn.Sequential(
            ConvLayer(self.base_dim // 2 * 3, dim, 3),
            ConvLayer(dim, dim*2, 1))

        self.to_vector = nn.Sequential(
            ConvLayer(self.base_dim, style_dim//2, 3, downsample=True),
            ConvLayer(style_dim//2, style_dim, 3, downsample=True))

        self.activate = nn.LeakyReLU(0.1)
    
    def get_quant_and_vector(self, style, pre_quant=None):
        vq_feat = self.pre_quant_1( self.base(style), style )
        if self.vq_condition:
            vq_feat = vq_feat + F.interpolate( self.from_pre_vq(pre_quant, style), vq_feat.shape[2:] )
        vq_feat = self.pre_quant_2(vq_feat, style)
        
        vq_feat = vq_feat.permute(0,2,3,1)
         
        quantize, diff, embed_ind = self.quantize(vq_feat)
        quantize = quantize.permute(0, 3, 1, 2)
        quantize = F.interpolate(quantize, self.hw)

        new_style_vector = F.adaptive_avg_pool2d( self.to_vector(quantize) , 1).squeeze(-1).squeeze(-1)
        return quantize, diff, embed_ind, new_style_vector

    def forward(self, feat, quantize):
        feat_condition = self.after_quant_1(feat)
        quantize = self.after_quant_2( torch.cat([quantize, feat_condition], dim=1) )
        
        gamma = quantize[:,:self.dim]
        beta = quantize[:,self.dim:]
        out = feat * gamma + beta
        return self.activate(out)

    def forward_with_embed(self, feat, emb_ind, style_conv, noise):
        # emb_ind shape: b x h x w
        quantize = self.quantize.embed_code(emb_ind) # b x h x w x feat_dim
        
        quantize = quantize.permute(0, 3, 1, 2)
        quantize = F.interpolate(quantize, self.hw)

        new_style_vector = F.adaptive_avg_pool2d( self.to_vector(quantize) , 1).squeeze(-1).squeeze(-1)
        feat = style_conv(feat, new_style_vector, noise)
        return self.forward(feat, quantize)


idx2res = [4,8,8,16,16,32,32,64,64]
class Generator(nn.Module):
    def __init__(self, size, style_dim, n_mlp, channel_multiplier=2, dislow=2, dishigh=5, n_embed=6, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, vq=0):
        super().__init__()

        self.size = size
        self.style_dim = style_dim
        self.channels = get_channels(channel_multiplier)

        layers = [PixelNorm()]
        for i in range(n_mlp):
            layers.append( EqualLinear( style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" ) )
        self.style = nn.Sequential(*layers)

        self.input = ConstantInput(self.channels[4])
        self.conv1 = StyledConv(
            self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
        )
        self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)

        self.convs = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.noises = nn.Module()
        
        self.log_size = int(math.log(size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1
        in_channel = self.channels[4]
        for layer_idx in range(self.num_layers):
            res = (layer_idx + 5) // 2
            shape = [1, 1, 2 ** res, 2 ** res]
            self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))

        for i in range(3, self.log_size + 1):
            out_channel = self.channels[2 ** i]

            self.convs.append(
                StyledConv(
                    in_channel,
                    out_channel,
                    3,
                    style_dim,
                    upsample=True,
                    blur_kernel=blur_kernel,
                )
            )

            self.convs.append(
                StyledConv(
                    out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
                )
            )

            self.to_rgbs.append(ToRGB(out_channel, style_dim))

            in_channel = out_channel

        self.n_latent = self.log_size * 2 - 2

        self.dislow = dislow
        self.dishigh = dishigh

        self.vqs = nn.ModuleList()
        n_embeds = n_embed
        if isinstance(n_embed, int):
            n_embeds = [n_embed]*(dishigh-dislow)
        
        vq_module = VQ0
        if vq==3: vq_module = VQ3
        elif vq==6: vq_module = VQ6
        for i, vqidx in enumerate( range(self.dislow, self.dishigh) ):
            hw = idx2res[vqidx]
            self.vqs.append( vq_module(style_dim, hw, self.channels[hw], n_embeds[i], vq_condition=(i>0)) )

    def make_noise(self):
        device = self.input.input.device

        noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]

        for i in range(3, self.log_size + 1):
            for _ in range(2):
                noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))

        return noises

    def prepare_latents(self, styles, input_is_latent, noise, 
            randomize_noise, truncation, truncation_latent, inject_index ):
        if not input_is_latent: styles = [self.style(s) for s in styles]

        if noise is None:
            if randomize_noise: noise = self.make_noise()
            else: noise = [getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)]

        if truncation < 1:
            style_t = []
            for style in styles:
                style_t.append(
                    truncation_latent + truncation * (style - truncation_latent)
                )
            styles = style_t

        if len(styles) < 2:
            inject_index = self.n_latent
            if styles[0].ndim < 3:
                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            else:
                latent = styles[0]
        else:
            if inject_index is None:
                if random.randint(0,1)==0:
                    inject_index = random.randint(self.dishigh, self.n_latent - 1)
                else:
                    inject_index = random.randint(1, self.dislow)

            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            if self.n_latent - inject_index > 0:
                latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
                latent = torch.cat([latent, latent2], 1)
        return latent, noise

    def _get_vqs(self, latent):
        quants, diffs, emb_inds, new_style_latents = [None], 0, [], []
        i = 0
        for vqidx in range(self.dislow, self.dishigh):
            cur_quant, cur_diff, cur_embind, cur_style = self.vqs[i].get_quant_and_vector(latent[:, vqidx], pre_quant=quants[-1])
            quants.append(cur_quant)
            diffs += cur_diff
            emb_inds.append(cur_embind)
            new_style_latents.append(cur_style)
            i+=1
        return quants[1:], diffs, emb_inds, new_style_latents

    def _forward_main(self, latent, noise):
        out = self.input(latent)
        out = self.conv1(out, latent[:, 0], noise=noise[0])
        skip = self.to_rgb1(out, latent[:, 1])

        quants, diffs, emb_inds, new_style_latents = self._get_vqs(latent)
        
        i = 1   # latent idx 3,4,5 contents the pose
        j = 0
        for conv1, conv2, noise1, noise2, to_rgb in zip(self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs):
            # 1st conv
            if self.dislow<=i<self.dishigh and self.dishigh>self.dislow:
                out = conv1(out, new_style_latents[j], noise=noise1)
            else: 
                out = conv1(out, latent[:, i], noise=noise1)
            if self.dishigh>self.dislow and self.dislow<=i<self.dishigh:
                out = self.vqs[j](out, quants[j])
                j += 1
            
            # 2nd conv
            if self.dishigh>self.dislow and self.dislow<=i+1<self.dishigh:
                out = conv2(out, new_style_latents[j], noise=noise2)
            else:
                out = conv2(out, latent[:, i + 1], noise=noise2)
            if self.dishigh>self.dislow and self.dislow<=i+1<self.dishigh:
                out = self.vqs[j](out, quants[j])
                j += 1

            skip = to_rgb(out, latent[:, i + 2], skip)

            i += 2

        image = skip
        return image, diffs, emb_inds

    def forward(self,
        styles,
        return_latents=False,
        return_noise=False,
        inject_index=None,
        truncation=1,
        truncation_latent=None,
        input_is_latent=False,
        noise=None,
        randomize_noise=True,
    ):
        latent, noise = self.prepare_latents(styles, inject_index=inject_index,
                                            input_is_latent=input_is_latent, noise=noise,
                                            randomize_noise=randomize_noise,
                                            truncation=truncation,
                                            truncation_latent=truncation_latent)

        image, diff, embed_idxs = self._forward_main(latent, noise)

        if return_latents:
            if return_noise:
                return image, latent, noise, diff, embed_idxs
            return image, latent, diff, embed_idxs
        else:
            if return_noise:
                return image, None, noise, diff, embed_idxs
            return image, None, diff, embed_idxs

    def decode(self, latent, noise=None, random_noise=False, return_noise=False):
        #latent = torch.cat([latent, latent[:,-1].unsqueeze(1)], dim=1)
        if noise is None:
            if random_noise:
                noise = self.make_noise()
            else:
                noise = [getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)]

        image, diff, embed_idxs = self._forward_main(latent, noise)

        if return_noise:
            return image, noise, diff, embed_idxs
        return image, diff, embed_idxs

    def decode_with_embed(self, latent, embed_inds):
        # embed_inds: a list, each entry has shape: b x h x w
        noise = [getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)]

        out = self.input(latent)
        out = self.conv1(out, latent[:, 0], noise=noise[0])
        skip = self.to_rgb1(out, latent[:, 1])

        i = 1   # latent idx 3,4,5 contents the pose
        j = 0
        for conv1, conv2, noise1, noise2, to_rgb in zip(self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs):
            # 1st conv
            if self.dislow<=i<self.dishigh and self.dishigh>self.dislow:
                out = self.vqs[j].forward_with_embed(out, embed_inds[j], conv1, noise1)
                j += 1
            else: 
                out = conv1(out, latent[:, i], noise=noise1)
            
            # 2nd conv
            if self.dishigh>self.dislow and self.dislow<=i+1<self.dishigh:
                out = self.vqs[j].forward_with_embed(out, embed_inds[j], conv2, noise2)
                j += 1
            else:
                out = conv2(out, latent[:, i + 1], noise=noise2)
            
            skip = to_rgb(out, latent[:, i + 2], skip)
            i += 2

        image = skip
        return image

class IPI2I(nn.Module):
    def __init__(self, size=256, latent=512, n_mlp=8, channel_multiplier=2, dislow=3, dishigh=6, n_embed=[2,6,6], vq=0):
        super().__init__()

        self.size = size
        self.channel_multiplier = channel_multiplier
        self.dislow = dislow
        self.dishigh = dishigh

        self.extract_res = [2**p for p in range( int(math.log2(self.size)) , 1, -1)]

        self.generator = Generator(size, latent, n_mlp, channel_multiplier, dislow, dishigh, n_embed, vq=vq)
        self.generator.vqs.eval()
        self._generate = self.generator.decode

        self.convs = Discriminator(size, channel_multiplier).convs
        
        self.style_encoder = StyleEncoder(size, latent, channel_multiplier)

    def load_from_training_models(self, ckpt):
        if isinstance(ckpt, str):
            ckpt = torch.load(ckpt)

        self.generator.load_state_dict(ckpt["g_ema"])
        self.style_encoder.load_state_dict(ckpt["se_ema"])

        tmp_discriminator = Discriminator(self.size, self.channel_multiplier)
        tmp_discriminator.load_state_dict(ckpt['d'])
        self.convs = tmp_discriminator.convs
        del tmp_discriminator
    
    @torch.no_grad()
    def _feat_extract(self, image, feature_res=None):
        if feature_res==None:
            feature_res = self.extract_res
        out = []
        feat = image
        for i in range(len(self.convs)):
            feat = self.convs[i](feat)
            if feat.shape[-1] in feature_res:
                out.append(feat)
                if feat.shape[-1]==feature_res[-1]:
                    break
        return out
    
    @torch.no_grad()
    def forward(self, identity, pose, mix_low=None, mix_high=None):
        if mix_low==None:
            mix_low = self.dislow
        if mix_high==None:
            mix_high = self.dishigh

        feat_idt = self._feat_extract(identity)
        feat_pos = self._feat_extract(pose)

        latents_idt = self.style_encoder(feat_idt)
        latents_pos = self.style_encoder(feat_pos)

        latents_mix = latents_idt
        latents_mix[:,mix_low:mix_high] = latents_pos[:,mix_low:mix_high]
        rec_img, _, embed_idxs = self._generate(latents_mix)

        return rec_img

    @torch.no_grad()
    def forward_with_segmap(self, latents, segmaps):
        return self.generator.decode_with_embed(latents, segmaps)

    @torch.no_grad()
    def get_latents_and_rec_image(self, image):
        feats = self._feat_extract(image)
        latents = self.style_encoder(feats)
        rec_image, _, embed_idxs = self._generate(latents)
        return rec_image, latents, embed_idxs
    
    @torch.no_grad()
    def get_latents(self, image):
        feats = self._feat_extract(image)
        latents = self.style_encoder(feats)
        return latents

    @torch.no_grad()
    def forward_with_latents(self, latents):
        rec_image, _, embed_idxs = self._generate(latents)
        return rec_image, embed_idxs

    @torch.no_grad()
    def forward_with_mix_latents(self, latents_idt, latents_pose):
        latents = latents_idt.clone()
        latents[:,self.dislow:self.dishigh] = latents_pose[:,self.dislow:self.dishigh] 
        rec_image, _, embed_idxs = self._generate(latents)
        return rec_image, embed_idxs


def train_model_to_demo_model(ckpt, dist_path):
    ckpt = torch.load(ckpt)
    args = ckpt.get('args')
 
    net = IPI2I(dislow=args.dislow, dishigh=args.dishigh, n_embed=args.vq_emb, vq=args.vq_type)
    net.load_from_training_models(ckpt)
    net.cpu()
    torch.save({'net': net.state_dict(),
                'args': args}, dist_path)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="training model to demo model convertor")
    
    parser.add_argument("--ckpt", type=str, default=None, help="path to the checkpoints of the trained model")
    parser.add_argument("--dist", type=str, default=None, help="path to save the converted model")

    args = parser.parse_args()
    train_model_to_demo_model(args.ckpt, args.dist)

'''
## test if the converted model generates correct image

ckpt = torch.load(ckpt)
args = ckpt.get('args')
net = IPI2I(dislow=args.dislow, dishigh=args.dishigh, n_embed=args.vq_emb)
net.load_state_dict(ckpt['net'])

size = 256
path = '../../../../Images/afhq/afhq/train/'
batch = 4

from torch.utils import data
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.utils import save_image, make_grid
transform = transforms.Compose( [
            transforms.ToTensor(),
            transforms.Resize((size, size)),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ] )
dataset = ImageFolder(path, transform=transform)
loader = iter(data.DataLoader(dataset, batch_size=batch, drop_last=True, shuffle=True))
images_idt = next(loader)[0].to(device)

mixed = net(images_idt[:2], images_idt[2:])
save_image( torch.cat([images_idt, mixed]).add(1).mul(0.5), 'test.jpg', nrow=2 )
'''
