# -*- coding: utf-8 -*-
import sys

sys.path.append('./')
sys.path.append('../')

import torch
import torch.nn as nn
from function import calc_mean_std, mean_variance_norm
from torchsummary import summary
from torchvision import transforms
import torch.nn.init as init
import torch.autograd as autograd
from torch.autograd import Variable

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


# decoder = nn.Sequential(
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(512, 256, (3, 3)),
#     nn.ReLU(),
#     nn.Upsample(scale_factor=2, mode='nearest'),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(256, 256, (3, 3)),
#     nn.ReLU(),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(256, 256, (3, 3)),
#     nn.ReLU(),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(256, 256, (3, 3)),
#     nn.ReLU(),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(256, 128, (3, 3)),
#     nn.ReLU(),
#     nn.Upsample(scale_factor=2, mode='nearest'),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(128, 128, (3, 3)),
#     nn.ReLU(),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(128, 64, (3, 3)),
#     nn.ReLU(),
#     nn.Upsample(scale_factor=2, mode='nearest'),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(64, 64, (3, 3)),
#     nn.ReLU(),
#     nn.ReflectionPad2d((1, 1, 1, 1)),
#     nn.Conv2d(64, 3, (3, 3)),
# )


class Decoder(nn.Module):

    def __init__(self, skip_connection_3=False):
        super(Decoder, self).__init__()
        self.decoder_layer_1 = nn.Sequential(
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(512, 256, (3, 3)),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest')
        )
        self.decoder_layer_2 = nn.Sequential(
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(256 + 256 if skip_connection_3 else 256, 256, (3, 3)),
            nn.ReLU(),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(256, 256, (3, 3)),
            nn.ReLU(),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(256, 256, (3, 3)),
            nn.ReLU(),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(256, 128, (3, 3)),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(128, 128, (3, 3)),
            nn.ReLU(),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(128, 64, (3, 3)),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(64, 64, (3, 3)),
            nn.ReLU(),
            nn.ReflectionPad2d((1, 1, 1, 1)),
            nn.Conv2d(64, 3, (3, 3))
        )

    def forward(self, cs, c_adain_3_feat=None):
        cs = self.decoder_layer_1(cs)
        if c_adain_3_feat is None:
            cs = self.decoder_layer_2(cs)
        else:
            cs = self.decoder_layer_2(torch.cat((cs, c_adain_3_feat), dim=1))
        return cs


vgg = nn.Sequential(
    nn.Conv2d(3, 3, (1, 1)),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(3, 64, (3, 3)),
    nn.ReLU(),  # relu1-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),  # relu1-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 128, (3, 3)),
    nn.ReLU(),  # relu2-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),  # relu2-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 256, (3, 3)),
    nn.ReLU(),  # relu3-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 512, (3, 3)),
    nn.ReLU(),  # relu4-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-1, this is the last layer used
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU()  # relu5-4
)


# Aesthetic discriminator
class AesDiscriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(AesDiscriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=False))
            return layers

        # Construct three discriminator models
        self.models = nn.ModuleList()
        self.score_models = nn.ModuleList()
        for i in range(3):
            self.models.append(
                nn.Sequential(
                    *discriminator_block(in_channels, 64, normalize=False),
                    *discriminator_block(64, 128),
                    *discriminator_block(128, 256),
                    *discriminator_block(256, 512)
                )
            )
            self.score_models.append(
                nn.Sequential(
                    nn.Conv2d(512, 1, 3, padding=1)
                )
            )

        self.downsample = nn.AvgPool2d(in_channels, stride=2, padding=[1, 1], count_include_pad=False)

    # Compute the MSE between model output and scalar gt
    def compute_loss(self, x, gt):
        _, outputs = self.forward(x)

        loss = sum([torch.mean((out - gt) ** 2) for out in outputs])
        return loss

    def forward(self, x):
        # input [b, 3, 224, 224]
        # feats: torch.Size([1, 512, 16, 16]) torch.Size([1, 512, 8, 8]) torch.Size([1, 512, 4, 4])
        # feat: [1, 512, 16, 16]
        outputs = []
        feats = []
        for i in range(len(self.models)):
            feats.append(self.models[i](x))
            outputs.append(self.score_models[i](self.models[i](x)))
            x = self.downsample(x)

        self.upsample = nn.Upsample(size=(feats[0].size()[2], feats[0].size()[3]), mode='nearest')
        feat = feats[0]
        for i in range(1, len(feats)):
            feat += self.upsample(feats[i])
        # print(feats[0].shape, feats[1].shape, feats[2].shape)
        return feat, outputs


