import functools
import os

import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F

# from torch.nn import Parameter as P
import sys

sys.path.insert(1, os.path.join(sys.path[0], ".."))
import BigGAN.layers as layers

# from sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d
from BigGAN.diffaugment_utils import DiffAugment


# Architectures for G
# Attention is passed in in the format '32_64' to mean applying an attention
# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
def G_arch(ch=64, attention="64", ksize="333333", dilation="111111"):
    arch = {}
    arch[512] = {
        "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
        "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
        "upsample": [True] * 7,
        "resolution": [8, 16, 32, 64, 128, 256, 512],
        "attention": {
            2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
            for i in range(3, 10)
        },
    }
    arch[256] = {
        "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2]],
        "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1]],
        "upsample": [True] * 6,
        "resolution": [8, 16, 32, 64, 128, 256],
        "attention": {
            2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
            for i in range(3, 9)
        },
    }
    arch[128] = {
        "in_channels": [ch * item for item in [16, 16, 8, 4, 2]],
        "out_channels": [ch * item for item in [16, 8, 4, 2, 1]],
        "upsample": [True] * 5,
        "resolution": [8, 16, 32, 64, 128],
        "attention": {
            2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
            for i in range(3, 8)
        },
    }
    arch[64] = {
        "in_channels": [ch * item for item in [16, 16, 8, 4]],
        "out_channels": [ch * item for item in [16, 8, 4, 2]],
        "upsample": [True] * 4,
        "resolution": [8, 16, 32, 64],
        "attention": {
            2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
            for i in range(3, 7)
        },
    }
    arch[32] = {
        "in_channels": [ch * item for item in [4, 4, 4]],
        "out_channels": [ch * item for item in [4, 4, 4]],
        "upsample": [True] * 3,
        "resolution": [8, 16, 32],
        "attention": {
            2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
            for i in range(3, 6)
        },
    }

    return arch


