import argparse
import logging
import os.path
import sys
import time
from collections import OrderedDict
import torchvision.utils as tvutils

import numpy as np
import torch
import torch.nn.functional as F
from IPython import embed
import lpips

import options as option
from models import create_model


import utils as util
from data import create_dataloader, create_dataset
from data.util import bgr2ycbcr

#### options
parser = argparse.ArgumentParser()
parser.add_argument("-opt", type=str, required=True, help="Path to options YMAL file.")
opt = option.parse(parser.parse_args().opt, is_train=False)

opt = option.dict_to_nonedict(opt)

#### mkdir and logger
util.mkdirs(
    (
        path
        for key, path in opt["path"].items()
        if not key == "experiments_root"
        and "pretrain_model" not in key
        and "resume" not in key
    )
)

os.system("rm ./result")
os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result")

util.setup_logger(
    "base",
    opt["path"]["log"],
    "test_" + opt["name"],
    level=logging.INFO,
    screen=True,
    tofile=True,
)
logger = logging.getLogger("base")
logger.info(option.dict2str(opt))

#### Create test dataset and dataloader
test_loaders = []
for phase, dataset_opt in sorted(opt["datasets"].items()):
    test_set = create_dataset(dataset_opt)
    test_loader = create_dataloader(test_set, dataset_opt)
    logger.info(
        "Number of test images in [{:s}]: {:d}".format(
            dataset_opt["name"], len(test_set)
        )
    )
    test_loaders.append(test_loader)

# load pretrained model by default
model = create_model(opt)
device = model.device


lpips_fn = lpips.LPIPS(net="alex").to(device)

scale = opt["degradation"]["scale"]


def add_gaussian_level(images, level):
    batch_size = images.shape[0]

    stds = (
        torch.tensor([level for _ in range(batch_size)])
        .float()
        .view(batch_size, 1, 1, 1)
    )
    noise = torch.randn_like(images) * stds / 255.0
    noisy_images = images + noise

    return torch.clamp(noisy_images, 0, 1)


def add_salt_and_pepper_noise(images, salt_prob, pepper_prob):
    batch_size, channels, height, width = images.shape
    noisy_images = images.clone()

    # Salt noise
    salt_mask = torch.rand(batch_size, channels, height, width) < salt_prob
    noisy_images[salt_mask] = 1.0

    # Pepper noise
    pepper_mask = torch.rand(batch_size, channels, height, width) < pepper_prob
    noisy_images[pepper_mask] = 0.0

    return torch.clamp(noisy_images, 0, 1)


def add_speckle_noise(images, level):
    batch_size = images.shape[0]
    stds = (
        torch.tensor([level for _ in range(batch_size)])
        .float()
        .view(batch_size, 1, 1, 1)
    )
    noise = torch.randn_like(images) * stds
    noisy_images = images + images * noise

    return torch.clamp(noisy_images, 0, 1)


def add_poisson_noise(images, scale=3.5):
    out = torch.poisson(images * 255) / 255.0
    noise = out - image
    noisy_images = image + noise * scale
    return torch.clamp(noisy_images, 0, 1)