class AesDiscriminator_new(nn.Module):
    def __init__(self, in_channels=3):
        super(AesDiscriminator_new, self).__init__()

        # Construct three discriminator models
        self.model = TANet()
        self.model.eval()
        self.model.load_state_dict(
            torch.load('./checkpoints/tanet/SRCC_513_LCC_531_MSE_016.pth',
                       map_location='cuda:0'))
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(14, 14))
        # self.conv = nn.Conv2d(in_channels=1536, out_channels=512, kernel_size=(1, 1))

        self.downsample = nn.AvgPool2d(in_channels, stride=2, padding=[1, 1], count_include_pad=False)

    # Compute the MSE between model output and scalar gt
    def compute_loss(self, x, gt):
        _, outputs = self.forward(x)

        # print(outputs)
        # loss = sum([torch.mean((out - gt) ** 2) for out in outputs])
        score = sum([out for out in outputs])
        loss = -sum([out for out in outputs])
        # print(loss)
        return loss, score

    def forward(self, x):
        transform = transforms.Compose([transforms.Resize((224, 224))])
        x = transform(x)
        output, feat = self.model(x)

        return feat, output


# the newly added discriminator to suppress the strange texture
class Discriminator(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Filters [256, 512, 1024]
        # Input_dim = channels (Cx64x64)
        # Output_dim = 1
        self.main_module = nn.Sequential(
            # Omitting batch normalization in critic because our new penalized training objective (WGAN with gradient penalty) is no longer valid
            # in this setting, since we penalize the norm of the critic's gradient with respect to each input independently and not the enitre batch.
            # There is not good & fast implementation of layer normalization --> using per instance normalization nn.InstanceNorm2d()
            # Image (3x256x256)
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(32, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            # State (32x128x128)
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(64, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            # State (64x64x64)
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            # State (128x32x32)
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            # State (256x16x16)
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512, affine=True),
            nn.LeakyReLU(0.2, inplace=True),

            # State (512x8x8)
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(1024, affine=True),
            nn.LeakyReLU(0.2, inplace=True))
        # output of main module --> State (1024x4x4)

        self.output = nn.Sequential(
            # The output of D is no longer a probability, we do not apply sigmoid at the output of D.
            nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0))

        # Apply Kaiming initialization to all the convolutional layers
        self.apply(self.init_weights)

    def forward(self, x):
        x = self.main_module(x)
        return self.output(x)

    def feature_extraction(self, x):
        # Use discriminator for feature extraction then flatten to vector of 16384
        x = self.main_module(x)
        return x.view(-1, 1024 * 4 * 4)

    def init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')

    # Compute the MSE between model output and scalar gt
    def compute_loss(self, x, gt):
        outputs = self.forward(x)

        # print(outputs)
        loss = sum([(out - gt) ** 2 for out in outputs])
        # loss = -sum([out for out in outputs])
        # print(loss)
        return loss


class AdaAttN_ori(nn.Module):

    def __init__(self, in_planes, max_sample=256 * 256, key_planes=None):
        super(AdaAttN_ori, self).__init__()
        if key_planes is None:
            key_planes = in_planes
        self.f = nn.Conv2d(key_planes, key_planes, (1, 1))
        self.g = nn.Conv2d(key_planes, key_planes, (1, 1))
        self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.sm = nn.Softmax(dim=-1)
        self.max_sample = max_sample

    def forward(self, content, style, content_key, style_key, seed=None):
        F = self.f(content_key)
        G = self.g(style_key)
        H = self.h(style)
        b, _, h_g, w_g = G.size()
        G = G.view(b, -1, w_g * h_g).contiguous()
        if w_g * h_g > self.max_sample:
            if seed is not None:
                torch.manual_seed(seed)
            index = torch.randperm(w_g * h_g).to(content.device)[:self.max_sample]
            G = G[:, :, index]
            style_flat = H.view(b, -1, w_g * h_g)[:, :, index].transpose(1, 2).contiguous()
        else:
            style_flat = H.view(b, -1, w_g * h_g).transpose(1, 2).contiguous()
        b, _, h, w = F.size()
        F = F.view(b, -1, w * h).permute(0, 2, 1)
        S = torch.bmm(F, G)
        # S: b, n_c, n_s
        S = self.sm(S)
        # mean: b, n_c, c
        mean = torch.bmm(S, style_flat)
        # std: b, n_c, c
        std = torch.sqrt(torch.relu(torch.bmm(S, style_flat ** 2) - mean ** 2))
        # mean, std: b, c, h, w
        mean = mean.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
        std = std.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
        return std * mean_variance_norm(content) + mean


class AdaAttN(nn.Module):

    def __init__(self, in_planes, max_sample=256 * 256, key_planes=None, planes2=None, planes3=None):
        super(AdaAttN, self).__init__()
        if key_planes is None:
            key_planes = in_planes
        self.f = nn.Conv2d(key_planes, key_planes, (1, 1))
        self.g = nn.Conv2d(key_planes, key_planes, (1, 1))
        self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.sm = nn.Softmax(dim=-1)
        self.max_sample = max_sample

        self.a = nn.Conv2d(planes2, planes3, (1, 1))
        self.b = nn.Conv2d(planes2, planes3, (1, 1))
        self.c = nn.Conv2d(planes3, planes3, (1, 1))
        self.c1 = nn.Conv2d(planes3, in_planes, (1, 1))

        self.d = nn.Conv2d(planes2, in_planes, (1, 1))
        self.e = nn.Conv2d(key_planes, in_planes, (1, 1))
        self.e1 = nn.Conv2d(in_planes, in_planes, (1, 1))

    def forward(self, content, style, content_key, style_key, aesthetic_feat, global_feat, aesthetic_key, global_key,
                seed=None):
        A = self.a(global_key)
        B = self.b(aesthetic_key)
        C = self.c(aesthetic_feat)
        b, _, h_g, w_g = B.size()
        B = B.view(b, -1, w_g * h_g).permute(0, 2, 1)
        if w_g * h_g > self.max_sample:
            if seed is not None:
                torch.manual_seed(seed)
            index = torch.randperm(w_g * h_g).to(content.device)[:self.max_sample]
            B = B[:, :, index]
            style_flat = C.view(b, -1, w_g * h_g)[:, :, index].contiguous()
        else:
            style_flat = C.view(b, -1, w_g * h_g).contiguous()
        b, _, h, w = A.size()
        A = A.view(b, -1, w * h).contiguous()
        S = torch.bmm(A, B)  # C * C
        S = self.sm(S)
        # mean: b, n_c, c
        mean = torch.bmm(S, style_flat)
        # std: b, n_c, c
        std = torch.sqrt(torch.relu(torch.bmm(S, style_flat ** 2) - mean ** 2))
        # mean, std: b, c, h, w
        mean = mean.view(b, -1, h, w).contiguous()
        std = std.view(b, -1, h, w).contiguous()
        # gstyle = std * mean_variance_norm(global_feat) + mean
        gstyle = std * (global_feat) + mean
        gstyle = self.c1(gstyle)

        D = self.d(aesthetic_key)
        E = self.e(style_key)
        E1 = self.e1(style)
        # print(D.shape, E.shape, E1.shape, gstyle.shape)
        b, _, h_g, w_g = E.size()
        E = E.view(b, -1, w_g * h_g).permute(0, 2, 1)
        if w_g * h_g > self.max_sample:
            if seed is not None:
                torch.manual_seed(seed)
            index = torch.randperm(w_g * h_g).to(content.device)[:self.max_sample]
            E = E[:, :, index]
            style_flat = E1.view(b, -1, w_g * h_g)[:, :, index].contiguous()
        else:
            style_flat = E1.view(b, -1, w_g * h_g).contiguous()
        b, _, h, w = D.size()
        D = D.view(b, -1, w * h).contiguous()
        S = torch.bmm(D, E)  # C * C
        S = self.sm(S)

        # print(S.shape, style_flat.shape)

        # mean: b, n_c, c
        mean = torch.bmm(S, style_flat)
        # std: b, n_c, c
        std = torch.sqrt(torch.relu(torch.bmm(S, style_flat ** 2) - mean ** 2))
        # mean, std: b, c, h, w
        mean = mean.view(b, -1, h, w).contiguous()
        std = std.view(b, -1, h, w).contiguous()
        # gstyle = std * mean_variance_norm(global_feat) + mean
        astyle = std * (gstyle) + mean

        F = self.f(content_key)
        G = self.g(style_key)
        H = self.h(astyle)
        b, _, h_g, w_g = G.size()
        G = G.view(b, -1, w_g * h_g).contiguous()
        if w_g * h_g > self.max_sample:
            if seed is not None:
                torch.manual_seed(seed)
            index = torch.randperm(w_g * h_g).to(content.device)[:self.max_sample]
            G = G[:, :, index]
            style_flat = H.view(b, -1, w_g * h_g)[:, :, index].transpose(1, 2).contiguous()
        else:
            style_flat = H.view(b, -1, w_g * h_g).transpose(1, 2).contiguous()
        b, _, h, w = F.size()
        F = F.view(b, -1, w * h).permute(0, 2, 1)
        S = torch.bmm(F, G)
        # S: b, n_c, n_s
        S = self.sm(S)
        # mean: b, n_c, c
        mean = torch.bmm(S, style_flat)
        # std: b, n_c, c
        std = torch.sqrt(torch.relu(torch.bmm(S, style_flat ** 2) - mean ** 2))
        # mean, std: b, c, h, w
        mean = mean.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
        std = std.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
        return std * mean_variance_norm(content) + mean


class Transformer(nn.Module):

    def __init__(self, in_planes, key_planes=None, shallow_layer=False):
        super(Transformer, self).__init__()
        self.attn_adain_4_1 = AdaAttN(in_planes=in_planes, key_planes=key_planes, planes2=352, planes3=160)
        self.attn_adain_5_1 = AdaAttN(in_planes=in_planes,
                                      key_planes=key_planes + 512 if shallow_layer else key_planes,
                                      planes2=352 + 320, planes3=320)

        self.upsample5_1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.merge_conv_pad = nn.ReflectionPad2d((1, 1, 1, 1))
        self.merge_conv = nn.Conv2d(in_planes, in_planes, (3, 3))

    def forward(self, content4_1, style4_1, content5_1, style5_1,
                content4_1_key, style4_1_key, content5_1_key, style5_1_key,
                aesthetic_feat4_1, global_feats4_1, aesthetic_feat5_1, global_feats5_1,
                aesthetic_feat4_1_key, global_feats4_1_key, aesthetic_feat5_1_key, global_feats5_1_key,
                seed=None):
        self.upsample_style4_1 = nn.Upsample(size=(style4_1.size()[2], style4_1.size()[3]), mode='nearest')
        self.upsample_style5_1 = nn.Upsample(size=(style5_1.size()[2], style5_1.size()[3]), mode='nearest')

        return self.merge_conv(self.merge_conv_pad(
            self.attn_adain_4_1(content4_1, style4_1, content4_1_key, style4_1_key,
                                self.upsample_style4_1(aesthetic_feat4_1), self.upsample_style4_1(global_feats4_1),
                                self.upsample_style4_1(aesthetic_feat4_1_key),
                                self.upsample_style4_1(global_feats4_1_key), seed=seed) +
            self.upsample5_1(self.attn_adain_5_1(content5_1, style5_1, content5_1_key, style5_1_key,
                                                 self.upsample_style5_1(aesthetic_feat5_1),
                                                 self.upsample_style5_1(global_feats5_1),
                                                 self.upsample_style5_1(aesthetic_feat5_1_key),
                                                 self.upsample_style5_1(global_feats5_1_key), seed=seed))))


class Net(nn.Module):
    def __init__(self, encoder, decoder, discriminator, disc_ad, Transform=None, net_adaattn_3=None, args=None):
        super(Net, self).__init__()
        enc_layers = list(encoder.children())
        self.enc_1 = nn.Sequential(*enc_layers[:4])  # input -> relu1_1
        self.enc_2 = nn.Sequential(*enc_layers[4:11])  # relu1_1 -> relu2_1
        self.enc_3 = nn.Sequential(*enc_layers[11:18])  # relu2_1 -> relu3_1
        self.enc_4 = nn.Sequential(*enc_layers[18:31])  # relu3_1 -> relu4_1
        self.enc_5 = nn.Sequential(*enc_layers[31:44])  # relu4_1 -> relu5_1

        self.args = args
        self.seed = 6666
        self.max_sample = 64 * 64

        if net_adaattn_3 is not None:
            self.net_adaattn_3 = net_adaattn_3

        if Transform is None:
            self.transform = Transform(in_planes=512)
        else:
            self.transform = Transform
        self.decoder = decoder
        self.discriminator = discriminator
        self.disc_ad = disc_ad
        self.cross_entropy_loss = nn.CrossEntropyLoss()

        self.mse_loss = nn.MSELoss()

        # fix the encoder
        for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4', 'enc_5']:
            for param in getattr(self, name).parameters():
                param.requires_grad = False

        # codebook
        self.style_dict = torch.load('./codebook/style_dict_1k.pt',
                                     map_location='cuda')
        self.image_list = torch.load('./codebook/image_lists_1k.pt',
                                     map_location='cuda')

    # extract relu1_1, relu2_1, relu3_1, relu4_1, relu5_1 features from input image
    def encode_with_intermediate(self, input):
        results = [input]
        for i in range(5):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
        return results[1:]

    @staticmethod
    def get_key(feats, last_layer_idx, need_shallow=True):
        if need_shallow and last_layer_idx > 0:
            results = []
            _, _, h, w = feats[last_layer_idx].shape
            for i in range(last_layer_idx):
                results.append(mean_variance_norm(nn.functional.interpolate(feats[i], (h, w))))
            results.append(mean_variance_norm(feats[last_layer_idx]))
            return torch.cat(results, dim=1)
        else:
            return mean_variance_norm(feats[last_layer_idx])

    # content loss
    def compute_content_loss(self, stylized_feats):
        self.loss_content = torch.tensor(0., device='cuda')
        if self.args.lambda_content > 0:
            for i in range(1, 5):
                self.loss_content += self.mse_loss(mean_variance_norm(stylized_feats[i]),
                                                   mean_variance_norm(self.content_feats[i]))

    def compute_style_loss(self, stylized_feats):
        self.loss_global = torch.tensor(0., device='cuda')
        if self.args.lambda_global > 0:
            for i in range(1, 5):
                s_feats_mean, s_feats_std = calc_mean_std(self.style_feats[i])
                stylized_feats_mean, stylized_feats_std = calc_mean_std(stylized_feats[i])
                self.loss_global += self.mse_loss(
                    stylized_feats_mean, s_feats_mean) + self.mse_loss(stylized_feats_std, s_feats_std)
        self.loss_local = torch.tensor(0., device='cuda')
        if self.args.lambda_local > 0:
            for i in range(1, 5):
                c_key = self.get_key(self.content_feats, i, self.args.shallow_layer)
                s_key = self.get_key(self.style_feats, i, self.args.shallow_layer)
                s_value = self.style_feats[i]
                b, _, h_s, w_s = s_key.size()
                s_key = s_key.view(b, -1, h_s * w_s).contiguous()
                if h_s * w_s > self.max_sample:
                    torch.manual_seed(self.seed)
                    index = torch.randperm(h_s * w_s).to('cuda')[:self.max_sample]
                    s_key = s_key[:, :, index]
                    style_flat = s_value.view(b, -1, h_s * w_s)[:, :, index].transpose(1, 2).contiguous()
                else:
                    style_flat = s_value.view(b, -1, h_s * w_s).transpose(1, 2).contiguous()
                b, _, h_c, w_c = c_key.size()
                c_key = c_key.view(b, -1, h_c * w_c).permute(0, 2, 1).contiguous()
                attn = torch.bmm(c_key, s_key)
                # S: b, n_c, n_s
                attn = torch.softmax(attn, dim=-1)
                # mean: b, n_c, c
                mean = torch.bmm(attn, style_flat)
                # std: b, n_c, c
                std = torch.sqrt(torch.relu(torch.bmm(attn, style_flat ** 2) - mean ** 2))
                # mean, std: b, c, h, w
                mean = mean.view(b, h_c, w_c, -1).permute(0, 3, 1, 2).contiguous()
                std = std.view(b, h_c, w_c, -1).permute(0, 3, 1, 2).contiguous()
                self.loss_local += self.mse_loss(stylized_feats[i],
                                                 std * mean_variance_norm(self.content_feats[i]) + mean)

    def compute_losses(self, gt_feats):
        self.compute_content_loss(gt_feats)
        self.compute_style_loss(gt_feats)

    def forward(self, content, style, aesthetic=False):
        self.style_feats = self.encode_with_intermediate(style)
        self.content_feats = self.encode_with_intermediate(content)

        if aesthetic:
            # codebook
            style_feat_tmp = self.style_feats[4]
            # print(style_feat_tmp.shape)
            style_feat_tmp = style_feat_tmp.view(-1, 512 * 16 * 16)

            style_dict = self.style_dict.unsqueeze(0).expand(style_feat_tmp.shape[0], -1, -1)

            L2_length1 = torch.norm(style_dict, p=2, dim=2)  # [6, 1000]
            # print(L2_length1.shape, L2_length1_.shape)
            # assert torch.equal(L2_length1, L2_length1_)
            L2_length2 = torch.norm(style_feat_tmp, p=2, dim=1).unsqueeze(-1)  # [6, 1]

            # print(L2_length1.shape, L2_length2.shape)
            L2 = torch.bmm(L2_length1.unsqueeze(-1), L2_length2.unsqueeze(-1)).squeeze(-1)

            cos_sim = torch.bmm(style_dict, style_feat_tmp.unsqueeze(-1)).squeeze(-1)  # [6, 1000]

            cos_sim = cos_sim / (L2 + 1e-10)

            _, idx = torch.topk(cos_sim, 1, dim=1, largest=True)
            idx.squeeze(-1).squeeze(-1)

            image_idxs = [self.image_list[i] for i in idx]
            # print(f'aes path: {image_idxs}')

            global_aes_feature1 = []
            global_aes_feature2 = []
            global_aes_feature3 = []
            global_aes_feature4 = []
            global_aes_feature5 = []
            for i in range(len(image_idxs)):
                [f1, f2, f3, f4, f5] = (
                    torch.load(f'./codebook/aes_dict_1k_new/{image_idxs[i][:-4]}.pt',
                               map_location='cuda'))
                global_aes_feature1.append(f1)
                global_aes_feature2.append(f2)
                global_aes_feature3.append(f3)
                global_aes_feature4.append(f4)
                global_aes_feature5.append(f5)
            global_aes_feature1 = torch.stack(global_aes_feature1, dim=0)
            global_aes_feature2 = torch.stack(global_aes_feature2, dim=0)
            global_aes_feature3 = torch.stack(global_aes_feature3, dim=0)
            global_aes_feature4 = torch.stack(global_aes_feature4, dim=0)
            global_aes_feature5 = torch.stack(global_aes_feature5, dim=0)

            global_feats = [global_aes_feature1, global_aes_feature2, global_aes_feature3, global_aes_feature4,
                            global_aes_feature5]

            aesthetic_s_feats, _ = self.discriminator(style)

            if self.args.skip_connection_3:
                c_adain_feat_3 = self.net_adaattn_3(self.content_feats[2], self.style_feats[2],
                                                    self.get_key(self.content_feats, 2, self.args.shallow_layer),
                                                    self.get_key(self.style_feats, 2, self.args.shallow_layer),
                                                    self.seed)
            else:
                c_adain_feat_3 = None

            stylized = self.transform(self.content_feats[3], self.style_feats[3], self.content_feats[4],
                                      self.style_feats[4],
                                      self.get_key(self.content_feats, 3, self.args.shallow_layer),
                                      self.get_key(self.style_feats, 3, self.args.shallow_layer),
                                      self.get_key(self.content_feats, 4, self.args.shallow_layer),
                                      self.get_key(self.style_feats, 4, self.args.shallow_layer),
                                      aesthetic_s_feats[3], global_feats[3], aesthetic_s_feats[4], global_feats[4],
                                      self.get_key(aesthetic_s_feats, 3, self.args.shallow_layer),
                                      self.get_key(global_feats, 3, self.args.shallow_layer),
                                      self.get_key(aesthetic_s_feats, 4, self.args.shallow_layer),
                                      self.get_key(global_feats, 4, self.args.shallow_layer),
                                      self.seed)

        else:
            stylized = self.transform(self.content_feats[3], self.style_feats[3], self.content_feats[4],
                                      self.style_feats[4])

        g_t = self.decoder(stylized, c_adain_feat_3)
        g_t_feats = self.encode_with_intermediate(g_t)

        self.compute_losses(g_t_feats)

        # gradient_penalty = self.calculate_gradient_penalty(style.data, g_t.detach().data)

        # adversarial loss
        loss_ad_d = self.disc_ad.compute_loss(style, 1) + self.disc_ad.compute_loss(g_t.detach(), 0)
        loss_ad_g = self.disc_ad.compute_loss(g_t, 1)

        loss_gan_g, score = self.discriminator.compute_loss(g_t, 1)

        return g_t, self.loss_content, self.loss_local, self.loss_global, loss_gan_g, loss_ad_g, loss_ad_d, score

    def calculate_gradient_penalty(self, real_images, fake_images):
        eta = torch.FloatTensor(real_images.shape[0], 1, 1, 1).uniform_(0, 1)
        eta = eta.expand(real_images.shape[0], real_images.size(1), real_images.size(2), real_images.size(3))
        eta = eta.cuda()

        interpolated = eta * real_images + ((1 - eta) * fake_images)

        interpolated = interpolated.cuda()

        # define it to calculate gradient
        interpolated = Variable(interpolated, requires_grad=True)

        # calculate probability of interpolated examples
        prob_interpolated = self.disc_ad(interpolated)

        # calculate gradients of probabilities with respect to examples
        gradients = autograd.grad(outputs=prob_interpolated, inputs=interpolated,
                                  grad_outputs=torch.ones(prob_interpolated.size()).cuda(),
                                  create_graph=True, retain_graph=True)[0]

        grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 0.001
        return grad_penalty


### TaNet ***

import os
import torch
import numpy as np
import math
import torch.optim as optim
# import option
import nni
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torchvision import models
# from dataset import AVADataset
from util import EDMLoss, AverageMeter
from tensorboardX import SummaryWriter
from tqdm import tqdm
from scipy.stats import pearsonr
from scipy.stats import spearmanr
from sklearn.metrics import accuracy_score
from nni.utils import merge_parameter
from torchsummary import summary

# opt = option.init()
device = torch.device("cuda:0")
MOBILE_NET_V2_UTR = 'https://s3-us-west-1.amazonaws.com/models-nima/mobilenetv2.pth.tar'


def adjust_learning_rate(params, optimizer, epoch):
    """Sets the learning rate to the initial LR
       decayed by 10 every 30 epochs"""
    lr = params.init_lr * (0.1 ** (epoch // 10))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU(inplace=True)
    )


def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU(inplace=True)
    )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        self.use_res_connect = self.stride == 1 and inp == oup

        self.conv = nn.Sequential(
            # pw
            nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False),
            nn.BatchNorm2d(inp * expand_ratio),
            nn.ReLU6(inplace=True),
            # dw
            nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False),
            nn.BatchNorm2d(inp * expand_ratio),
            nn.ReLU6(inplace=True),
            # pw-linear
            nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
        )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, n_class=1000, input_size=224, width_mult=1.):
        super(MobileNetV2, self).__init__()
        # setting of inverted residual blocks
        self.interverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        # building first layer
        assert input_size % 32 == 0
        input_channel = int(32 * width_mult)
        self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280
        self.features = [conv_bn(3, input_channel, 2)]
        # building inverted residual blocks
        for t, c, n, s in self.interverted_residual_setting:
            output_channel = int(c * width_mult)
            for i in range(n):
                if i == 0:
                    self.features.append(InvertedResidual(input_channel, output_channel, s, t))
                else:
                    self.features.append(InvertedResidual(input_channel, output_channel, 1, t))
                input_channel = output_channel
        # building last several layers
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
        # self.features.append(nn.AvgPool2d(input_size // 32))
        # make it nn.Sequential
        self.features = nn.Sequential(*self.features)

        # avgpool
        self.avgpool = nn.AvgPool2d(input_size // 32)

        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.0),
            nn.Linear(self.last_channel, n_class),
        )

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(-1, self.last_channel)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


