# -*- coding: utf-8 -*-

import argparse
from pathlib import Path
import os
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.utils.data as data
from PIL import Image
from PIL import ImageFile
from tensorboardX import SummaryWriter
from torchvision import transforms
from torchvision.utils import save_image
import wandb

import net_aesstyler as net
from sampler import InfiniteSamplerWrapper

cudnn.benchmark = True
Image.MAX_IMAGE_PIXELS = None  # Disable DecompressionBombError
ImageFile.LOAD_TRUNCATED_IMAGES = True  # Disable OSError: image file is truncated


def train_transform():
    transform_list = [
        transforms.Resize(size=(512, 512)),
        transforms.RandomCrop(256),
        transforms.ToTensor()
    ]
    return transforms.Compose(transform_list)


class FlatFolderDataset(data.Dataset):
    def __init__(self, root, transform):
        super(FlatFolderDataset, self).__init__()
        self.root = root
        self.paths = os.listdir(self.root)
        self.transform = transform

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(os.path.join(self.root, path)).convert('RGB')
        img = self.transform(img)
        return img, path

    def __len__(self):
        return len(self.paths)

    def name(self):
        return 'FlatFolderDataset'


def adjust_learning_rate(optimizer, iteration_count):
    lr = args.lr / (1.0 + args.lr_decay * iteration_count)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--content_dir', type=str, default='PATH_TO_COCO',
                    help='Directory path to a batch of content images')
parser.add_argument('--style_dir', type=str, default='PATH_TO_WIKIART',
                    help='Directory path to a batch of style images')
# Models
parser.add_argument('--vgg', type=str,
                    default='checkpoints/adaattn/vgg_normalised.pth')
parser.add_argument('--decoder', type=str,
                    default='checkpoints/adaattn/AdaAttN/latest_net_decoder.pth')
parser.add_argument('--transform', type=str,
                    default='checkpoints/adaattn/AdaAttN/latest_net_transformer.pth')
parser.add_argument('--net_adaattn_3', type=str,
                    default='checkpoints/adaattn/AdaAttN/latest_net_adaattn_3.pth')
parser.add_argument('--sample_path', type=str,
                    default='samples',
                    help='Derectory to save the intermediate samples')

# training options
parser.add_argument('--save_dir',
                    default='./exp',
                    help='Directory to save the model')
parser.add_argument('--log_dir',
                    default='./logs',
                    help='Directory to save the log')
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--lr_decay', type=float, default=5e-5)
parser.add_argument('--stage1_iter', type=int, default=0)
parser.add_argument('--stage2_iter', type=int, default=30000)
parser.add_argument('--batch_size', type=int, default=6)
parser.add_argument('--lambda_content', type=float, default=0.0, help='weight for L2 content loss')
parser.add_argument('--lambda_global', type=float, default=3.5, help='weight for L2 style loss')
parser.add_argument('--lambda_local', type=float, default=0.4, help='weight for attention weighted style loss')
parser.add_argument('--gan_weight', type=float, default=1.2)
parser.add_argument('--ad_weight', type=float, default=5.0)
parser.add_argument('--n_threads', type=int, default=16)
parser.add_argument('--save_model_interval', type=int, default=500)

parser.add_argument('--log_interval', type=int, default=50)
parser.add_argument('--use_wandb', action='store_true', help='whether use wandb')
parser.add_argument('--resume', action='store_true', help='enable it to train the model from checkpoints')
parser.add_argument('--checkpoints',
                    default='./checkpoints_aesstyler',
                    help='Directory to save the training checkpoints')
parser.add_argument('--skip_connection_3', action='store_true',
                    help='if specified, add skip connection on ReLU-3')
parser.add_argument('--shallow_layer', action='store_true',
                    help='if specified, also use features of shallow layers')
