# -*- coding: utf-8 -*-

import argparse
from pathlib import Path
import os
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.utils.data as data
from PIL import Image
from PIL import ImageFile
from tensorboardX import SummaryWriter
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader

import net_aesstyler as net
from sampler import InfiniteSamplerWrapper

cudnn.benchmark = True
Image.MAX_IMAGE_PIXELS = None  # Disable DecompressionBombError
ImageFile.LOAD_TRUNCATED_IMAGES = True  # Disable OSError: image file is truncated


def train_transform():
    transform_list = [
        transforms.Resize(size=(256, 256)),
        transforms.ToTensor()
    ]
    return transforms.Compose(transform_list)


class FlatFolderDataset(data.Dataset):
    def __init__(self, root, transform):
        super(FlatFolderDataset, self).__init__()
        self.root = root
        self.paths = os.listdir(self.root)
        self.transform = transform

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(os.path.join(self.root, path)).convert('RGB')
        img = self.transform(img)
        return img

    def __len__(self):
        return len(self.paths)

    def name(self):
        return 'FlatFolderDataset'


def adjust_learning_rate(optimizer, iteration_count):
    lr = args.lr / (1.0 + args.lr_decay * iteration_count)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--content', type=str,
                    help='File path to the content image')
parser.add_argument('--content_dir', type=str, default='inputs/contents',
                    help='Directory path to a batch of content images')
parser.add_argument('--style', type=str,
                    help='File path to the style image, or multiple style \
                    images separated by commas if you want to do \
                    style interpolation')
parser.add_argument('--style_dir', type=str, default='inputs/styles',
                    help='Directory path to a batch of style images')

# Models
parser.add_argument('--vgg', type=str, default='checkpoints/adaattn/vgg_normalised.pth')
parser.add_argument('--decoder', type=str, default='exp/decoder_iter_25000.pth')
parser.add_argument('--transform', type=str, default='exp/transformer_iter_25000.pth')
parser.add_argument('--discriminator', type=str, default='exp/discriminator_iter_25000.pth')
parser.add_argument('--net_adaattn_3', type=str, default='exp/net_adaattn_3_iter_25000.pth')
# Additional options
parser.add_argument('--content_size', type=int, default=256,
                    help='New (minimum) size for the content image, \
                    keeping the original size if set to 0')
parser.add_argument('--style_size', type=int, default=256,
                    help='New (minimum) size for the style image, \
                    keeping the original size if set to 0')
parser.add_argument('--crop', action='store_true',
                    help='do center crop to create squared image')
parser.add_argument('--save_ext', default='.jpg',
                    help='The extension name of the output image')
parser.add_argument('--output', type=str, default='outputs',
                    help='Directory to save the output image(s)')

# Advanced options
parser.add_argument('--preserve_color', action='store_true',
                    help='If specified, preserve color of the content image')
parser.add_argument('--alpha', type=float, default=1.0,
                    help='The weight that controls the degree of \
                             stylization. Should be between 0 and 1')
parser.add_argument(
    '--style_interpolation_weights', type=str, default='',
    help='The weight for blending the style of multiple style images')
parser.add_argument('--skip_connection_3', action='store_true',
                    help='if specified, add skip connection on ReLU-3')
parser.add_argument('--shallow_layer', action='store_true',
                    help='if specified, also use features of shallow layers')
parser.add_argument('--lambda_content', type=float, default=0., help='weight for L2 content loss')
parser.add_argument('--lambda_global', type=float, default=0., help='weight for L2 style loss')
parser.add_argument('--lambda_local', type=float, default=0.,help='weight for attention weighted style loss')

args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

args.skip_connection_3 = True
args.shallow_layer = True

channels = 512 + 256 + 128 + 64

max_sample = 64 * 64
net_adaattn_3 = net.AdaAttN_ori(in_planes=256, key_planes=256 + 128 + 64,
                             max_sample=max_sample)
# decoder = net.decoder
transform = net.Transformer(
    in_planes=512, key_planes=channels, shallow_layer=True)
decoder = net.Decoder(True)
# transform = net.Transform(in_planes=512)
vgg = net.vgg
discriminator = net.AesDiscriminator_new()
disc_ad = net.Discriminator()
# discriminator.eval()

decoder.load_state_dict(torch.load(args.decoder))
transform.load_state_dict(torch.load(args.transform))
vgg.load_state_dict(torch.load(args.vgg))
net_adaattn_3.load_state_dict(torch.load(args.net_adaattn_3))
# discriminator.load_state_dict(torch.load(args.discriminator))

vgg = nn.Sequential(*list(vgg.children())[:44])
network = net.Net(vgg, decoder, discriminator, disc_ad, transform, net_adaattn_3, args)
network.eval()
network.to(device)

discriminator.model.eval()

content_tf = train_transform()
style_tf = train_transform()

content_dataset = FlatFolderDataset(args.content_dir, content_tf)
style_dataset = FlatFolderDataset(args.style_dir, style_tf)

content_dataloader = DataLoader(content_dataset, batch_size=1, num_workers=8, shuffle=False)
style_dataloader = DataLoader(style_dataset, batch_size=1, num_workers=8, shuffle=False)
style_iter = iter(style_dataloader)

output_dir = args.output
os.makedirs(output_dir, exist_ok=True)

for idx, content_images in enumerate(content_dataloader):
    style_images = next(style_iter).to(device)
    content_images = content_images.to(device)

    # if i < args.stage1_iter:
    output, loss_c, loss_local, loss_global, loss_gan_g, loss_ad_g, loss_ad_d, score = network(
        content_images, style_images, aesthetic=True)

    output.clamp(0, 255)
    output = output.cpu()
    # print(output)
    for i in range(content_images.shape[0]):
        output_name = f'{output_dir}/{str(idx * content_images.shape[0] + i)}_stylized_{str(idx * content_images.shape[0] + i)}{args.save_ext}'
        save_image(output[i], str(output_name))