def resnet365_backbone():
    arch = 'resnet18'
    # load the pre-trained weights
    model_file = './checkpoints/tanet/resnet18_places365.pth.tar'
    last_model = models.__dict__[arch](num_classes=365)

    checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
    state_dict = {str.replace(k, 'module.', ''): v for k, v in checkpoint['state_dict'].items()}
    last_model.load_state_dict(state_dict)
    return last_model


def mobile_net_v2(pretrained=False):
    model = MobileNetV2()

    if pretrained:
        print("read mobilenet weights")
        path_to_model = './checkpoints/tanet/mobilenetv2.pth.tar'
        state_dict = torch.load(path_to_model, map_location=lambda storage, loc: storage)
        model.load_state_dict(state_dict)
    return model


def Attention(x):
    batch_size, in_channels, h, w = x.size()
    quary = x.view(batch_size, in_channels, -1)  # b*c*hw
    key = quary  # b*c*hw
    quary = quary.permute(0, 2, 1)  # b*hw*c

    sim_map = torch.matmul(quary, key)

    # print(sim_map.shape)

    ql2 = torch.norm(quary, dim=2, keepdim=True)
    kl2 = torch.norm(key, dim=1, keepdim=True)
    sim_map = torch.div(sim_map, torch.matmul(ql2, kl2).clamp(min=1e-8))

    return sim_map


