import argparse
from datetime import datetime
import os
import yaml
from torch import nn
import logging 
import torch
from PIL import Image
import numpy as np
from tqdm import tqdm

from torchvision.utils import save_image
#from torch.utils.data import DataLoader
from torch.autograd import Variable

from data import define_dataloader
from models import get_model
from torchsummary import summary
from utils import denormalize, set_seed
from models.loss import mse

parser = argparse.ArgumentParser()
parser.add_argument("--cfg", type=str, default="./configs/celeba_inpainting1024.yaml", help="the config files")
parser.add_argument("--hr_shape", type=int, default=128, help="training image size 128 or 256")
parser.add_argument("--output", default='', help="where to store the output")
parser.add_argument("--gpu", type=int, default=0, help="gpu number")
opt = parser.parse_args()

cfg = yaml.load(open(opt.cfg, encoding="utf-8"), Loader=yaml.Loader)
set_seed(cfg["seed"])

# output setting -----------------------------------
if opt.output != "":
    cfg["output"] = opt.output
      
img_dir = os.path.join(cfg["output"], "imgs")
models_dir = os.path.join(cfg["output"], "models")
os.makedirs(img_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)
# --------------------------------------------------

# logging ------------------------------------------
logging.basicConfig(filename=os.path.join(cfg["output"], "train.log"), level=logging.INFO)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
logging.getLogger('').addHandler(console)
# --------------------------------------------------

# models and datasets setting ----------------------
device = torch.device(f"cuda:{opt.gpu}")

train_data, val_data = define_dataloader(cfg["dataset"])
diffusion = get_model(cfg["diffusion"]).to(device)
if cfg["train"]["resume"] > 0:
    diffusion.denoise_fn.load_state_dict(torch.load(os.path.join(models_dir, f"netG_latest.pth"), \
                                            map_location=device))
    print(f"pretrained model at {cfg['train']['resume']} epoch load success!!")
    
summary(diffusion.denoise_fn, [[6,128,128], [1]])
# ----------------------------------------------------

# set train optimzation -----------------------------
optimzer = torch.optim.Adam(diffusion.parameters(), lr=cfg["para"]["lr"], \
                            betas=(cfg["para"]["b1"], cfg["para"]["b2"]))
# MSE = mse
diffusion.set_loss(mse)
diffusion.set_new_noise_schedule(device)
# ---------------------------------------------------

for epoch in range(cfg["train"]["resume"], cfg["train"]["epoch"]):
    total_loss = []
    time_s = datetime.now()
    for item in tqdm(train_data):
        gt = item["image"].to(device)
        con_img = item["con_image"].to(device)
        mask_img = item["mask_image"].to(device)
        mask = item["mask"].to(device)
        
        optimzer.zero_grad()
        
        loss = diffusion(gt, con_img, mask=mask)
        loss.backward()
        optimzer.step()
        total_loss.append(loss.item())
# --------------------------------------------------
    logging.info(f"epoch: {epoch+1}, avg_loss: {sum(total_loss) / len(total_loss)}, timecost: {datetime.now() - time_s}")
        
    if (epoch+1) % cfg["train"]["n_sample"] == 0:
        diffusion.eval()
        with torch.no_grad():
            out, _ = diffusion.restoration(con_img[:4], y_t=con_img[:4], y_0=gt[:4], mask=mask[:4], sample_num=1)
            img_grid = denormalize(torch.cat([mask_img[:4], out, gt[:4]], -1))
            save_image(img_grid, os.path.join(img_dir, f"epoch_{epoch+1}.png"), nrow=1, normalize=False)
        diffusion.train()
            
    # if (epoch+1) % cfg["train"]["save_ckpt"] == 0:
    #     torch.save(diffusion.denoise_fn.state_dict(), os.path.join(models_dir, f"netG_{epoch+1}.pth"))
    #     logging.info(f"epoch {epoch+1} model saved !!!")
            
    torch.save(diffusion.denoise_fn.state_dict(), os.path.join(models_dir, f"netG_latest.pth"))
    print("latest models saved !!!")
            
                
        

    