import argparse
import os.path
import time
import cv2
import pandas as pd
import torch.cuda
import numpy as np
from models.sam_demo import FPC
import utils
from skimage.metrics import peak_signal_noise_ratio as PSNR
from skimage.metrics import structural_similarity as SSIM
from lpips.lpips import *

parser = argparse.ArgumentParser(description="Args of this repo.")
parser.add_argument("--rho", default=0.01, type=float)
parser.add_argument("--beta", default=0.0, type=float)
parser.add_argument("--rate", default=0.04, type=float)
parser.add_argument("--device", default="0")
opt = parser.parse_args()
opt.device = f"cuda:{opt.device}"

import numpy as np

def compute_angular_error(original, reconstructed):
    """
    计算两组向量之间的角度误差（弧度）
    :param original: 原始向量矩阵，形状 (H, W, C)
    :param reconstructed: 重建向量矩阵，形状 (H, W, C)
    :return: 角度误差矩阵，形状 (H, W)
    """
    dot_product = np.sum(original * reconstructed, axis=-1)
    norm_original = np.linalg.norm(original, axis=-1)
    norm_reconstructed = np.linalg.norm(reconstructed, axis=-1)
    
    # 避免除以零和数值不稳定
    epsilon = 1e-10
    cos_theta = dot_product / (norm_original * norm_reconstructed + epsilon)
    # 将 cos_theta 限制在 [-1, 1] 范围内（避免浮点误差）
    cos_theta = np.clip(cos_theta, -1.0, 1.0)
    angular_error = np.arccos(cos_theta)
    return angular_error

def compute_mapsnr(original, reconstructed):
    """
    计算 MaPSNR
    :param original: 原始向量矩阵，形状 (H, W, C)
    :param reconstructed: 重建向量矩阵，形状 (H, W, C)
    :return: MaPSNR 值（单位：dB）
    """
    angular_error = compute_angular_error(original, reconstructed)
    mean_angular_error = np.mean(angular_error)
    
    # MaPSNR 公式（假设最大角度误差为 pi）
    mapsnr = 10 * np.log10((np.pi ** 2) / (mean_angular_error ** 2))
    return mapsnr


def testing(network, save_img):
    datasets = ["Set14", "Urban100", 'General100']
    for idx, item in enumerate(datasets):
        sum_psnr, sum_ssim, sum_mapsnr = 0., 0., 0.
        i = 0
        path = os.path.join(f'../data/{item}/HR')
        print("*", ("  test dataset: " + path + ", device: " + str(config.device) + "  ").center(120, "="), "*")
        with torch.no_grad():
            for root, dir, files in os.walk(path):
                for file in files:
                    i = i + 1
                    name = file.split('.')[0]
                    Img = cv2.imread(f"{root}/{file}")
                    try:
                        Img_yuv = cv2.cvtColor(Img, cv2.COLOR_BGR2YCrCb)
                    except Exception:
                        print(name)
                        continue
                    Img_rec_yuv = Img_yuv.copy()
                    Iorg_y = Img_yuv[:, :, 0]
                    x = Iorg_y / 255.
                    x = torch.from_numpy(np.array(x)).float()

                    h, w = x.size()
                    lack = config.block_size - h % config.block_size if h % config.block_size != 0 else 0
                    padding_h = torch.zeros(lack, w)
                    expand_h = h + lack
                    inputs = torch.cat((x, padding_h), 0)
                    lack = config.block_size - w % config.block_size if w % config.block_size != 0 else 0
                    expand_w = w + lack
                    padding_w = torch.zeros(expand_h, lack)
                    inputs = torch.cat((inputs, padding_w), 1).unsqueeze(0).unsqueeze(0)
                    inputs = torch.cat(torch.split(inputs, split_size_or_sections=config.block_size, dim=3), dim=0)
                    inputs = torch.cat(torch.split(inputs, split_size_or_sections=config.block_size, dim=2), dim=0).to(config.device)

                    reconstruction = network(inputs)


                    idx = expand_w // config.block_size
                    reconstruction = torch.cat(torch.split(reconstruction, split_size_or_sections=1 * idx, dim=0), dim=2)
                    reconstruction = torch.cat(torch.split(reconstruction, split_size_or_sections=1, dim=0), dim=3)
                    reconstruction = reconstruction.squeeze()[:h, :w]

                    x_hat = reconstruction.cpu().numpy()

                    psnr = PSNR(x_hat * 255, Iorg_y.astype(np.float64), data_range=255)
                    ssim = SSIM(x_hat * 255, Iorg_y.astype(np.float64), data_range=255)
                    mapsnr = compute_mapsnr(x_hat * 255, Iorg_y.astype(np.float64))

                    sum_psnr += psnr
                    sum_ssim += ssim
                    sum_mapsnr += mapsnr

                    Img_rec_yuv[:,:,0] = x_hat * 255
                    im_rec_rgb = cv2.cvtColor(Img_rec_yuv, cv2.COLOR_YCrCb2BGR)
                    im_rec_rgb = np.clip(im_rec_rgb, 0, 255).astype(np.uint8)

                    if save_img:
                        img_path = "./recon_img/Y/{}/{}/".format(item, int(config.ratio * 100))
                        if not os.path.isdir("./recon_img/Y/{}/".format(item)):
                            os.mkdir("./recon_img/Y/{}/".format(item))
                        if not os.path.isdir(img_path):
                            os.mkdir(img_path)
                            print("\rMkdir {}".format(img_path))
                        cv2.imwrite(f"{img_path}/{name}_{round(psnr, 2)}_{round(ssim, 4)}.png", (im_rec_rgb))

            print(f"{i} AVG RES: PSNR={round(sum_psnr / i, 2)}, SSIM={round(sum_ssim / i, 4)}, MaPSNR={round(sum_mapsnr/i, 4)}")


if __name__=="__main__":
    print("Start evaluate...")
    config = utils.GetConfig(ratio=opt.rate, device=opt.device, tune=True)
    net = FPC(LayerNo=10, cs_ratio=opt.rate, rho=opt.rho, beta=opt.beta).to(config.device).eval()
    model_path = os.path.join(config.folder, f"model_rho{opt.rho}_beta{opt.beta}.pth")
    print(model_path)
    if os.path.exists(model_path):
        if torch.cuda.is_available():
            trained_model = torch.load(model_path, map_location=config.device)
        else:
            trained_model = torch.load(model_path, map_location="cpu")
        net.load_state_dict(trained_model, strict=False)
        for rho_, beta_ in zip(net.sam_rhos, net.sam_betas):
            print(rho_, beta_)
        print("Trained model loaded.")
    else:
        raise FileNotFoundError("Missing trained models.")

    testing(net, save_img = config.save)
