from __future__ import print_function
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from time import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.optim as optim
from tensorboardX import SummaryWriter
from torch.nn.utils import clip_grad_norm_
import pytorch_ssim
from GAN import Attention_GAN
from tqdm import tqdm
from torchvision import datasets, utils, transforms
from GAN.dataset_wavelet import wavelet_transform, wavelet_inverse
from GAN.parameter import create_argparser
import numpy as np
import random
import matplotlib.pyplot as plt
plt.ion()
from GAN.image_datasets import load_data, load_pair_data

from model.archs.NAFNet_arch import NAFNet


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

def main():
    setup_seed(0)
    args = create_argparser().parse_args()
    args.batch_size = 2
    cont_flag = False
    retr_epoch_id = 23

    total_iters = args.max_epoch
    Generator_iters = 4  # every n epoches G is trained, D will be trained once

    im_save_path = "Validation_images_norm"
    if cont_flag:
        model_save_path = r'checkpoint/20231105-2330Validation_images_norm-lr_G=1e-04-lr_D=1e-05_norm/'
        summary_path = r'summary'
    else:
        datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
        model_save_path = (
                "checkpoint/"
                + datetime_str
                + im_save_path
                + "-lr_G={:.0e}-lr_D={:.0e}_norm".format(args.lr_G, args.lr_D)
        )
        summary_path = (
                "log/"
                + datetime_str
                + im_save_path
                + "-lr_G={:.0e}-lr_D={:.0e}_norm".format(args.lr_G, args.lr_D)
        )

        if not os.path.exists(im_save_path):
            os.makedirs(im_save_path)
        if not os.path.exists(model_save_path):
            os.makedirs(model_save_path)
        if not os.path.exists(summary_path):
            os.makedirs(summary_path)

    print("===> Building model")

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    torch.cuda.empty_cache()

    # build generator and discriminator
    netG = NAFNet(in_channel=6,
        out_channel=3,
        width=64,
        enc_blk_nums=[1, 1, 1, 28],
        middle_blk_num=1,
        dec_blk_nums=[1, 1, 1, 1]
    )
    netD = Attention_GAN.Discriminator(
        args, in_channels=args.netD_ch_num, batch_norm=True
    )

    netG.to(device)
    netD.to(device)

    if cont_flag:
        netG_dict = torch.load(
            model_save_path + "/netG_epoch_{}.pth".format(retr_epoch_id)
        )
        netG.load_state_dict(netG_dict)

        netD_dict = torch.load(
            model_save_path + "/netD_epoch_{}.pth".format(retr_epoch_id)
        )
        netD.load_state_dict(netD_dict)

    train_data_loader = load_pair_data(
        input_dir="",  # ENTER YOUR TRAINING INPUT IMAGE DIRECTORY HERE
        target_dir="",  # ENTER YOUR TRAINING TARGET IMAGE DIRECTORY HERE
        batch_size=args.batch_size,
        image_size=args.image_size,
        class_cond=args.class_cond,
        deterministic=True,
    )

    valid_data_loader = load_pair_data(
        input_dir="",  # ENTER YOUR VALIDATION INPUT IMAGE DIRECTORY HERE
        target_dir="",  # ENTER YOUR VALIDATION TARGET IMAGE DIRECTORY HERE
        batch_size=args.batch_size,
        image_size=args.image_size,
        class_cond=args.class_cond,
        deterministic=True,
    )

    # print(len(train_data_loader), len(valid_data_loader))

    # loss functions
    mse_loss_fun = nn.MSELoss(size_average=True).to(device)
    l1_loss_fun = nn.L1Loss(size_average=True).to(device)
    ssim_loss_fun = pytorch_ssim.SSIM(size_average=True).to(device)


    # setup optimizer
    optimizerG = optim.AdamW(
        netG.parameters(), lr=args.lr_G, betas=(0.9, 0.999)
    )
    optimizerD = optim.AdamW(
        netD.parameters(), lr=args.lr_D, betas=(0.9, 0.999)
    )

    print("===> Training Start")

    if cont_flag:
        startEpoch = retr_epoch_id
    else:
        startEpoch = 0

    niter = total_iters
    flag = 0

    for epoch in range(startEpoch, niter, 1):
        start_time = time()
        # train
        run_loss_G = 0
        run_loss_G_l1 = 0
        run_loss_G_mse = 0
        run_loss_G_ssim = 0
        run_loss_D_real = 0
        run_loss_D_fake = 0
        counter = 0
        netG.train()
        netD.train()

        for i, batch in enumerate(tqdm(train_data_loader), 1):
            small_batch = batch[0].to(device)
            large_batch = batch[1].to(device)

            large_batch1, large_H1 = wavelet_transform(large_batch, device)
            small_batch1, small_H1 = wavelet_transform(small_batch, device)

            large_batch2, large_H2 = wavelet_transform(large_batch1, device)
            small_batch2, small_H2 = wavelet_transform(small_batch1, device)

            ll1 = wavelet_inverse(
                torch.cat((
                    torch.cat((large_batch2, small_H2[:, :3, :, :]), dim=3),
                    torch.cat((small_H2[:, 3:6, :, :], small_H2[:, 6:9, :, :]), dim=3)
                ), dim=2),
                args.image_size // 2,
                device
            )

            noise = torch.randn_like(ll1).to(device)
            input = torch.cat((ll1, noise), dim=1)
            ll1 = netG(input)

            input = wavelet_inverse(
                torch.cat((
                    torch.cat((ll1, small_H1[:, :3, :, :]), dim=3),
                    torch.cat((small_H1[:, 3:6, :, :], small_H1[:, 6:9, :, :]), dim=3)
                ), dim=2),
                args.image_size,
                device
            )
            noise = torch.randn_like(input).to(device)
            input = torch.cat((input, noise), dim=1)
            target = large_batch
            ############################
            # (1) Update G network: maximize log(D(G(z)))
            ###########################
            fake = netG(input)

            netG.zero_grad()
            # G_dis_loss = bce_loss_fun(netD(fake), label_real)
            G_dis_loss = -torch.mean(F.sigmoid(netD(fake)))
            G_ssim_loss = 1 - ssim_loss_fun((fake + 1.0) / 2.0, (target + 1.0) / 2.0)
            G_mse_loss = mse_loss_fun(fake, target)
            G_l1_loss = l1_loss_fun(fake, target)

            errG = (
                    args.loss_func['dis'] * G_dis_loss
                    + args.loss_func['MSE'] * G_mse_loss
                    + args.loss_func['SSIM'] * G_ssim_loss
                    + args.loss_func['L1'] + G_l1_loss
            )

            clip_grad_norm_(netG.parameters(), 0.5)
            errG.backward()
            optimizerG.step()

            ############################
            # (2) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            indie_G = Generator_iters
            if i % indie_G == 0:
                netD.zero_grad()
                D_fake_loss = torch.mean(F.sigmoid(netD(fake.detach())))
                D_real_loss = -torch.mean(F.sigmoid(netD(target)))
                errD = (D_fake_loss + D_real_loss) * 0.5
                clip_grad_norm_(netD.parameters(), 0.5)
                errD.backward()
                optimizerD.step()

            if i % indie_G == 0:
                counter = counter + 1
                run_loss_G = (run_loss_G + errG.item())/counter
                run_loss_G_mse = (run_loss_G_mse + args.loss_func['MSE'] * G_mse_loss.item())/counter
                run_loss_G_l1 = (run_loss_G_l1 + args.loss_func['L1'] * G_l1_loss.item())/counter
                run_loss_G_ssim = (run_loss_G_ssim + args.loss_func['SSIM'] * G_ssim_loss.item())/counter

                run_loss_D_real = (run_loss_D_real + D_real_loss.item())/counter
                run_loss_D_fake = (run_loss_D_fake + D_fake_loss.item())/counter

        # valid
        val_loss_G = 0
        val_loss_G_l1 = 0
        val_loss_G_mse = 0
        val_loss_G_ssim = 0
        val_loss_D_real = 0
        val_loss_D_fake = 0

        counter = 0
        netG.eval()
        netD.eval()

        with torch.no_grad():
                for i, batch in enumerate(tqdm(valid_data_loader), 1):
                    small_batch = batch[0].to(device)
                    large_batch = batch[1].to(device)

                    large_batch1, large_H1 = wavelet_transform(large_batch, device)
                    small_batch1, small_H1 = wavelet_transform(small_batch, device)

                    large_batch2, large_H2 = wavelet_transform(large_batch1, device)
                    small_batch2, small_H2 = wavelet_transform(small_batch1, device)

                    ll1 = wavelet_inverse(
                    torch.cat((
                        torch.cat((large_batch2, small_H2[:, :3, :, :]), dim=3),
                        torch.cat((small_H2[:, 3:6, :, :], small_H2[:, 6:9, :, :]), dim=3)
                    ), dim=2),
                    args.image_size // 2,
                    device
                    )

                    noise = torch.randn_like(ll1).to(device)
                    val_input = torch.cat((ll1, noise), dim=1)
                    ll1 = netG(val_input)

                    val_input = wavelet_inverse(
                    torch.cat((
                        torch.cat((ll1, small_H1[:, :3, :, :]), dim=3),
                        torch.cat((small_H1[:, 3:6, :, :], small_H1[:, 6:9, :, :]), dim=3)
                    ), dim=2),
                    args.image_size,
                    device
                    )
                    noise = torch.randn_like(val_input).to(device)
                    val_input = torch.cat((val_input, noise), dim=1)
                    val_target = large_batch
                    val_fake = netG(val_input)
                    
                    netG.zero_grad()
                    G_dis_loss = -torch.mean(F.sigmoid(netD(val_fake)))
                    # G_dis_loss = bce_loss_fun(netD(val_fake), label_real)
                    G_l1_loss = l1_loss_fun(val_fake, val_target)
                    G_mse_loss = mse_loss_fun(val_fake, val_target)
                    G_ssim_loss = 1 - ssim_loss_fun((val_fake + 1.0) / 2.0, (val_target + 1.0) / 2.0)
                    errG = (
                            args.loss_func['dis'] * G_dis_loss
                            + args.loss_func['MSE'] * G_mse_loss
                            + args.loss_func['SSIM'] * G_ssim_loss
                            + args.loss_func['L1'] + G_l1_loss
                    )

                    D_fake_loss = torch.mean(F.sigmoid(netD(val_fake)))
                    D_real_loss = -torch.mean(F.sigmoid(netD(val_target)))
                    # D_fake_loss = bce_loss_fun(netD(fake.detach()), label_fake)
                    # D_real_loss = bce_loss_fun(netD(val_target), label_real)
                    errD = (D_fake_loss + D_real_loss) * 0.5

                    counter = counter + 1
                    val_loss_G = (val_loss_G + errG.item())/counter
                    val_loss_G_mse = (val_loss_G_mse + args.loss_func['MSE'] * G_mse_loss.item())/counter
                    val_loss_G_l1 = (val_loss_G_l1 + args.loss_func['L1'] * G_l1_loss.item())/counter
                    val_loss_G_ssim = (val_loss_G_ssim + args.loss_func['SSIM'] * G_ssim_loss.item())/counter

                    val_loss_D_real = (val_loss_D_real + D_real_loss.item())/counter
                    val_loss_D_fake = (val_loss_D_fake + D_fake_loss.item())/counter
                # save test
                if flag <= 1:
                    utils.save_image(
                        (val_target+1)/2,
                        "%s/inputAF_epoch_%03d.png" % (im_save_path, epoch)
                    )
                    flag += 1

                utils.save_image(
                    (val_fake+1)/2,
                    "%s/network_epoch_%03d.png" % (im_save_path, epoch)
                )

        text = (
                "[%d/%d] Loss_G: %.4f Loss_G_l1: %.4f Loss_G_MSE: %.4f Loss_G_ssim: %.4f | Loss_D: %.4f D(x): %.4f D(G(z)): %.4f \n"
                % (
                    epoch,
                    niter,
                    run_loss_G,
                    run_loss_G_l1,
                    run_loss_G_mse,
                    run_loss_G_ssim,
                    run_loss_D_real,
                    run_loss_D_fake,
                    (run_loss_D_real+run_loss_D_fake)*0.5
                )
        )
        print(text)

        valid_text = (
                "[%d/%d] Valid_Loss_G: %.4f Valid_Loss_G_l1: %.4f Valid_Loss_G_MSE: %.4f Valid_Loss_G_ssim: %.4f | Valid_Loss_D: %.4f Valid_D(x): %.4f Valid_D(G(z)): %.4f\n"
                
                " Time: %d s"
                % (
                    epoch,
                    niter,
                    val_loss_G,
                    val_loss_G_l1, 
                    val_loss_G_mse,
                    val_loss_G_ssim,
                    val_loss_D_real,
                    val_loss_D_fake,
                    (val_loss_D_real+val_loss_D_fake)*0.5,
                    int(time() - start_time),
                )
        )
        print(valid_text)
        torch.save(
            netG.state_dict(), "%s/netG_epoch_%d.pth" % (model_save_path, epoch)
        )
        torch.save(
            netD.state_dict(), "%s/netD_epoch_%d.pth" % (model_save_path, epoch)
        )


if __name__ == "__main__":
    main()