class Generator(nn.Module):
    def __init__(
            self,
            G_ch=64,
            dim_z=128,
            bottom_width=4,
            resolution=128,
            G_kernel_size=3,
            G_attn="64",
            n_classes=1000,
            num_G_SVs=1,
            num_G_SV_itrs=1,
            G_shared=True,
            shared_dim=0,
            hier=False,
            cross_replica=False,
            mybn=False,
            G_activation=nn.ReLU(inplace=False),
            G_lr=5e-5,
            G_B1=0.0,
            G_B2=0.999,
            adam_eps=1e-8,
            BN_eps=1e-5,
            SN_eps=1e-12,
            G_mixed_precision=False,
            G_fp16=False,
            G_init="ortho",
            skip_init=False,
            no_optim=False,
            G_param="SN",
            norm_style="bn",
            class_cond=True,
            embedded_optimizer=True,
            instance_cond=False,
            G_shared_feat=True,
            shared_dim_feat=2048,
            **kwargs
    ):
        super(Generator, self).__init__()
        # Channel width mulitplier
        self.ch = G_ch
        # Dimensionality of the latent space
        self.dim_z = dim_z
        # The initial spatial dimensions
        self.bottom_width = bottom_width
        # Resolution of the output
        self.resolution = resolution
        # Kernel size?
        self.kernel_size = G_kernel_size
        # Attention?
        self.attention = G_attn
        # number of classes, for use in categorical conditional generation
        self.n_classes = n_classes
        # Use shared embeddings?
        self.G_shared = G_shared
        # Dimensionality of the shared embedding? Unused if not using G_shared
        self.shared_dim = shared_dim if shared_dim > 0 else dim_z
        # Hierarchical latent space?
        self.hier = hier
        # Cross replica batchnorm?
        self.cross_replica = cross_replica
        # Use my batchnorm?
        self.mybn = mybn
        # nonlinearity for residual blocks
        self.activation = G_activation
        # Initialization style
        self.init = G_init
        # Parameterization style
        self.G_param = G_param
        # Normalization style
        self.norm_style = norm_style
        # Epsilon for BatchNorm?
        self.BN_eps = BN_eps
        # Epsilon for Spectral Norm?
        self.SN_eps = SN_eps
        # fp16?
        self.fp16 = G_fp16
        # Use embeddings for instance features?
        self.G_shared_feat = G_shared_feat
        self.shared_dim_feat = shared_dim_feat
        # Architecture dict
        self.arch = G_arch(self.ch, self.attention)[resolution]

        # If using hierarchical latents, adjust z
        if self.hier:
            # Number of places z slots into
            self.num_slots = len(self.arch["in_channels"]) + 1
            self.z_chunk_size = self.dim_z // self.num_slots
            # Recalculate latent dimensionality for even splitting into chunks
            self.dim_z = self.z_chunk_size * self.num_slots
        else:
            self.num_slots = 1
            self.z_chunk_size = 0

        # Which convs, batchnorms, and linear layers to use
        if self.G_param == "SN":
            self.which_conv = functools.partial(
                layers.SNConv2d,
                kernel_size=3,
                padding=1,
                num_svs=num_G_SVs,
                num_itrs=num_G_SV_itrs,
                eps=self.SN_eps,
            )
            self.which_linear = functools.partial(
                layers.SNLinear,
                num_svs=num_G_SVs,
                num_itrs=num_G_SV_itrs,
                eps=self.SN_eps,
            )
        else:
            self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
            self.which_linear = nn.Linear

        # We use a non-spectral-normed embedding here regardless;
        # For some reason applying SN to G's embedding seems to randomly cripple G
        self.which_embedding = nn.Embedding
        bn_linear = (
            functools.partial(self.which_linear, bias=False)
            if self.G_shared
            else self.which_embedding
        )
        if not class_cond and not instance_cond:
            input_sz_bn = self.n_classes
        else:
            input_sz_bn = self.z_chunk_size
        if class_cond:
            input_sz_bn += self.shared_dim
        if instance_cond:
            input_sz_bn += self.shared_dim_feat
        self.which_bn = functools.partial(
            layers.ccbn,
            which_linear=bn_linear,
            cross_replica=self.cross_replica,
            mybn=self.mybn,
            input_size=input_sz_bn,
            norm_style=self.norm_style,
            eps=self.BN_eps,
        )

        # Prepare model
        # If not using shared embeddings, self.shared is just a passthrough
        self.shared = (
            self.which_embedding(n_classes, self.shared_dim)
            if G_shared
            else layers.identity()
        )
        self.shared_feat = (
            self.which_linear(2048, self.shared_dim_feat)
            if G_shared_feat
            else layers.identity()
        )
        # First linear layer
        self.linear = self.which_linear(
            self.dim_z // self.num_slots,
            self.arch["in_channels"][0] * (self.bottom_width ** 2),
        )

        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        # while the inner loop is over a given block
        self.blocks = []
        for index in range(len(self.arch["out_channels"])):
            self.blocks += [
                [
                    layers.GBlock(
                        in_channels=self.arch["in_channels"][index],
                        out_channels=self.arch["out_channels"][index],
                        which_conv=self.which_conv,
                        which_bn=self.which_bn,
                        activation=self.activation,
                        upsample=(
                            functools.partial(F.interpolate, scale_factor=2)
                            if self.arch["upsample"][index]
                            else None
                        ),
                    )
                ]
            ]

            # If attention on this block, attach it to the end
            if self.arch["attention"][self.arch["resolution"][index]]:
                print(
                    "Adding attention layer in G at resolution %d"
                    % self.arch["resolution"][index]
                )
                self.blocks[-1] += [
                    layers.Attention(self.arch["out_channels"][index], self.which_conv)
                ]

        # Turn self.blocks into a ModuleList so that it's all properly registered.
        self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])

        # output layer: batchnorm-relu-conv.
        # Consider using a non-spectral conv here
        self.output_layer = nn.Sequential(
            layers.bn(
                self.arch["out_channels"][-1],
                cross_replica=self.cross_replica,
                mybn=self.mybn,
            ),
            self.activation,
            self.which_conv(self.arch["out_channels"][-1], 3),
        )

        # Initialize weights. Optionally skip init for testing.
        if not skip_init:
            self.init_weights()

        # Set up optimizer
        # If this is an EMA copy, no need for an optim, so just return now
        if no_optim or not embedded_optimizer:
            return
        self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
        if G_mixed_precision:
            print("Using fp16 adam in G...")
            from cifar100.generative_models.BigGAN import utils

            self.optim = utils.Adam16(
                params=self.parameters(),
                lr=self.lr,
                betas=(self.B1, self.B2),
                weight_decay=0,
                eps=self.adam_eps,
            )
        else:
            self.optim = optim.Adam(
                params=self.parameters(),
                lr=self.lr,
                betas=(self.B1, self.B2),
                weight_decay=0,
                eps=self.adam_eps,
            )

        # LR scheduling, left here for forward compatibility
        # self.lr_sched = {'itr' : 0}# if self.progressive else {}
        # self.j = 0

    # Initialize
    def init_weights(self):
        self.param_count = 0
        for module in self.modules():
            if (
                    isinstance(module, nn.Conv2d)
                    or isinstance(module, nn.Linear)
                    or isinstance(module, nn.Embedding)
            ):
                if self.init == "ortho":
                    init.orthogonal_(module.weight)
                elif self.init == "N02":
                    init.normal_(module.weight, 0, 0.02)
                elif self.init in ["glorot", "xavier"]:
                    init.xavier_uniform_(module.weight)
                else:
                    print("Init style not recognized...")
                self.param_count += sum(
                    [p.data.nelement() for p in module.parameters()]
                )
        print("Param count for G" "s initialized parameters: %d" % self.param_count)

        # Get conditionings

    def get_condition_embeddings(self, cl=None, feat=None):
        c_embed = []
        if cl is not None:
            c_embed.append(self.shared(cl))
        if feat is not None:
            c_embed.append(self.shared_feat(feat))
        if len(c_embed) > 0:
            c_embed = torch.cat(c_embed, dim=-1)
        return c_embed

    # Note on this forward function: we pass in a y vector which has
    # already been passed through G.shared to enable easy class-wise
    # interpolation later. If we passed in the one-hot and then ran it through
    # G.shared in this forward function, it would be harder to handle.
    def forward(self, z, label=None, feats=None):
        y = self.get_condition_embeddings(label, feats)
        # If hierarchical, concatenate zs and ys
        if self.hier:
            zs = torch.split(z, self.z_chunk_size, 1)
            z = zs[0]
            ys = [torch.cat([y, item], 1) for item in zs[1:]]
        else:
            ys = [y] * len(self.blocks)

        # First linear layer
        h = self.linear(z)
        # Reshape
        h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)

        # Loop over blocks
        for index, blocklist in enumerate(self.blocks):
            # Second inner loop in case block has multiple layers
            for block in blocklist:
                h = block(h, ys[index])

        # Apply batchnorm-relu-conv-tanh at output
        return torch.tanh(self.output_layer(h))


