import argparse
from datetime import datetime
import os
import yaml
from torch import nn
import logging 
import torch
from PIL import Image
import numpy as np
from tqdm import tqdm
from idm.model import I2IViT
from idm.sr3_modules.vit_utils import interpolate_pos_embed

from torchvision.utils import save_image
#from torch.utils.data import DataLoader
from torchvision import transforms

from data import define_dataloader
from utils import denormalize, set_seed
from metrics import PSNR, ssim

parser = argparse.ArgumentParser()
parser.add_argument("--cfg", type=str, default="./configs/guideTSR_celeba_X8.yaml", help="the config files")
parser.add_argument("--hr_shape", type=int, default=[512, 512], help="test image size 128 or 256")
parser.add_argument("--ckpt", type=str, default="./results/guide_tsr_gan_1/models/generator_latest.pth", help="ckpt")
parser.add_argument("--output", default='', help="where to store the output")
parser.add_argument("--gpu", type=int, default=0, help="gpu number")
opt = parser.parse_args()

cfg = yaml.load(open(opt.cfg, encoding="utf-8"), Loader=yaml.Loader)
set_seed(cfg["seed"])

# output setting -----------------------------------
if opt.output != "":
    cfg["output"] = opt.output
      
img_dir = os.path.join(cfg["output"], "imgs")
os.makedirs(img_dir, exist_ok=True)
# --------------------------------------------------

# logging ------------------------------------------
logging.basicConfig(filename=os.path.join(cfg["output"], "train.log"), level=logging.INFO)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
logging.getLogger('').addHandler(console)
# --------------------------------------------------

cfg["datasets"]["val"]["config"]["image_size"] = opt.hr_shape
cfg["datasets"]["val"]["config"]["data_len"] = -1

# models and datasets setting ----------------------
device = torch.device(f"cuda:{opt.gpu}")
# device = torch.device("cpu")

_, val_data = define_dataloader(cfg["datasets"])
model = I2IViT(cfg)

# summary(multisr, [(3,64,64), (4,64,64), (4,128,128), (4,256,256)])
# summary(dis, [(3,512,512)])
# ----------------------------------------------------

# load state dict -----------------------------

ckpt = torch.load(opt.ckpt, map_location="cpu")
interpolate_pos_embed(model.vit, ckpt, profix="vit.")
model.load_state_dict(ckpt)
model.to(device)
# ---------------------------------------------------

# eval --------------------------------
total_psnr = []
total_ssim = []
psnr = PSNR(255)

# set image size and split to 256x256 patches -------
unfold = nn.Unfold(kernel_size=(256, 256), stride=256)
fold = nn.Fold(opt.hr_shape, kernel_size=(256, 256), stride=256)

model.eval()
with torch.no_grad():
    for item in tqdm(val_data):
        gt = item["gt"].to(device)
        lq = item["lq"].to(device)
        fname = item["fname"]
        
        patches = unfold(lq)
        patches = patches.reshape(3, 256, 256, -1).permute(3,0,1,2)
        out = model(patches)
        out = out.permute(1,2,3,0).reshape(3*256*256, -1)
        out = fold(out)

        total_psnr.append(psnr(out, gt).cpu().numpy())
        total_ssim.append(ssim(gt, out).cpu().numpy())
        _, c, _, _ = lq.shape
        if c < 3:
            lq = lq.repeat(1,3,1,1)
        img_grid  = denormalize(torch.cat([lq, out.unsqueeze(0), gt], dim=-1))
        # img_grid  = denormalize(out)
        save_image(img_grid, os.path.join(img_dir, fname[0]), nrow=1, normalize=False)
        
    avg_psnr = sum(total_psnr) / len(total_psnr)
    avg_ssim = sum(total_ssim) / len(total_ssim)
    logging.info(f"psnr: {avg_psnr}, ssim: {avg_ssim}")

            
                
        

    