import math
import random

import torch
from torch import nn
from torch.nn import functional as F

from model.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
from model.common import Discriminator


def make_kernel(k):
    k = torch.tensor(k, dtype=torch.float32)
    if k.ndim == 1:
        k = k[None, :] * k[:, None]
    k /= k.sum()

    return k

class ConstantInput(nn.Module):
    def __init__(self, in_size):
        super().__init__()
        in_isize, in_channel, in_psize, _ = in_size
        num = in_isize // in_psize
        self.input = nn.Parameter(torch.randn(1, num**2, in_channel))

    def forward(self, batch_size):
        out = self.input.repeat(batch_size, 1, 1)

        return out

class EqualLinear(nn.Module):
    def __init__(
        self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
    ):
        super().__init__()

        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))

        else:
            self.bias = None

        self.activation = activation

        self.scale = (1 / math.sqrt(in_dim)) * lr_mul
        self.lr_mul = lr_mul

    def forward(self, input):
        bias = self.bias
        if bias is not None: bias = bias * self.lr_mul
        if self.activation:
            out = F.linear(input, self.weight * self.scale)
            out = fused_leaky_relu(out, bias)
        else:
            out = F.linear(input, self.weight * self.scale, bias=bias)
        return out

    def __repr__(self):
        return (f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})")

class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return input * torch.rsqrt(torch.mean(input ** 2, dim= -1, keepdim=True) + 1e-8)

class Blur(nn.Module):
    def __init__(self, blur_kernel=[1.,3.,3.,1.]):
        super().__init__()

        kernel = make_kernel(blur_kernel)
        self.register_buffer('kernel', kernel)
        p = len(blur_kernel)
        self.pad1, self.pad2 = (p-1)//2, math.ceil((p-1.0)/2.0)

    def forward(self, input):
        out = F.pad(input, (self.pad1, self.pad2, self.pad1, self.pad2), mode='reflect')
        out = upfirdn2d(out, self.kernel)

        return out

class Upsample(nn.Module):
    def __init__(self, ipsize, opsize=None, kernel=[1.,3.,3.,1.], factor=2):
        super().__init__()

        self.factor = factor
        self.ipsize = ipsize 
        self.opsize = opsize

        if factor > 1: 
            kernel = make_kernel(kernel) * (factor ** 2)
            self.register_buffer("kernel", kernel)
            p = kernel.shape[0] - factor
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2
            self.pad = (pad0, pad1)

    def forward(self, x, flat=False):
        if self.factor == 1: 
            return x 

        if len(x.size()) == 3: 
            b, t, c = x.size()
            n = int(math.sqrt(t))
            channel = c // (self.ipsize ** 2)
            x = x.view(b, n, n, channel, self.ipsize, self.ipsize)
            x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(b, channel, n*self.ipsize, n*self.ipsize)
        # upsample 
        x = upfirdn2d(x, self.kernel, up=self.factor, down=1, pad=self.pad)
        if flat: 
            b, c, h, _ = x.size()
            n = h // self.opsize
            x = x.view(b, channel, n, self.opsize, n, self.opsize)
            x = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(b, n**2, channel*(self.opsize**2))

        return x


class EqualConv(nn.Module):
    def __init__(self, in_size, out_size, bias_init=0, factor=1):
        super().__init__()

        _, ichannel, ipsize, _ = in_size 
        _, ochannel, opsize, _ = out_size 
        self.in_size, self.out_size = in_size, out_size
        self.upsampler = Upsample(ipsize, opsize, factor=factor)

        in_dim = ichannel * (opsize ** 2)
        out_dim= ochannel * (opsize ** 2)
        self.scale = 1 / math.sqrt(in_dim)
        self.weight = nn.Parameter(torch.randn(out_dim, in_dim)) 
        self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
        self.activate = FusedLeakyReLU(ochannel * opsize ** 2)

    def forward(self, x):
        x = self.upsampler(x, flat=True)
        x = F.linear(x, self.weight * self.scale)
        x = x + self.bias 
        b, t, c = x.size()
        x = self.activate(x.view(b*t, c))
        x = x.view(b, t, c)
        return x


class ToRGB(nn.Module):
    def __init__(self, in_size, upsample=True):
        super().__init__()

        self.ires, in_channel, ipsize, _ = in_size 
        self.ipsize = ipsize 
        self.scale = 1 / math.sqrt(in_channel*(ipsize**2))
        self.weight = nn.Parameter(torch.randn(3*(ipsize**2), in_channel*(ipsize**2)))
        self.bias = nn.Parameter(torch.zeros(3*(ipsize**2)))
        if upsample: 
            self.upsample = Upsample(ipsize, ipsize, factor=2)

    def forward(self, x, skip=None):
        x = F.linear(x, self.weight * self.scale)
        x = x + self.bias 

        b, t, c = x.size()
        n = self.ires // self.ipsize
        x = x.view(b, n, n, 3, self.ipsize, self.ipsize)
        x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(b, 3, n*self.ipsize, n*self.ipsize)

        if skip is not None:
            skip = self.upsample(skip)
            x = x + skip

        return x