def MV2():
    model = mobile_net_v2()
    model = nn.Sequential(*list(model.children())[:-1])
    return model


class L5(nn.Module):
    def __init__(self):
        super(L5, self).__init__()
        back_model = MV2()
        self.base_model = back_model

        self.head = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.0),
            nn.Linear(1280, 10),
            # nn.Softmax(dim=1)
        )

    def forward(self, x):
        # [40, 1280, 1, 1]
        # print(f'mobilenet output shape:{x.shape}')
        # summary(base_model1, input_size=(3, 224, 224))

        features = []  # Initialize an empty list to store intermediate feature maps
        for idx, module in enumerate(self.base_model):
            if idx == 0:
                for idx1, module1, in enumerate(module):
                    x = module1(x)
                    if idx1 == 6:
                        # [40, 32, 28, 28]
                        features.append(x)
                    if idx1 == 10:
                        # [-1, 64, 14, 14]
                        features.append(x)
                    if idx1 == 13:
                        # [40, 96, 14, 14]
                        features.append(x)
                    if idx1 == 16:
                        # [-1, 160, 7, 7]
                        features.append(x)
                    if idx1 == 17:
                        # [40, 320, 7, 7]
                        features.append(x)
            else:
                x = module(x)
        # x = self.base_model(x)
        x = x.view(x.size(0), -1)
        x = self.head(x)
        # features = torch.cat(features, dim=1)
        # features = self.conv4(features)
        return x, features


