import argparse
import os.path
import time
import cv2
import pandas as pd
import torch.cuda
import numpy as np
from models.fsam_demo import FPC
import utils
from skimage.metrics import peak_signal_noise_ratio as PSNR
from skimage.metrics import structural_similarity as SSIM
from lpips.lpips import *

parser = argparse.ArgumentParser(description="Args of this repo.")
parser.add_argument("--rho", default=0.01, type=float)
parser.add_argument("--beta", default=0.0, type=float)
parser.add_argument("--rate", default=0.04, type=float)
parser.add_argument("--device", default="0")
opt = parser.parse_args()
opt.device = f"cuda:{opt.device}"

def testing(network, save_img):
    datasets = ["U100"]

    for idx, item in enumerate(datasets):
        sum_psnr, sum_ssim = 0., 0.
        i = 0
        path = os.path.join('../data/Urban100/HR')
        print("*", ("  test dataset: " + path + ", device: " + str(config.device) + "  ").center(120, "="), "*")
        with torch.no_grad():
            for root, dir, files in os.walk(path):
                for file in files:
                    i = i + 1
                    name = file.split('.')[0]
                    Img = cv2.imread(f"{root}/{file}")
                    try:
                        Img_yuv = cv2.cvtColor(Img, cv2.COLOR_BGR2YCrCb)
                    except Exception:
                        print(name)
                        continue
                    Img_rec_yuv = Img_yuv.copy()
                    Iorg_y = Img_yuv[:, :, 0]
                    x = Iorg_y / 255.
                    x = torch.from_numpy(np.array(x)).float()

                    h, w = x.size()
                    lack = config.block_size - h % config.block_size if h % config.block_size != 0 else 0
                    padding_h = torch.zeros(lack, w)
                    expand_h = h + lack
                    inputs = torch.cat((x, padding_h), 0)
                    lack = config.block_size - w % config.block_size if w % config.block_size != 0 else 0
                    expand_w = w + lack
                    padding_w = torch.zeros(expand_h, lack)
                    inputs = torch.cat((inputs, padding_w), 1).unsqueeze(0).unsqueeze(0)
                    inputs = torch.cat(torch.split(inputs, split_size_or_sections=config.block_size, dim=3), dim=0)
                    inputs = torch.cat(torch.split(inputs, split_size_or_sections=config.block_size, dim=2), dim=0).to(config.device)

                    reconstruction = network(inputs)


                    idx = expand_w // config.block_size
                    reconstruction = torch.cat(torch.split(reconstruction, split_size_or_sections=1 * idx, dim=0), dim=2)
                    reconstruction = torch.cat(torch.split(reconstruction, split_size_or_sections=1, dim=0), dim=3)
                    reconstruction = reconstruction.squeeze()[:h, :w]

                    x_hat = reconstruction.cpu().numpy()

                    psnr = PSNR(x_hat * 255, Iorg_y.astype(np.float64), data_range=255)
                    ssim = SSIM(x_hat * 255, Iorg_y.astype(np.float64), data_range=255)

                    sum_psnr += psnr
                    sum_ssim += ssim

                    Img_rec_yuv[:,:,0] = x_hat * 255
                    im_rec_rgb = cv2.cvtColor(Img_rec_yuv, cv2.COLOR_YCrCb2BGR)
                    im_rec_rgb = np.clip(im_rec_rgb, 0, 255).astype(np.uint8)

                    if save_img:
                        img_path = "./recon_img/Y/{}/{}/".format(item, int(config.ratio * 100))
                        if not os.path.isdir("./recon_img/Y/{}/".format(item)):
                            os.mkdir("./recon_img/Y/{}/".format(item))
                        if not os.path.isdir(img_path):
                            os.mkdir(img_path)
                            print("\rMkdir {}".format(img_path))
                        cv2.imwrite(f"{img_path}/{name}_{round(psnr, 2)}_{round(ssim, 4)}.png", (im_rec_rgb))

            print(f"{i} AVG RES: PSNR, {round(sum_psnr / i, 2)}, SSIM, {round(sum_ssim / i, 4)}")


if __name__=="__main__":
    print("Start evaluate...")
    config = utils.GetConfig(ratio=opt.rate, device=opt.device, tune=True)
    net = FPC(LayerNo=10, cs_ratio=opt.rate, rho=opt.rho, beta=opt.beta).to(config.device).eval()
    model_path = os.path.join(config.folder, f"model_rho{opt.rho:.4f}_beta_{opt.beta:.2f}.pth")
    print(model_path)
    if os.path.exists(model_path):
        if torch.cuda.is_available():
            trained_model = torch.load(model_path, map_location=config.device)
        else:
            trained_model = torch.load(model_path, map_location="cpu")
        net.load_state_dict(trained_model, strict=False)
        print("Trained model loaded.")
    else:
        raise FileNotFoundError("Missing trained models.")

    testing(net, save_img = config.save)