def generate_spatial_gaussian_noise(image, std_dev, kernel_size=3):
    """
    Generate spatially correlated Gaussian noise and add it to the image.

    Parameters:
    - image: Input image tensor (shape: [batch_size, channels, height, width])
    - std_dev: Standard deviation of the Gaussian noise
    - kernel_size: Size of the Gaussian kernel (should be odd, e.g., 3, 5, 7)

    Returns:
    - Noisy image with spatial Gaussian noise added
    """
    # Generate independent Gaussian noise
    noise = torch.randn_like(image) * std_dev / 255

    # Create a Gaussian kernel
    kernel = torch.ones((image.shape[1], 1, kernel_size, kernel_size)) / (
        kernel_size**2
    )

    # Apply the Gaussian filter to introduce spatial correlation
    noise = F.conv2d(noise, kernel, padding=kernel_size // 2, groups=image.shape[1])

    # Add the spatially correlated noise to the image
    noisy_image = image + noise

    # Clamp the image to ensure values are within valid range [0, 1]
    return torch.clamp(noisy_image, 0, 1)


for test_loader in test_loaders:
    for level in opt["test_noise_level"]:
        test_set_name = test_loader.dataset.opt["name"]  # path opt['']
        logger.info("\nTesting [{:s}]...".format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results["psnr"] = []
        test_results["ssim"] = []
        test_results["psnr_y"] = []
        test_results["ssim_y"] = []
        test_results["lpips"] = []
        test_times = []

        for i, test_data in enumerate(test_loader):
            single_img_psnr = []
            single_img_ssim = []
            single_img_psnr_y = []
            single_img_ssim_y = []
            # need_GT = False if test_loader.dataset.opt["dataroot_GT"] is None else True
            need_GT = True
            img_path = test_data["GT_path"][0] if need_GT else test_data["LQ_path"][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            #### input dataset_LQ
            LQ, GT = test_data["LQ"], test_data["GT"]
    s

            model.feed_data(LQ, GT)
            tic = time.time()

            model.test()
            toc = time.time()
            test_times.append(toc - tic)

            visuals = model.get_current_visuals()
            SR_img = visuals["Output"]
            output = util.tensor2img(SR_img.squeeze())  # uint8
            LQ_ = util.tensor2img(visuals["Input"].squeeze())  # uint8
            GT_ = util.tensor2img(visuals["GT"].squeeze())  # uint8


            suffix = opt["suffix"]
            if suffix:
                save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png")
            else:
                save_img_path = os.path.join(dataset_dir, img_name + ".png")
            LQ_img_path = os.path.join(
                "your path",
                test_set_name,
                img_name + "_LQ.png",
            )
            GT_img_path = os.path.join(
                "your path",
                test_set_name,
                img_name + "_HQ.png",
            )
            Output_img_path = os.path.join(
                "your path",
                test_set_name,
                img_name + "_Output.png",
            )
            util.save_img(LQ_, LQ_img_path)
            util.save_img(GT_, GT_img_path)
            util.save_img(output, Output_img_path)


            if need_GT:
                gt_img = GT_ / 255.0
                sr_img = output / 255.0

                crop_border = opt["crop_border"] if opt["crop_border"] else scale
                if crop_border == 0:
                    cropped_sr_img = sr_img
                    cropped_gt_img = gt_img
                else:
                    cropped_sr_img = sr_img[
                        crop_border:-crop_border, crop_border:-crop_border
                    ]
                    cropped_gt_img = gt_img[
                        crop_border:-crop_border, crop_border:-crop_border
                    ]

                # psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
                # ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)
                psnr = util.calculate_psnr(GT_, output)
                ssim = util.calculate_ssim(GT_, output)
                lp_score = (
                    lpips_fn(GT.to(device) * 2 - 1, SR_img.to(device) * 2 - 1)
                    .squeeze()
                    .item()
                )

                test_results["psnr"].append(psnr)
                test_results["ssim"].append(ssim)
                test_results["lpips"].append(lp_score)

                if len(gt_img.shape) == 3:
                    pass
                    
                else:
                    logger.info(
                        "img:{:15s} - PSNR: {:.6f} dB; SSIM: {:.6f}.".format(
                            img_name, psnr, ssim
                        )
                    )

                    test_results["psnr_y"].append(psnr)
                    test_results["ssim_y"].append(ssim)
            else:
                logger.info(img_name)

        ave_lpips = sum(test_results["lpips"]) / len(test_results["lpips"])
        ave_psnr = sum(test_results["psnr"]) / len(test_results["psnr"])
        ave_ssim = sum(test_results["ssim"]) / len(test_results["ssim"])
        logger.info(
            "----Average PSNR/SSIM results for {}_{}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n".format(
                test_set_name, level, ave_psnr, ave_ssim
            )
        )
        if test_results["psnr_y"] and test_results["ssim_y"]:
            ave_psnr_y = sum(test_results["psnr_y"]) / len(test_results["psnr_y"])
            ave_ssim_y = sum(test_results["ssim_y"]) / len(test_results["ssim_y"])
            logger.info(
                "----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n".format(
                    ave_psnr_y, ave_ssim_y
                )
            )

        logger.info("----average LPIPS\t: {:.6f}\n".format(ave_lpips))

        print(f"average test time: {np.mean(test_times):.4f}")