class L1(nn.Module):

    def __init__(self):
        super(L1, self).__init__()

        self.last_out_w = nn.Linear(365, 100)
        self.last_out_b = nn.Linear(365, 1)

        # initialize
        for i, m_name in enumerate(self._modules):
            if i > 2:
                nn.init.kaiming_normal_(self._modules[m_name].weight.data)

    def forward(self, x):
        res_last_out_w = self.last_out_w(x)
        res_last_out_b = self.last_out_b(x)
        param_out = {}
        param_out['res_last_out_w'] = res_last_out_w
        param_out['res_last_out_b'] = res_last_out_b
        return param_out


# L3
class TargetNet(nn.Module):

    def __init__(self):
        super(TargetNet, self).__init__()
        # L2
        self.fc1 = nn.Linear(365, 100)
        for i, m_name in enumerate(self._modules):
            if i > 2:
                nn.init.kaiming_normal_(self._modules[m_name].weight.data)
        self.bn1 = nn.BatchNorm1d(100).cuda()
        self.relu1 = nn.PReLU()
        # self.drop1 = nn.Dropout(1 - 0.5)

        self.relu7 = nn.PReLU()
        self.relu7.cuda()
        self.sig = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x, paras):

        q = self.fc1(x)
        # print(q.shape)
        q = self.bn1(q)
        q = self.relu1(q)
        # q = self.drop1(q)

        self.lin = nn.Sequential(TargetFC(paras['res_last_out_w'], paras['res_last_out_b']))
        q = self.lin(q)
        q = self.softmax(q)
        return q


