import argparse
from datetime import datetime
import os
import yaml
from torch import nn
import logging 
import torch
from PIL import Image
import numpy as np
from tqdm import tqdm
from idm.model import GuideTSR
from idm.sr3_modules.vit_utils import interpolate_pos_embed

from torchvision.utils import save_image
#from torch.utils.data import DataLoader
from torchvision import transforms

from data import define_dataloader
from utils import denormalize, set_seed
from metrics import PSNR, ssim

parser = argparse.ArgumentParser()
parser.add_argument("--cfg", type=str, default="./configs/guideTSR_celeba_X8.yaml", help="the config files")
parser.add_argument("--hr_shape", type=int, default=1024, help="hr image size 128 or 256")
parser.add_argument("--lr_shape", type=int, default=128, help="lr image size 128 or 256")
parser.add_argument("--ckpt", type=str, default="./results/guide_tsr_gan_1/models/generator_latest.pth", help="ckpt")
parser.add_argument("--output", default='', help="where to store the output")
parser.add_argument("--gpu", type=int, default=0, help="gpu number")
opt = parser.parse_args()

cfg = yaml.load(open(opt.cfg, encoding="utf-8"), Loader=yaml.Loader)
set_seed(cfg["seed"])

# output setting -----------------------------------
if opt.output != "":
    cfg["output"] = opt.output
      
img_dir = os.path.join(cfg["output"], "imgs")
os.makedirs(img_dir, exist_ok=True)
# --------------------------------------------------

# logging ------------------------------------------
logging.basicConfig(filename=os.path.join(cfg["output"], "train.log"), level=logging.INFO)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
logging.getLogger('').addHandler(console)
# --------------------------------------------------

cfg["datasets"]["val"]["config"]["data_len"] = -1
cfg["model"]["config"]["img_size"] = opt.hr_shape


# models and datasets setting ----------------------
device = torch.device(f"cuda:{opt.gpu}")

_, val_data = define_dataloader(cfg["datasets"])

# summary(multisr, [(3,64,64), (4,64,64), (4,128,128), (4,256,256)])
# summary(dis, [(3,512,512)])
# ----------------------------------------------------

# eval --------------------------------
total_psnr = []
total_ssim = []
psnr = PSNR(255)

with torch.no_grad():
    for item in tqdm(val_data):
        lq = item["lq"].to(device)
        gt = item["gt"].to(device)
        fname = item["fname"]
        
        _, _, h, _ = gt.shape
        
        # set resolution for diffrient input 
        if cfg["model"]["config"]["img_size"] != h:
            cfg["model"]["config"]["img_size"] = h
            model = GuideTSR(cfg)
            ckpt = torch.load(opt.ckpt, map_location="cpu")
            interpolate_pos_embed(model.stg2, ckpt, profix="stg2.")
            model.load_state_dict(ckpt)
            model.to(device)
            model.eval()

        out = model(lq)
        
        total_psnr.append(psnr(out, gt).cpu().numpy())
        total_ssim.append(ssim(gt, out).cpu().numpy())
        img_grid  = denormalize(out)
        save_image(img_grid, os.path.join(img_dir, fname[0]), nrow=1, normalize=False)
        
    avg_psnr = sum(total_psnr) / len(total_psnr)
    avg_ssim = sum(total_ssim) / len(total_ssim)
    logging.info(f"psnr: {avg_psnr}, ssim: {avg_ssim}")