# Discriminator architecture, same paradigm as G's above
def D_arch(ch=64, attention="64", ksize="333333", dilation="111111"):
    arch = {}
    arch[256] = {
        "in_channels": [3] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
        "out_channels": [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
        "downsample": [True] * 6 + [False],
        "resolution": [128, 64, 32, 16, 8, 4, 4],
        "attention": {
            2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
            for i in range(2, 8)
        },
    }
    arch[128] = {
        "in_channels": [3] + [ch * item for item in [1, 2, 4, 8, 16]],
        "out_channels": [item * ch for item in [1, 2, 4, 8, 16, 16]],
        "downsample": [True] * 5 + [False],
        "resolution": [64, 32, 16, 8, 4, 4],
        "attention": {
            2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
            for i in range(2, 8)
        },
    }
    arch[64] = {
        "in_channels": [3] + [ch * item for item in [1, 2, 4, 8]],
        "out_channels": [item * ch for item in [1, 2, 4, 8, 16]],
        "downsample": [True] * 4 + [False],
        "resolution": [32, 16, 8, 4, 4],
        "attention": {
            2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
            for i in range(2, 7)
        },
    }
    arch[32] = {
        "in_channels": [3] + [item * ch for item in [4, 4, 4]],
        "out_channels": [item * ch for item in [4, 4, 4, 4]],
        "downsample": [True, True, False, False],
        "resolution": [16, 16, 16, 16],
        "attention": {
            2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
            for i in range(2, 6)
        },
    }
    return arch


class Discriminator(nn.Module):
    def __init__(
            self,
            D_ch=64,
            D_wide=True,
            resolution=128,
            D_kernel_size=3,
            D_attn="64",
            n_classes=1000,
            num_D_SVs=1,
            num_D_SV_itrs=1,
            D_activation=nn.ReLU(inplace=False),
            D_lr=2e-4,
            D_B1=0.0,
            D_B2=0.999,
            adam_eps=1e-8,
            SN_eps=1e-12,
            output_dim=1,
            D_mixed_precision=False,
            D_fp16=False,
            D_init="ortho",
            skip_init=False,
            D_param="SN",
            class_cond=True,
            embedded_optimizer=True,
            instance_cond=False,
            instance_sz=2048,
            **kwargs
    ):
        super(Discriminator, self).__init__()
        # Width multiplier
        self.ch = D_ch
        # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
        self.D_wide = D_wide
        # Resolution
        self.resolution = resolution
        # Kernel size
        self.kernel_size = D_kernel_size
        # Attention?
        self.attention = D_attn
        # Number of classes
        self.n_classes = n_classes
        # Activation
        self.activation = D_activation
        # Initialization style
        self.init = D_init
        # Parameterization style
        self.D_param = D_param
        # Epsilon for Spectral Norm?
        self.SN_eps = SN_eps
        # Fp16?
        self.fp16 = D_fp16
        # Architecture
        self.arch = D_arch(self.ch, self.attention)[resolution]

        # Which convs, batchnorms, and linear layers to use
        # No option to turn off SN in D right now
        if self.D_param == "SN":
            self.which_conv = functools.partial(
                layers.SNConv2d,
                kernel_size=3,
                padding=1,
                num_svs=num_D_SVs,
                num_itrs=num_D_SV_itrs,
                eps=self.SN_eps,
            )
            self.which_linear = functools.partial(
                layers.SNLinear,
                num_svs=num_D_SVs,
                num_itrs=num_D_SV_itrs,
                eps=self.SN_eps,
            )
            self.which_embedding = functools.partial(
                layers.SNEmbedding,
                num_svs=num_D_SVs,
                num_itrs=num_D_SV_itrs,
                eps=self.SN_eps,
            )
        # Prepare model
        # self.blocks is a doubly-nested list of modules, the outer loop intended
        # to be over blocks at a given resolution (resblocks and/or self-attention)
        self.blocks = []
        for index in range(len(self.arch["out_channels"])):
            self.blocks += [
                [
                    layers.DBlock(
                        in_channels=self.arch["in_channels"][index],
                        out_channels=self.arch["out_channels"][index],
                        which_conv=self.which_conv,
                        wide=self.D_wide,
                        activation=self.activation,
                        preactivation=(index > 0),
                        downsample=(
                            nn.AvgPool2d(2) if self.arch["downsample"][index] else None
                        ),
                    )
                ]
            ]
            # If attention on this block, attach it to the end
            if self.arch["attention"][self.arch["resolution"][index]]:
                print(
                    "Adding attention layer in D at resolution %d"
                    % self.arch["resolution"][index]
                )
                self.blocks[-1] += [
                    layers.Attention(self.arch["out_channels"][index], self.which_conv)
                ]
        # Turn self.blocks into a ModuleList so that it's all properly registered.
        self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
        # Linear output layer. The output dimension is typically 1, but may be
        # larger if we're e.g. turning this into a VAE with an inference output
        self.linear = self.which_linear(self.arch["out_channels"][-1], output_dim)
        # Embedding for projection discrimination
        if class_cond and instance_cond:
            self.linear_feat = self.which_linear(
                instance_sz, self.arch["out_channels"][-1] // 2
            )
            self.embed = self.which_embedding(
                self.n_classes, self.arch["out_channels"][-1] // 2
            )
        elif class_cond:
            # Embedding for projection discrimination
            self.embed = self.which_embedding(
                self.n_classes, self.arch["out_channels"][-1]
            )
        elif instance_cond:
            self.linear_feat = self.which_linear(
                instance_sz, self.arch["out_channels"][-1]
            )

        # Initialize weights
        if not skip_init:
            self.init_weights()

        # Set up optimizer
        if embedded_optimizer:
            self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
            if D_mixed_precision:
                print("Using fp16 adam in D...")
                from cifar100.generative_models.BigGAN import utils

                self.optim = utils.Adam16(
                    params=self.parameters(),
                    lr=self.lr,
                    betas=(self.B1, self.B2),
                    weight_decay=0,
                    eps=self.adam_eps,
                )
            else:
                self.optim = optim.Adam(
                    params=self.parameters(),
                    lr=self.lr,
                    betas=(self.B1, self.B2),
                    weight_decay=0,
                    eps=self.adam_eps,
                )
        # LR scheduling, left here for forward compatibility
        # self.lr_sched = {'itr' : 0}# if self.progressive else {}
        # self.j = 0

    # Initialize
    def init_weights(self):
        self.param_count = 0
        for module in self.modules():
            if (
                    isinstance(module, nn.Conv2d)
                    or isinstance(module, nn.Linear)
                    or isinstance(module, nn.Embedding)
            ):
                if self.init == "ortho":
                    init.orthogonal_(module.weight)
                elif self.init == "N02":
                    init.normal_(module.weight, 0, 0.02)
                elif self.init in ["glorot", "xavier"]:
                    init.xavier_uniform_(module.weight)
                else:
                    print("Init style not recognized...")
                self.param_count += sum(
                    [p.data.nelement() for p in module.parameters()]
                )
        print("Param count for D" "s initialized parameters: %d" % self.param_count)

    def forward(self, x, y=None, feat=None):
        # Stick x into h for cleaner for loops without flow control
        h = x
        # Loop over blocks
        for index, blocklist in enumerate(self.blocks):
            for block in blocklist:
                h = block(h)
        # Apply global sum pooling as in SN-GAN
        h = torch.sum(self.activation(h), [2, 3])
        # Get initial class-unconditional output
        out = self.linear(h)
        # Condition on both class and instance features
        if y is not None and feat is not None:
            out = out + torch.sum(
                torch.cat([self.embed(y), self.linear_feat(feat)], dim=-1) * h,
                1,
                keepdim=True,
            )
        # Condition on class only
        elif y is not None:
            # Get projection of final featureset onto class vectors and add to evidence
            out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
        # Condition on instance features only
        elif feat is not None:
            out = out + torch.sum(self.linear_feat(feat) * h, 1, keepdim=True)
        return out


# Parallelized G_D to minimize cross-gpu communication
# Without this, Generator outputs would get all-gathered and then rebroadcast.
class G_D(nn.Module):
    def __init__(self, G, D, optimizer_G=None, optimizer_D=None):
        super(G_D, self).__init__()
        self.G = G
        self.D = D
        self.optimizer_G = optimizer_G
        self.optimizer_D = optimizer_D

    def forward(
            self,
            z,
            gy,
            feats_g=None,
            x=None,
            dy=None,
            feats=None,
            train_G=False,
            return_G_z=False,
            split_D=False,
            policy=False,
            DA=False,
    ):
        # If training G, enable grad tape
        with torch.set_grad_enabled(train_G):
            # Get Generator output given noise
            G_z = self.G(z, gy, feats_g)
            # Cast as necessary
            # if self.G.fp16 and not self.D.fp16:
            #   G_z = G_z.float()
            # if self.D.fp16 and not self.G.fp16:
            #   G_z = G_z.half()
        # Split_D means to run D once with real data and once with fake,
        # rather than concatenating along the batch dimension.
        if split_D:
            D_fake = self.D(G_z, gy, feats_g)
            if x is not None:
                D_real = self.D(x, dy, feats)
                return D_fake, D_real
            else:
                if return_G_z:
                    return D_fake, G_z
                else:
                    return D_fake
        # If real data is provided, concatenate it with the Generator's output
        # along the batch dimension for improved efficiency.
        else:
            D_input = torch.cat([G_z, x], 0) if x is not None else G_z
            D_class = torch.cat([gy, dy], 0) if dy is not None else gy
            if feats_g is not None:
                D_feats = (
                    torch.cat([feats_g, feats], 0) if feats is not None else feats_g
                )
            else:
                D_feats = None
            if DA:
                D_input = DiffAugment(D_input, policy=policy)
            # Get Discriminator output
            D_out = self.D(D_input, D_class, D_feats)
            if x is not None:
                return torch.split(D_out, [G_z.shape[0], x.shape[0]])  # D_fake, D_real
            else:
                if return_G_z:
                    return D_out, G_z
                else:
                    return D_out