# --------------------------------------------------
#                    Main Designs 
# --------------------------------------------------
def norm(input, norm_type='layernorm'): 
    # [b, hw, c]
    if norm_type == 'layernorm' or norm_type == 'l2norm':
        normdim = -1 
    elif norm_type == 'insnorm': 
        normdim = 1
    else: 
        raise NotImplementedError('have not implemented this type of normalization')

    if norm_type != 'l2norm': 
        mean = torch.mean(input, dim=normdim, keepdim=True)
        input = input - mean 
        
    demod = torch.rsqrt(torch.sum(input ** 2, dim=normdim, keepdim=True) + 1e-8)
    return  input * demod


class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, qdim, kdim, vdim):
        super(MultiheadAttention, self).__init__()
        
        self.scale = embed_dim  ** -0.5
        self.to_q = EqualLinear(qdim, embed_dim, bias=False)
        self.to_k = EqualLinear(kdim, embed_dim, bias=False)
        self.to_v = EqualLinear(vdim, embed_dim, bias=False)

    def forward(self, q, k, v):
        b, n, dim = q.size()
        q, k, v = self.to_q(q), self.to_k(k), self.to_v(v)

        dots = torch.bmm(q, k.transpose(1, 2)) * self.scale
        attn = dots.softmax(dim=-1)
        out = torch.bmm(attn, v)
        return out, (attn.detach(),)