args = parser.parse_args()
if args.use_wandb:
    run = wandb.init(project='AesStyler',
                     name=f'AesStyler',
                     config=args,
                     reinit=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.skip_connection_3 = True
args.shallow_layer = True

if args.shallow_layer:
    channels = 512 + 256 + 128 + 64
else:
    channels = 512

save_dir = Path(args.save_dir)
save_dir.mkdir(exist_ok=True, parents=True)
log_dir = Path(args.log_dir)
log_dir.mkdir(exist_ok=True, parents=True)
checkpoints_dir = Path(args.checkpoints)
checkpoints_dir.mkdir(exist_ok=True, parents=True)
writer = SummaryWriter(log_dir=str(log_dir))

max_sample = 64 * 64
if args.skip_connection_3:
    net_adaattn_3 = net.AdaAttN_ori(in_planes=256, key_planes=256 + 128 + 64 if args.shallow_layer else 256,
                                    max_sample=max_sample)
else:
    net_adaattn_3 = None
# decoder = net.decoder
transform = net.Transformer(
    in_planes=512, key_planes=channels, shallow_layer=args.shallow_layer)
decoder = net.Decoder(args.skip_connection_3)
# transform = net.Transform(in_planes=512)
vgg = net.vgg
discriminator = net.AesDiscriminator_new()
disc_ad = net.Discriminator()
# discriminator.eval()

decoder.load_state_dict(torch.load(args.decoder))
transform.load_state_dict(torch.load(args.transform), strict=False)
vgg.load_state_dict(torch.load(args.vgg))

if args.skip_connection_3:
    net_adaattn_3.load_state_dict(torch.load(args.net_adaattn_3))

vgg = nn.Sequential(*list(vgg.children())[:44])
network = net.Net(vgg, decoder, discriminator, disc_ad, transform, net_adaattn_3, args)
network.train()
network.to(device)

discriminator.model.eval()

content_tf = train_transform()
style_tf = train_transform()

content_dataset = FlatFolderDataset(args.content_dir, content_tf)
style_dataset = FlatFolderDataset(args.style_dir, style_tf)

content_iter = iter(data.DataLoader(
    content_dataset, batch_size=args.batch_size,
    sampler=InfiniteSamplerWrapper(content_dataset),
    num_workers=args.n_threads))
style_iter = iter(data.DataLoader(
    style_dataset, batch_size=args.batch_size,
    sampler=InfiniteSamplerWrapper(style_dataset),
    num_workers=args.n_threads))

optimizer = torch.optim.Adam([{'params': network.decoder.parameters()},
                              {'params': network.transform.parameters()},
                              {'params': network.net_adaattn_3.parameters()},
                              ], lr=args.lr)
# optimizer = torch.optim.Adam([{'params': network.transform.parameters()}], lr=args.lr)
optimizer_D = torch.optim.Adam(network.disc_ad.parameters(), lr=args.lr)

start_iter = -1

# Enable it to train the model from checkpoints
if args.resume:
    checkpoints = torch.load(args.checkpoints + '/checkpoints.pth.tar')
    network.load_state_dict(checkpoints['net'])
    optimizer.load_state_dict(checkpoints['optimizer'])
    start_iter = checkpoints['epoch']

# 训练
for i in range(start_iter + 1, args.stage1_iter + args.stage2_iter):
    adjust_learning_rate(optimizer, iteration_count=i)
    adjust_learning_rate(optimizer_D, iteration_count=i)
    content_images, contnetn_path = next(content_iter)
    content_images = content_images.to(device)
    style_images, style_path = next(style_iter)
    style_images = style_images.to(device)

    stylized_results, loss_c, loss_local, loss_global, loss_gan_g, loss_ad_g, loss_ad_d, score = network(
        content_images, style_images, aesthetic=True)

    score = score / args.batch_size

    loss_c = loss_c * args.lambda_content
    loss_local = loss_local * args.lambda_local
    loss_global = loss_global * args.lambda_global

    loss_gan_g = args.gan_weight * loss_gan_g

    optimizer_D.zero_grad()
    loss_ad_d = args.ad_weight * loss_ad_d
    loss_ad_d.backward(retain_graph=False)

    if i < args.stage1_iter:
        loss = loss_c + loss_local + loss_global
    else:
        loss_ad_g = args.ad_weight * loss_ad_g
        loss = loss_c + loss_local + loss_global + loss_gan_g + loss_ad_g

    optimizer.zero_grad()
    loss.backward(retain_graph=False)
    optimizer.step()
    optimizer_D.step()

    writer.add_scalar('loss_content', loss_c.item(), i + 1)
    # writer.add_scalar('loss_style', loss_s.item(), i + 1)
    writer.add_scalar('loss_gan_g', loss_gan_g.item(), i + 1)

    if args.use_wandb:
        wandb.log(
            {'loss_content': loss_c.item(), 'loss_local': loss_local.item(), 'loss_global': loss_global.item(),
             'loss_gan_g': loss_gan_g.item(),
             'score': score.item(), 'loss_ad_d': loss_ad_d.item(), 'loss_ad_g': loss_ad_g.item()}, step=i + 1)

    # Save intermediate results
    output_dir = Path(args.sample_path)
    output_dir.mkdir(exist_ok=True, parents=True)
    if (i + 1) % args.log_interval == 0:
        visualized_imgs = torch.cat([content_images, style_images, stylized_results])
        # print(stylized_results)
        output_name = output_dir / 'output{:d}.jpg'.format(i + 1)
        save_image(visualized_imgs, str(output_name), nrow=args.batch_size)
        if i < args.stage1_iter:
            print('[%d/%d] loss_content:%.4f, loss_style:%.4f' % (
                i + 1, args.stage1_iter + args.stage2_iter, loss_c.item(), loss_s.item()))
        else:
            print(
                '[%d/%d] loss_content:%.4f, loss_local:%.4f, loss_global:%.4f loss_gan_g:%.4f, loss_ad_g:%.4f, loss_ad_d:%.4f' % (
                    i + 1, args.stage1_iter + args.stage2_iter, loss_c.item(), loss_local.item(), loss_global.item(),
                    loss_gan_g.item(),
                    loss_ad_g.item(), loss_ad_d.item()))
        os.makedirs(os.path.join(output_dir, 'style'), exist_ok=True)
        os.makedirs(os.path.join(output_dir, 'content'), exist_ok=True)
        os.makedirs(os.path.join(output_dir, 'stylized'), exist_ok=True)
        for idx in range(args.batch_size):
            style_name = os.path.join(output_dir, 'style', f'output{i + 1}_{idx}.jpg')
            content_name = os.path.join(output_dir, 'content', f'output{i + 1}_{idx}.jpg')
            stylized_name = os.path.join(output_dir, 'stylized', f'output{i + 1}_{idx}.jpg')
            save_image(stylized_results[idx], str(stylized_name))
            save_image(content_images[idx], str(content_name))
            save_image(style_images[idx], str(style_name))
        if args.use_wandb:
            # Log the output image and metrics to W&B
            wandb.log({"output_image": wandb.Image(str(output_name))}, step=i + 1)

    # Save models
    if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.stage1_iter + args.stage2_iter:
        checkpoints = {
            "net": network.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": i
        }
        torch.save(checkpoints, checkpoints_dir / 'checkpoints.pth.tar')

        state_dict = network.decoder.state_dict()
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].to(torch.device('cpu'))
        torch.save(state_dict, save_dir /
                   'decoder_iter_{:d}.pth'.format(i + 1))

        state_dict = network.transform.state_dict()
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].to(torch.device('cpu'))
        torch.save(state_dict, save_dir /
                   'transformer_iter_{:d}.pth'.format(i + 1))

        state_dict = network.discriminator.state_dict()
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].to(torch.device('cpu'))
        torch.save(state_dict, save_dir /
                   'discriminator_iter_{:d}.pth'.format(i + 1))

        state_dict = network.net_adaattn_3.state_dict()
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].to(torch.device('cpu'))
        torch.save(state_dict, save_dir /
                   'net_adaattn_3_iter_{:d}.pth'.format(i + 1))

        state_dict = network.disc_ad.state_dict()
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].to(torch.device('cpu'))
        torch.save(state_dict, save_dir /
                   'disc_ad_iter_{:d}.pth'.format(i + 1))

writer.close()