class TargetFC(nn.Module):
    def __init__(self, weight, bias):
        super(TargetFC, self).__init__()
        self.weight = weight
        self.bias = bias

    def forward(self, input_):
        out = F.linear(input_, self.weight, self.bias)
        return out


class TANet(nn.Module):
    def __init__(self):
        super(TANet, self).__init__()
        self.res365_last = resnet365_backbone()
        self.hypernet = L1()

        # L3
        self.tygertnet = TargetNet()

        self.avg = nn.AdaptiveAvgPool2d((10, 1))
        self.avg_RGB = nn.AdaptiveAvgPool2d((12, 12))

        self.mobileNet = L5()
        self.softmax = nn.Softmax(dim=1)

        # L4
        self.head_rgb = nn.Sequential(
            nn.ReLU(),
            nn.Dropout(p=0.0),
            nn.Linear(20736, 10),
            nn.Softmax(dim=1)
        )

        # L6
        self.head = nn.Sequential(
            nn.ReLU(),
            nn.Dropout(p=0.0),
            nn.Linear(30, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x_temp = self.avg_RGB(x)
        x_temp = Attention(x_temp)
        x_temp = x_temp.view(x_temp.size(0), -1)
        x_temp = self.head_rgb(x_temp)

        res365_last_out = self.res365_last(x)
        res365_last_out_weights = self.hypernet(res365_last_out)
        res365_last_out_weights_mul_out = self.tygertnet(res365_last_out, res365_last_out_weights)

        x2 = res365_last_out_weights_mul_out.unsqueeze(dim=2)
        x2 = self.avg(x2)
        x2 = x2.squeeze(dim=2)

        x1, feats = self.mobileNet(x)

        x = torch.cat([x1, x2, x_temp], 1)
        x = self.head(x)

        return x, feats
