import torch
import torch.nn as nn
import torch.nn.functional as F

import math
import random

from models.model_stylegan2 import ConvLayer, EqualLinear, StyledConv, ToRGB, ConstantInput, ResBlock
from models.model_stylegan2 import Discriminator, Generator
from models.vq import Quantize

idx2res = [4,8,8,16,16,32,32,64,64]

def get_channels(channel_multiplier):
    return {
            4: 512, 8: 512, 16: 512, 32: 512, 
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier, }

class Reshape(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def forward(self, tensor):
        return tensor.view(*self.shape)

class StyledConvMod(StyledConv):
    def forward(self, input, style):
        style = F.adaptive_avg_pool2d(style,1).squeeze(-1).squeeze(-1)
        out = self.conv(input, style)
        out = self.activate(out)
        return out

class ToRGBMod(ToRGB):
    def forward(self, input, style, skip=None):
        style = F.adaptive_avg_pool2d(style,1).squeeze(-1).squeeze(-1)
        out = self.conv(input, style)
        out = out + self.bias
        if skip is not None:
            skip = self.upsample(skip)
            out = out + skip
        return out
 
class VQBase(nn.Module):
    def __init__(self, style_dim, hw, dim, n_embed, vq_condition=False, decay=0.99, eps=1e-5):
        super().__init__()

        self.base_dim = 128
        self.hw = hw
        self.dim = dim
        self.vq_condition = vq_condition
        self.base = ConstantInput(self.base_dim)

        self.quantize = Quantize(self.base_dim, n_embed, decay, eps)

        self.to_vector = nn.Sequential(
            ConvLayer(self.base_dim, style_dim//2, 3, downsample=True),
            ConvLayer(style_dim//2, style_dim, 3, downsample=True))

        self.activate = nn.LeakyReLU(0.1)
    
    def vqfeat_to_quant(self, vq_feat):
        vq_feat = vq_feat.permute(0,2,3,1)
        quantize, diff, embed_ind = self.quantize(vq_feat)
        quantize = quantize.permute(0, 3, 1, 2)
        quantize = F.interpolate(quantize, self.hw)

        new_style_vector = F.adaptive_avg_pool2d( self.to_vector(quantize) , 1).squeeze(-1).squeeze(-1)
        return quantize, diff, embed_ind, new_style_vector

    def quant_mod_feat(self, feat, quantize):
        gamma = quantize[:,:self.dim]
        beta = quantize[:,self.dim:]
        out = feat * gamma + beta
        return self.activate(out)

    def forward_with_embed(self, feat, emb_ind, style_conv, noise):
        # emb_ind shape: b x h x w
        quantize = self.quantize.embed_code(emb_ind) # b x h x w x feat_dim
        quantize = quantize.permute(0, 3, 1, 2)
        quantize = F.interpolate(quantize, self.hw)

        new_style_vector = F.adaptive_avg_pool2d( self.to_vector(quantize) , 1).squeeze(-1).squeeze(-1)
        feat = style_conv(feat, new_style_vector, noise)
        return self.forward(feat, quantize)

class VQ0(VQBase):
    def __init__(self, style_dim, hw, dim, n_embed, vq_condition=False, decay=0.99, eps=1e-5):
        super().__init__(style_dim, hw, dim, n_embed, vq_condition, decay, eps)

        self.pre_quant = nn.ModuleList([ 
                        StyledConv(self.base_dim, self.base_dim, 3, style_dim, upsample= (hw>8) ),
                        StyledConv(self.base_dim, self.base_dim, 3, style_dim, upsample=True),
                        StyledConv(self.base_dim, self.base_dim, 1, style_dim, upsample= (hw==32) )])
        
        self.after_quant = nn.Sequential(
            ConvLayer(self.base_dim, dim, 3),
            ConvLayer(dim, dim*2, 1))

    def get_quant_and_vector(self, style, pre_quant=None):
        vq_feat = self.base(style)
        for conv in self.pre_quant:
            vq_feat = conv(vq_feat, style)
        
        return self.vqfeat_to_quant(vq_feat)

    def forward(self, feat, quantize):
        quantize = self.after_quant( quantize ) 
        
        return self.quant_mod_feat(feat, quantize)
    
class VQ3(VQBase):
    def __init__(self, style_dim, hw, dim, n_embed, vq_condition=False, decay=0.99, eps=1e-5):
        super().__init__(style_dim, hw, dim, n_embed, vq_condition, decay, eps)

        self.pre_quant = nn.ModuleList([ 
                        StyledConv(self.base_dim, self.base_dim, 3, style_dim, upsample= (hw>8) ),
                        StyledConv(self.base_dim, self.base_dim, 3, style_dim, upsample=True),
                        StyledConv(self.base_dim, self.base_dim, 1, style_dim, upsample= (hw==32) )])
        if self.vq_condition:
            self.from_pre_vq = ConvLayer(self.base_dim, self.base_dim*2, 1)
        
        self.from_content = ConvLayer(dim, self.base_dim//4, 3)

        self.after_quant = nn.Sequential(
            ConvLayer(self.base_dim + self.base_dim//4, dim, 3),
            ConvLayer(dim, dim*2, 1))


    def get_quant_and_vector(self, style, pre_quant=None):
        vq_feat = self.base(style)
        for conv in self.pre_quant:
            vq_feat = conv(vq_feat, style)
        
        if self.vq_condition:
            vq_feat = self.pre_vq_condition(vq_feat, pre_quant)

        return self.vqfeat_to_quant(vq_feat)

    def forward(self, feat, quantize):
        content_feat = self.from_content( feat )
        quantize = self.after_quant( torch.cat([quantize, content_feat], dim=1) )
        
        return self.quant_mod_feat(feat, quantize)
    
    def pre_vq_condition(self, vq_feat, pre_quantize):
        affines = self.from_pre_vq( pre_quantize )
        affines = F.interpolate(affines, vq_feat.shape[2:])
        gamma = affines[:,:self.base_dim]
        beta = affines[:,self.base_dim:]
        out = vq_feat * gamma + beta
        return self.activate(out)

class VQ6(VQBase):
    def __init__(self, style_dim, hw, dim, n_embed, vq_condition=False, decay=0.99, eps=1e-5):
        super().__init__(style_dim, hw, dim, n_embed, vq_condition, decay, eps)

        self.pre_quant_1 = StyledConv(self.base_dim, self.base_dim, 3, style_dim, upsample= (hw>=8) )  
        if self.vq_condition:  
            self.from_pre_vq = StyledConv(self.base_dim, self.base_dim, 3, style_dim, upsample=False)
        self.pre_quant_2 = StyledConv(self.base_dim, self.base_dim, 3, style_dim, upsample= (hw>8))

        self.after_quant_1 = ConvLayer(dim, self.base_dim//2, kernel_size=1)
        self.after_quant_2 = nn.Sequential(
            ConvLayer(self.base_dim // 2 * 3, dim, 3),
            ConvLayer(dim, dim*2, 1))

    def get_quant_and_vector(self, style, pre_quant=None):
        vq_feat = self.pre_quant_1( self.base(style), style )
        if self.vq_condition:
            vq_feat = vq_feat + F.interpolate( self.from_pre_vq(pre_quant, style), vq_feat.shape[2:] )
        vq_feat = self.pre_quant_2(vq_feat, style)
        
        return self.vqfeat_to_quant(vq_feat)

    def forward(self, feat, quantize):
        feat_condition = self.after_quant_1(feat)
        quantize = self.after_quant_2( torch.cat([quantize, feat_condition], dim=1) )
        
        return self.quant_mod_feat(feat, quantize)

class GeneratorPI(Generator):
    def __init__(self, size, style_dim, n_mlp, channel_multiplier=2, dislow=2, dishigh=5, n_embed=6, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, vq=0):
        super().__init__(size, style_dim, n_mlp, channel_multiplier, blur_kernel, lr_mlp)

        self.dislow = dislow
        self.dishigh = dishigh

        self.vqs = nn.ModuleList()
        n_embeds = n_embed
        if isinstance(n_embed, int):
            n_embeds = [n_embed]*(dishigh-dislow)
        
        vq_module = VQ0
        if vq==3: vq_module = VQ3
        elif vq==6: vq_module = VQ6

        for i, vqidx in enumerate( range(self.dislow, self.dishigh) ):
            hw = idx2res[vqidx]
            self.vqs.append( vq_module(style_dim, hw, self.channels[hw], n_embeds[i], vq_condition=(i>0)) )

    def set_vq_decay(self, decay=0.9997):
        for i, _ in enumerate(range(self.dislow, self.dishigh)):
            self.vqs[i].decay = decay

    def _prepare_latents(self, styles, input_is_latent, noise, randomize_noise, truncation, truncation_latent, inject_index ):
        if not input_is_latent: styles = [self.style(s) for s in styles]

        if noise is None:
            if randomize_noise: noise = self.make_noise()
            else: noise = [getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)]

        if truncation < 1:
            style_t = []
            for style in styles: style_t.append( truncation_latent + truncation * (style - truncation_latent) )
            styles = style_t

        if len(styles) < 2:
            inject_index = self.n_latent
            if styles[0].ndim < 3: latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            else: latent = styles[0]
        else:
            if inject_index is None:
                if random.randint(0,1)==0: inject_index = random.randint(self.dishigh, self.n_latent - 1)
                else: inject_index = random.randint(1, self.dislow)

            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            if self.n_latent - inject_index > 0:
                latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
                latent = torch.cat([latent, latent2], 1)
        return latent, noise

    def _get_vqs(self, latent):
        quants, diffs, emb_inds, new_style_latents = [None], 0, [], []
        i = 0
        for vqidx in range(self.dislow, self.dishigh):
            cur_quant, cur_diff, cur_embind, cur_style = self.vqs[i].get_quant_and_vector(latent[:, vqidx], pre_quant=quants[-1])
            quants.append(cur_quant)
            diffs += cur_diff
            emb_inds.append(cur_embind)
            new_style_latents.append(cur_style)
            i+=1
        return quants[1:], diffs, emb_inds, new_style_latents

    def _forward_main(self, latent, noise):
        out = self.input(latent)
        out = self.conv1(out, latent[:, 0], noise=noise[0])
        skip = self.to_rgb1(out, latent[:, 1])

        quants, diffs, emb_inds, new_style_latents = self._get_vqs(latent)
        
        i = 1   # latent idx 3,4,5 contents the pose
        j = 0
        for conv1, conv2, noise1, noise2, to_rgb in zip(self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs):
            # 1st conv
            if self.dislow<=i<self.dishigh and self.dishigh>self.dislow:
                out = conv1(out, new_style_latents[j], noise=noise1)
            else: 
                out = conv1(out, latent[:, i], noise=noise1)
            if self.dishigh>self.dislow and self.dislow<=i<self.dishigh:
                out = self.vqs[j](out, quants[j])
                j += 1
            
            # 2nd conv
            if self.dishigh>self.dislow and self.dislow<=i+1<self.dishigh:
                out = conv2(out, new_style_latents[j], noise=noise2)
            else:
                out = conv2(out, latent[:, i + 1], noise=noise2)
            if self.dishigh>self.dislow and self.dislow<=i+1<self.dishigh:
                out = self.vqs[j](out, quants[j])
                j += 1

            skip = to_rgb(out, latent[:, i + 2], skip)

            i += 2

        image = skip
        return image, out, diffs, emb_inds

    def forward(self,
        styles,
        return_latents=False,
        inject_index=None,
        truncation=1,
        truncation_latent=None,
        input_is_latent=False,
        noise=None,
        randomize_noise=True,
        return_feature=False
    ):
        latent, noise = self._prepare_latents(styles, inject_index=inject_index,
                                            input_is_latent=input_is_latent, noise=noise,
                                            randomize_noise=randomize_noise,
                                            truncation=truncation,
                                            truncation_latent=truncation_latent)

        image, out, diff, embed_idxs = self._forward_main(latent, noise)

        if return_feature:
            return image, out, latent, diff, embed_idxs
        if return_latents:
            return image, latent, diff, embed_idxs
        else:
            return image, None, diff, embed_idxs

    def decode(self, latent, noise=None, 
        random_noise=False, return_feature=False):
        
        if noise is None:
            if random_noise:
                noise = self.make_noise()
            else:
                noise = [getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)]

        image, out, diff, embed_idxs = self._forward_main(latent, noise)

        if return_feature:
            return image, out, diff, embed_idxs
        return image, diff, embed_idxs

    def decode_with_embed(self, latent, embed_inds):
        # embed_inds: a list, each entry has shape: b x h x w
        noise = [getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)]

        out = self.input(latent)
        out = self.conv1(out, latent[:, 0], noise=noise[0])
        skip = self.to_rgb1(out, latent[:, 1])

        i = 1   # latent idx 3,4,5 contents the pose
        j = 0
        for conv1, conv2, noise1, noise2, to_rgb in zip(self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs):
            # 1st conv
            if self.dislow<=i<self.dishigh and self.dishigh>self.dislow:
                out = self.vqs[j].forward_with_embed(out, embed_inds[j], conv1, noise1)
                j += 1
            else: 
                out = conv1(out, latent[:, i], noise=noise1)
            
            # 2nd conv
            if self.dishigh>self.dislow and self.dislow<=i+1<self.dishigh:
                out = self.vqs[j].forward_with_embed(out, embed_inds[j], conv2, noise2)
                j += 1
            else:
                out = conv2(out, latent[:, i + 1], noise=noise2)
            
            skip = to_rgb(out, latent[:, i + 2], skip)
            i += 2

        image = skip
        return image

class DiscriminatorAE(Discriminator):
    def __init__(self, size, channel_multiplier=2, blur_kernel=[1,3,3,1]):

        super().__init__(size, channel_multiplier, blur_kernel)

        self.size = size
        self.extract_res = [2**p for p in range( int(math.log2(self.size)) , 1, -1)]

        self.decoder = nn.ModuleList( [
            StyledConvMod( self.channels[8], self.channels[16]//2, 3, 
                        style_dim=self.channels[8], upsample=True, blur_kernel=blur_kernel ),
            StyledConvMod( self.channels[16]//2, self.channels[32]//2, 3, 
                        style_dim=self.channels[16], upsample=True, blur_kernel=blur_kernel ),
            StyledConvMod( self.channels[32]//2, self.channels[64]//2, 3, 
                        style_dim=self.channels[32], upsample=True, blur_kernel=blur_kernel ),
            StyledConvMod( self.channels[64]//2, self.channels[128]//2, 3, 
                        style_dim=self.channels[64], upsample=True, blur_kernel=blur_kernel ),
            ToRGBMod(self.channels[128]//2, self.channels[128], upsample=False)            
        ] )

    def extract(self, input, feature_res=None ):
        if feature_res==None:
            feature_res = self.extract_res
        out = []
        feat = input
        for i in range(len(self.convs)):
            feat = self.convs[i](feat)
            if feat.shape[-1] in feature_res:
                out.append(feat)
                if feat.shape[-1]==feature_res[-1]:
                    break
        return out

    def reconstruct(self, input):
        feats = self.extract(input, feature_res=[128,64,32,16,8,4])

        decode = feats[-2]
        for i in range(4, 0, -1):
            decode = self.decoder[4-i](decode, feats[i] )

        rec_img = self.decoder[4](decode, feats[0] )
        return rec_img, feats[-1]
        
    def getRFFeat(self, out):
        batch, channel, height, width = out.shape
        group = batch #min(batch, self.stddev_group)
        stddev = out.view(
            group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
        )
        stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
        stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
        stddev = stddev.repeat(group, 1, height, width)
        out_rf = torch.cat([out, stddev], 1)

        return out_rf

    def forward(self, input, ae=False, extract=False):
        if ae: 
            rec_img, out = self.reconstruct(input)
        else: 
            if extract:
                feats = self.extract(input)
                out = feats[-1]
            else:
                out = self.convs(input)

        out_rf = self.getRFFeat(out)
        out_rf = self.final_rf(out_rf)
        
        outputs = [out_rf]
        if ae:
            outputs.append(rec_img)
        if extract:
            outputs.append(feats)

        if len(outputs)==1: return outputs[0]
        return outputs

class StyleEncoder(nn.Module):
    def __init__(self, size, style_dim, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
        super().__init__()

        channels = get_channels(channel_multiplier)

        self.style_dim = style_dim

        self.to_style = nn.ModuleList()
        for p in range( int(math.log2(size)) , 1, -1):
            s = 2**p
            rep = 2
            if s==4: rep=1
            elif s==size: rep=3
            
            for _ in range(rep):
                self.to_style.append( nn.Sequential(
                    ConvLayer(channels[s], channels[s], 3, downsample=True, blur_kernel=blur_kernel),
                    nn.AdaptiveMaxPool2d(4), Reshape([-1, channels[s]*4*4]),
                    EqualLinear(channels[s]*4*4, style_dim, activation="fused_lrelu"),
                    EqualLinear(style_dim, style_dim, activation="fused_lrelu"),
                    EqualLinear(style_dim, style_dim, activation=None),
                    ))

    def forward(self, feat_list):
        # input size: large -> small
        # output size: small -> large
        style_list = [ self.to_style[0](feat_list[0]).unsqueeze(1) ]
        for i, feat in enumerate(feat_list):
            style_list.append( self.to_style[i*2+1](feat).unsqueeze(1) )
            if feat.shape[-1] != 4:
                style_list.append( self.to_style[i*2+2](feat).unsqueeze(1) )

        style_list.reverse()
        return torch.cat(style_list, dim=1)
  
class GeneratorHR(nn.Module):
    def __init__(self, size, style_dim, n_mlp, channel_multiplier=2, dislow=2, dishigh=5, n_embed=6, blur_kernel=[1, 3, 3, 1], lr_mlp=0.01, vq=0):
        super().__init__()

        self.gen_256 = GeneratorPI(256, style_dim, n_mlp, channel_multiplier, dislow, dishigh, n_embed, blur_kernel, lr_mlp, vq=vq)
        self.n_latent = self.gen_256.n_latent
        self.channels = self.gen_256.channels
        self.style = self.gen_256.style
        
        self.convs = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.noises = nn.Module()

        self.num_extra_layers = 2 if size==512 else 4
        for layer_idx in range(self.num_extra_layers):
            res = (layer_idx + 18) // 2
            shape = [1, 1, 2 ** res, 2 ** res]
            self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))

        in_channel = self.channels[256]
        for i in range(9, int(math.log2(size))+1):
            out_channel = self.channels[2 ** i]

            self.convs.append(
                StyledConv(in_channel, out_channel, 3, style_dim,
                    upsample=True, blur_kernel=blur_kernel,))

            self.convs.append(
                StyledConv(out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel))

            self.to_rgbs.append(ToRGB(out_channel, style_dim))

            in_channel = out_channel
    
    def get_latent(self, input):
        return self.gen_256.style(input)

    def make_noise(self):
        return self.gen_256.make_noise()

    def forward(self, styles, return_latents=False, inject_index=None, 
                    truncation=1, truncation_latent=None,
                    input_is_latent=False, noise=None,
                    randomize_noise=True):
        # the input latent is only up to 256 resolution
        skip, out, latent, diff, embed_idxs = self.gen_256.forward(styles, 
                    inject_index=inject_index, truncation=truncation, truncation_latent=truncation_latent,
                    input_is_latent=input_is_latent, noise=noise, randomize_noise=randomize_noise, return_latents=True, return_feature=True)
        
        random_select = torch.randperm(latent.shape[1])[:self.num_extra_layers]
        extra_latent = latent[:,random_select]
        extra_noise = [None for _ in range(self.num_extra_layers)]
        if not randomize_noise: 
            extra_noise = [getattr(self.noises, f"noise_{i}") for i in range(self.num_extra_layers)]
        
        i=0
        for conv1, conv2, noise1, noise2, to_rgb in zip(
            self.convs[::2], self.convs[1::2], extra_noise[::2], extra_noise[1::2], self.to_rgbs
        ):
            out = conv1(out, extra_latent[:,i], noise1)
            out = conv2(out, extra_latent[:,i+1], noise2)
            skip = to_rgb(out, extra_latent[:,i+1], skip)
            i+=2
        
        if return_latents:
            latent = torch.cat([latent, extra_latent], dim=1)
            return skip, latent, diff, embed_idxs
        else:
            return skip, None, diff, embed_idxs
    
    def decode(self, latent, noise=None, random_noise=False):
        skip, out, diff, embed_idxs = self.gen_256.decode(latent, noise, random_noise, return_feature=True)
        
        extra_latent = latent[:,-self.num_extra_layers:]
        extra_noise = [None for _ in range(self.num_extra_layers)]
        if not random_noise: 
            extra_noise = [getattr(self.noises, f"noise_{i}") for i in range(self.num_extra_layers)]
        
        i=0
        for conv1, conv2, noise1, noise2, to_rgb in zip(
            self.convs[::2], self.convs[1::2], extra_noise[::2], extra_noise[1::2], self.to_rgbs
        ):
            out = conv1(out, extra_latent[:,i], noise1)
            out = conv2(out, extra_latent[:,i+1], noise2)
            skip = to_rgb(out, extra_latent[:,i+1], skip)
            i+=2
        
        return skip, diff, embed_idxs

class DiscriminatorHR(nn.Module):
    def __init__(self, size, channel_multiplier=2, blur_kernel=[1,3,3,1]):

        super().__init__()

        self.dis_256 = DiscriminatorAE(256, channel_multiplier, blur_kernel)
        self.channels = self.dis_256.channels
        self.extract_res = self.dis_256.extract_res

        self.num_extra_layer = 2 if size==1024 else 1
        extra_convs = [ConvLayer(3, self.channels[size], 1)]
        in_channel = self.channels[size]
        for i in range(self.num_extra_layer,0,-1):
            out_channel = self.channels[int(2**(7+i))]
            extra_convs.append(ResBlock(in_channel, out_channel, blur_kernel))
            in_channel = out_channel
        self.extra_convs = nn.Sequential(*extra_convs)

    def extract(self, input, feature_res=None):
        if feature_res==None:
            feature_res = self.extract_res
        out = []
        feat = input
        for i in range(len(self.extra_convs)+len(self.dis_256.convs)-1):
            if i<=self.num_extra_layer:
                feat = self.extra_convs[i](feat)
                
            else:
                feat = self.dis_256.convs[i-self.num_extra_layer](feat) # skip the first input conv
                
            if feat.shape[-1] in feature_res:
                out.append(feat)
                if feat.shape[-1]==feature_res[-1]:
                    break
        return out

    def reconstruct(self, input):
        
        feats = self.extract(input, feature_res=[128,64,32,16,8,4])
        
        decode = feats[-2]
        for i in range(4, 0, -1):
            decode = self.dis_256.decoder[4-i](decode, feats[i] )

        rec_img = self.dis_256.decoder[4](decode, feats[0] )
        return rec_img, feats[-1]

    def forward(self, input, ae=False, extract=False):
        if ae: 
            rec_img, out = self.reconstruct(input)
        else: 
            feats = self.extract(input)
            out = feats[-1]
            

        out_rf = self.dis_256.getRFFeat(out)
        out_rf = self.dis_256.final_rf(out_rf)
        
        outputs = [out_rf]
        if ae:
            outputs.append(rec_img)
        if extract:
            outputs.append(feats)

        if len(outputs)==1: return outputs[0]
        return outputs