import torch
from torch import nn
import argparse
import logging
import os
import numpy as np
import torch.distributed as dist
import yaml
from tqdm import tqdm
from torchvision.utils import save_image
from time import time
from torchsummary import summary

from data import define_dataloader
from utils import denormalize, set_seed
from idm.model import GuideTSR, NLayerDiscriminator, I2IViT
from eval import guidetsr_eval
from models.loss import PerceptualLoss

parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='configs/VIT_gopro_deblur.yaml',
                    help='ymal file for configuration')
parser.add_argument('--gpu_ids', type=str, default=None)
parser.add_argument('--img_size', type=int, default=256, help="the input image size")
parser.add_argument('--l_gan', type=float, default=0.1, help="weight of the prior l1 loss")
parser.add_argument('--gen_l1', type=float, default=10, help="weight of vit l1 loss")
parser.add_argument('--gen_percep', type=float, default=0.5, help="weight of vit perceptual loss")
parser.add_argument('--output', type=str, default="", help="output path for train and resume")
# parser.add_argument('--local_rank', type=int,help='local rank for dist')

# parse configs
args = parser.parse_args()
# Convert to NoneDict, which return None for missing key.
cfg = yaml.load(open(args.cfg, encoding="utf-8"), Loader=yaml.Loader)
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# set seed------------------------------------------------
set_seed(cfg['seed'])

# output setting -----------------------------------------
if args.output != "":
    cfg["out_path"] = args.output

img_dir = os.path.join(cfg["out_path"], "imgs")
models_dir = os.path.join(cfg["out_path"], "models")
os.makedirs(img_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)
# --------------------------------------------------------

# logging ------------------------------------------------
logging.basicConfig(filename=os.path.join(cfg["out_path"], "train.log"), level=logging.INFO)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
logging.getLogger('').addHandler(console)
# --------------------------------------------------------

# dataset ------------------------------------------------
train_data, val_data = define_dataloader(cfg["datasets"])

# model --------------------------------------------------
# model = GuideTSR(cfg)
model = I2IViT(cfg)
dis = NLayerDiscriminator(3, args.img_size)
logging.info('Initial Model Finished')

if cfg['train']['resume'] > 0:
    model.load_state_dict(torch.load(os.path.join(models_dir, f"generator_latest.pth")), strict=True)
    dis.load_state_dict(torch.load(os.path.join(models_dir, f"dis_latest.pth")), strict=True)
    logging.info('Resuming training from epoch: {}.'.format(cfg['train']['resume']))
model.to(device)
dis.to(device)
summary(model, [[3,256,256]])

# optim --------------------------------------------------
optim_G = torch.optim.Adam(model.parameters(), lr=cfg["optim"]["lr"], betas=(cfg["optim"]["b1"], cfg["optim"]["b2"]))
optim_D = torch.optim.Adam(dis.parameters(), lr=cfg["optim"]["lr"]*0.1, betas=(cfg["optim"]["b1"], cfg["optim"]["b2"]))
l1_loss = nn.L1Loss().to(device)
percep = PerceptualLoss().to(device)
schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optim_G, T_max=cfg["train"]["epoch"] // cfg["train"]["lr_freq"])

# Train --------------------------------------------------
for epoch in range(cfg["train"]["resume"], cfg["train"]["epoch"]):
    start = time()
    total_loss = []
    total_gan = []
    total_stg2_l1 = []
    total_stg2_percep = []
    total_loss_d = []
    for item in tqdm(train_data):
        lq = item["lq"].cuda()
        gt = item["gt"].cuda()
        
        optim_G.zero_grad()
        out = model(lq)
        stg2_l1 = l1_loss(gt, out)
        stg2_percep = percep(gt, out)
        fake = dis(out)
        l_gan = -fake.mean()
        loss = stg2_l1 * args.gen_l1 + stg2_percep * args.gen_percep + l_gan * args.l_gan
        loss.backward()
        optim_G.step()
        
        optim_D.zero_grad()
        real_d = dis(gt)
        fake_d = dis(out.detach())
        loss_d = nn.ReLU()(1-real_d).mean() + nn.ReLU()(1+fake_d).mean()
        loss_d.backward()
        optim_D.step()
        
        # -------------------------------------------
        total_loss.append(loss.item())
        total_gan.append(l_gan.item())
        total_stg2_l1.append(stg2_l1.item())
        total_stg2_percep.append(stg2_percep.item())
        total_loss_d.append(loss_d.item())
        # print(model.vit.pos_embed)
    # lr schedule
    if (epoch+1) % cfg["train"]["lr_freq"] == 0:
        schedule.step()
    
    # loss logging and visualization
    logging.info(f"epoch: {epoch+1}, avg_gen: {sum(total_loss) / len(total_loss)}, \
    stg2 l1 loss: {sum(total_stg2_l1)/len(total_stg2_l1)}, stg2 percep loss: {sum(total_stg2_percep)/len(total_stg2_percep)}, \
    loss_gan: {sum(total_gan)/len(total_gan)}, loss_d: {sum(total_loss_d)/len(total_loss_d)}, time: {time()-start}")
    
    img_grid = denormalize(torch.cat([out[:4], gt[:4]], -1))
    save_image(img_grid, os.path.join(img_dir, f"epoch_{epoch+1}.png"), nrow=1, normalize=False)
    
    # eval
    if epoch % cfg['train']['val_freq'] == 0:
        model.eval()
        psnr, ssim = guidetsr_eval(model, val_data, device)
        torch.cuda.empty_cache()
        logging.info(f"epoch: {epoch+1}, psnr: {psnr}, ssim: {ssim}")
        model.train()
        
    # save models -----------------------------------
    torch.save(model.state_dict(), models_dir+"/generator_latest.pth")
    torch.save(dis.state_dict(), models_dir+"/dis_latest.pth")