class ModulatedConv(nn.Module):
    def __init__(
        self,
        in_size, 
        style_dim,
        style_mod='prod', 
        norm_type='layernorm'
    ):
        super().__init__()

        ires, in_channel, self.ipsize, style_num = in_size 
        self.style_mod = style_mod
        self.norm_type = norm_type
        self.keys = nn.Parameter(nn.init.orthogonal_(torch.empty(1, style_num, in_channel)))
        self.pos = nn.Parameter(torch.zeros(1, (ires // self.ipsize) ** 2, in_channel))
        self.attention = MultiheadAttention(in_channel, in_channel, in_channel, style_dim)
    
    def forward(self, input, style, is_new_style=False):
        b, t, c = input.size()

        # remove old style 
        input = norm(input, norm_type=self.norm_type)
        input = input.view(b, t, -1, self.ipsize, self.ipsize)

        # calculate new style 
        if not is_new_style: 
            # multi-head attention 
            query = torch.mean(input, dim=[3,4])
            keys = self.keys.repeat(input.size(0), 1, 1)
            pos = self.pos.repeat(input.size(0), 1, 1)
            new_style, _ = self.attention(q=query+pos, k=keys, v=style)
        else: 
            new_style = style

        # append new style 
        if self.style_mod == 'prod': 
            out = input * new_style.unsqueeze(-1).unsqueeze(-1)
        elif self.style_mod == 'plus': 
            out = input + new_style.unsqueeze(-1).unsqueeze(-1)
        else: 
            raise NotImplementedError('Have not implemented this type of style modulation')

        out = out.view(b, t, c)
        return out, (new_style.detach(), )



# ---------------------------------------------------------
#                       Networks 
# ---------------------------------------------------------

class MappingNetwork(nn.Module): 
    def __init__(
        self,
        style_dim=512,
        style_num=16,
        n_mlp = 8,
        lr_mlp=0.01
    ): 
        super().__init__()

        layers = [PixelNorm()]
        for i in range(n_mlp - 1):
            layers.append(EqualLinear(style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"))
        layers.append(EqualLinear(style_dim, style_num*style_dim, lr_mul=lr_mlp, activation="fused_lrelu"))
        self.style_dim = style_dim
        self.style_num = style_num
        self.style = nn.Sequential(*layers)
        
    def forward(self, z): 
        w = self.style(z)
        w = w.view(-1, self.style_num, self.style_dim)
        return w
        

class Generator(nn.Module):
    def __init__(
        self,
        args, 
        size,
        style_dim=512,
        n_mlp = 8,
        lr_mlp=0.01,
        channel_multiplier=2,
    ):
        super().__init__()

        self.size = size
        self.style_dim = style_dim
        self.style_num = args.style_num #16/32/64/128
        self.configs = {
            8:  (8, 512, 1, self.style_num), 
            16: (16, 512, 1, self.style_num), 
            32: (32, 512, 1, self.style_num), 
            64: (64,   256 * channel_multiplier, 1, self.style_num), 
            128: (128, 128 * channel_multiplier, 2, self.style_num), #1
            256: (256, 64  * channel_multiplier, 2, self.style_num), #1
            512: (512, 32  * channel_multiplier, 4, self.style_num), 
            1024: (1024, 16 * channel_multiplier, 4, self.style_num), 
        }

        self.style = MappingNetwork(
            style_dim=style_dim, style_num=args.style_num, n_mlp=n_mlp, lr_mlp=lr_mlp)

        self.input = ConstantInput(in_size=self.configs[8])
        self.log_size = int(math.log(size, 2))
        self.n_latent = (self.log_size - 3) * 2  + 1

        self.mod1 = ModulatedConv(
            self.configs[8], style_dim, style_mod=args.style_mod, norm_type=args.norm)
        self.conv1 = EqualConv(self.configs[8], self.configs[8], factor=1)
        self.to_rgb1 = ToRGB(self.configs[8], upsample=False)

        self.mods = nn.ModuleList()
        self.convs = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        
        for i in range(4, self.log_size + 1):
            self.mods.append(ModulatedConv(
                self.configs[2**(i-1)], style_dim, style_mod=args.style_mod, norm_type=args.norm))
            self.convs.append(EqualConv(self.configs[2**(i-1)], self.configs[2**i], factor=2))
            self.mods.append(ModulatedConv(
                self.configs[2**i], style_dim, style_mod=args.style_mod, norm_type=args.norm))
            if i == self.log_size: 
                self.convs.append(None)
            else: 
                self.convs.append(EqualConv(self.configs[2**i], self.configs[2**i], factor=1))
            self.to_rgbs.append(ToRGB(self.configs[2**i], upsample=True))


    def mean_latent(self, n_latent):
        latent_in = torch.randn(
            n_latent, self.style_dim, device=self.input.input.device
        )
        latent = self.style(latent_in).mean(0, keepdim=True)

        return latent

    def get_latent(self, input):
        return self.style(input)


    def forward(
        self,
        styles,
        inject_index=None,
        kv_index=None,
        truncation=1,
        truncation_latent=None,
        input_is_latent=False,
        input_is_style=False,
        return_styles=False,
        return_latents=False,
    ):
    
        if not input_is_style: 
            if not input_is_latent:
                styles = [self.style(s) for s in styles]

            if truncation < 1:
                style_t = []
                for style in styles:
                    style_t.append(truncation_latent + truncation * (style - truncation_latent))
                styles = style_t

            if len(styles) < 2:
                inject_index = self.n_latent
                if styles[0].ndim == 3: # (b, num, dim) -> (layer, b, num, dim)
                    latent = styles[0].unsqueeze(0).repeat(inject_index, 1, 1, 1)
                else:
                    latent = styles[0]
            else:
                if kv_index is None: 
                    kv_index = random.randint(1, self.style_num - 1), random.randint(1, self.style_num - 1)
                latent1 = torch.cat([styles[0][:, :kv_index[0], :], styles[1][:, kv_index[0]:, :]], dim=1)
                latent2 = torch.cat([styles[1][:, :kv_index[1], :], styles[0][:, kv_index[1]:, :]], dim=1)
                if inject_index is None:
                    inject_index = random.randint(1, self.n_latent - 1)
                latent1 = latent1.unsqueeze(0).repeat(inject_index, 1, 1, 1)
                latent2 = latent2.unsqueeze(0).repeat(self.n_latent - inject_index, 1, 1, 1)
                latent = torch.cat([latent1, latent2], dim=0)
        else: 
            latent = styles[0]

        # -------------------------
        #    network feedforward
        # -------------------------
        out = self.input(latent[0].size(0))
        out, _ = self.mod1(out, latent[0], is_new_style=input_is_style)
        skip = self.to_rgb1(out)
        out = self.conv1(out)

        i = 1
        for mod1, mod2, conv1, conv2, to_rgb in zip(
            self.mods[::2], self.mods[1::2], self.convs[::2], self.convs[1::2], self.to_rgbs
        ):
            out, _ = mod1(out, latent[i], is_new_style=input_is_style)
            out = conv1(out)
            out, _ = mod2(out, latent[i+1], is_new_style=input_is_style)
            skip = to_rgb(out, skip)
            if conv2 is not None: 
                out = conv2(out)
            i += 2

        image = skip

        if return_latents:
            return image, latent
        else:
            return image, None