import argparse
import os
from tqdm import tqdm
import blobfile as bf
import numpy as np
import torch as th
import torch.distributed as dist

from image_adapt.guided_diffusion import dist_util, logger
from image_adapt.guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    args_to_dict,
    add_dict_to_argparser,
)
from image_adapt.guided_diffusion.image_datasets import load_data
from torchvision import utils
import math
from image_adapt.resize_right import resize
import time
from motionblur.motionblur import *
from image_adapt.guided_diffusion.svd_replacement import *
from image_adapt.guided_diffusion.corruption_function import *
from image_adapt.guided_diffusion.jpeg_torch import *
# from imagecorruptions import corrupt
import torchvision.transforms as transforms
from image_adapt.guided_diffusion.ssim import *
# from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity



# added
def load_reference(data_dir, batch_size, image_size, class_cond=False, corruption="shot_noise", severity=5,):
    data = load_data(
        data_dir=data_dir,
        batch_size=batch_size,
        image_size=image_size,
        class_cond=class_cond,
        deterministic=True,
        random_flip=False,
        corruption=corruption,
        severity=severity,
    )
    for large_batch, model_kwargs, filename in data:
        model_kwargs["ref_img"] = large_batch
        yield model_kwargs, filename


def main():
    args = create_argparser().parse_args()

    seed = 1234
    # th.manual_seed(0)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    dist_util.setup_dist()
    logger.configure(dir=args.save_dir)

    logger.log("creating model...")
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    model.load_state_dict(
        dist_util.load_state_dict(args.model_path, map_location="cpu")
    )
    model.to(dist_util.dev())
    if args.use_fp16:
        model.convert_to_fp16()
    model.eval()

    logger.log("creating resizers...")
    assert math.log(args.D, 2).is_integer()

    shape = (args.batch_size, 3, args.image_size, args.image_size)
    shape_d = (args.batch_size, 3, int(args.image_size / args.D), int(args.image_size / args.D))

    logger.log("loading data...")
    data = load_reference(
        args.base_samples,
        args.batch_size,
        image_size=args.image_size,
        class_cond=args.class_cond,
        corruption=args.corruption,
        severity=args.severity,
    )
    

    avg_psnr = 0
    cnt = 0
    mae = 0
    mae_update = 0
    count = 0
    logger.log("creating samples...")

    ssim_acc_list = []
    ssim_psnr_list = []
    ssim_f_alex_list = []
    ssim_f_vgg_list = []
    
    qf_acc = []
    qf_acc_10 = []

    for j in tqdm(range(args.num_samples //  args.batch_size )):

        # fun_lpips_alex = LearnedPerceptualImagePatchSimilarity(net_type="alex").to(dist_util.dev())
        # fun_lpips_vgg = LearnedPerceptualImagePatchSimilarity(net_type="vgg").to(dist_util.dev())

        kernel_batch, kernel_uncert_batch = edr46("motionblur", "gauss_init", args.batch_size, dist_util.dev())

        # kernel_batch = th.from_numpy(np.load(f'dataset/kernel_{j}_2.npy')).to(dist_util.dev())
        
        model_kwargs, filename = next(data)
        model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()}

        c = kernel_batch

        # missing_r = torch.randperm(256**2)[:256**2 * 4 //5].to(dist_util.dev()).long() * 3
        missing_r = torch.randperm(256**2)[:256**2 //2].to(dist_util.dev()).long() * 3
        missing_g = missing_r + 1
        missing_b = missing_g + 1
        missing = torch.cat([missing_r, missing_g, missing_b], dim=0)
        H_funcs_inpaint = Inpainting(3, 256, missing, dist_util.dev())
        
        print("======== Source: ", args.source)
        if args.source == "blur":
            H_funcs = DeblurringArbitral2DFull(kernel_batch, 3, 256, dist_util.dev(), conv_shape="same")
            y_0 = H_funcs.H(model_kwargs["ref_img"])
            
        elif args.source == "jpeg":
            qf_init = 10
            y_0 = jpeg_decode(jpeg_encode(model_kwargs["ref_img"], qf_init), qf_init)
            
        elif args.source == "SR":
            blur_by = 4
            H_funcs_SR = SuperResolution(3, 256, blur_by, dist_util.dev())
            y_0 = H_funcs_SR.H_pinv(H_funcs_SR.H(model_kwargs["ref_img"])).view(args.batch_size, 3, 256, 256) 

        elif args.source == "SR_kernel":
            H_funcs = DeblurringArbitral2DFull(kernel_batch, 3, 256, dist_util.dev(), conv_shape="same")
            y_0 = H_funcs.H(model_kwargs["ref_img"])
            blur_by = 4
            H_funcs_SR = SuperResolution(3, 256, blur_by, dist_util.dev())
            y_0 = H_funcs_SR.H_pinv(H_funcs_SR.H(y_0)).view(args.batch_size, 3, 256, 256)
        
        elif args.source == "inpaint":
            y_0 = H_funcs_inpaint.H(model_kwargs["ref_img"])
            
        elif args.source == "mask_sr":
            blur_by = 4
            H_funcs_SR = SuperResolution(3, 256, blur_by, dist_util.dev())
            y_0 = H_funcs_SR.H_pinv(H_funcs_SR.H(model_kwargs["ref_img"])).view(args.batch_size, 3, 256, 256) 
            y_0 = H_funcs_inpaint.H(y_0)

        H_funcs_uncert_blur = DeblurringArbitral2DFull(kernel_uncert_batch, 3, 256, dist_util.dev(), conv_shape="same")

        blur_by = 4
        H_funcs_SR = SuperResolution(3, 256, blur_by, dist_util.dev())
        H_funcs_color = Colorization(256, dist_util.dev())           
        H_funcs_uncert_SR = H_funcs_SR
        # compress_by = 2
        # H_funcs_cs = WalshHadamardCS(3, 256, compress_by, torch.randperm(256**2, device=dist_util.dev()), dist_util.dev())

        H_funcs_uncert_jpeg = None

        # missing_r = torch.randperm(256**2)[:256**2 * 4 //5].to(dist_util.dev()).long() * 3
        # H_funcs_inpaint = Inpainting(3, 256, missing_r, dist_util.dev())


        sigma_y = args.sigma_y
        model_kwargs["contrast_img"] = y_0 + 2 * sigma_y * torch.randn_like(y_0)


        pinv_y_0 = H_funcs_inpaint.H_pinv(model_kwargs["contrast_img"]).view(args.batch_size, 3, 256, 256)
        if args.source == "inpaint" or args.source == "mask_sr":
            model_kwargs["contrast_img"] = pinv_y_0 + H_funcs_inpaint.H_pinv(H_funcs_inpaint.H(torch.ones_like(pinv_y_0))).reshape(*pinv_y_0.shape) - 1
    
        tmp = model_kwargs["ref_img"].clone()
        model_kwargs["ref_img"] = None
        H_funcs = None

        # start_time = time.time()
        sample_blur, H_estimate = diffusion.ddim_sample_loop(
            model,
            (model_kwargs["contrast_img"].shape[0], 3, args.image_size, args.image_size),
            clip_denoised=args.clip_denoised,
            model_kwargs=model_kwargs,
            eta=1.0,
            H_funcs=H_funcs_uncert_blur,
            noise=model_kwargs["contrast_img"],
            N=65,
            # qf=10,
            target="blur"
        )
        # end_time = time.time()
        # print(end_time - start_time)
        
        # qf_randn = th.randint(1, 101, (1,)).item()
        qf_randn = 10
        # print("===========", qf_randn)
        sample_jpeg, qf_estimate = diffusion.ddim_sample_loop(
            model,
            (model_kwargs["contrast_img"].shape[0], 3, args.image_size, args.image_size),
            clip_denoised=args.clip_denoised,
            model_kwargs=model_kwargs,
            eta=1.0,
            H_funcs=H_funcs,
            noise=model_kwargs["contrast_img"],
            N=65,
            target="jpeg",
            qf=qf_randn,
        )
        # model_kwargs["contrast_img"] = sample_blur
        
        sample_sr, H_tmp = diffusion.ddim_sample_loop(
            model,
            (model_kwargs["contrast_img"].shape[0], 3, args.image_size, args.image_size),
            clip_denoised=args.clip_denoised,
            model_kwargs=model_kwargs,
            eta=1.0,
            H_funcs=H_funcs_SR,
            noise=model_kwargs["contrast_img"],
            N=65,
            target="SR",
            qf=qf_randn,
        )

        sample_inpaint, H_tmp = diffusion.ddim_sample_loop(
            model,
            (model_kwargs["contrast_img"].shape[0], 3, args.image_size, args.image_size),
            clip_denoised=args.clip_denoised,
            model_kwargs=model_kwargs,
            eta=1.0,
            H_funcs=H_funcs_inpaint,
            noise=model_kwargs["contrast_img"],
            N=65,
            target="inpaint"
        )

        sample_mask_sr, H_tmp = diffusion.ddim_sample_loop(
            model,
            (model_kwargs["contrast_img"].shape[0], 3, args.image_size, args.image_size),
            clip_denoised=args.clip_denoised,
            model_kwargs=model_kwargs,
            eta=1.0,
            H_funcs=H_funcs_inpaint,
            noise=model_kwargs["contrast_img"],
            N=65,
            target="mask_sr"
        )


        sample_SR_kernel, H_SR_kernel = diffusion.ddim_sample_loop(
            model,
            (model_kwargs["contrast_img"].shape[0], 3, args.image_size, args.image_size),
            clip_denoised=args.clip_denoised,
            model_kwargs=model_kwargs,
            eta=1.0,
            H_funcs=H_funcs_uncert_blur,
            noise=model_kwargs["contrast_img"],
            N=65,
            target="SR_kernel",
            qf=qf_randn,
        )

        print("qf_estimate: ", qf_estimate)
        ssim_jpeg = ssim(jpeg_decode(jpeg_encode(sample_jpeg, qf_estimate), qf_estimate), model_kwargs["contrast_img"]).item()

        ssim_blur = ssim(H_estimate.H(sample_blur), model_kwargs["contrast_img"]).item()

        ssim_SR = ssim(H_funcs_SR.H_pinv(H_funcs_SR.H(sample_sr)).view(args.batch_size, 3, 256, 256),  model_kwargs["contrast_img"]).item()
        
        ssim_inpaint = ssim(H_funcs_inpaint.H_pinv(H_funcs_inpaint.H(sample_inpaint)).view(args.batch_size, 3, 256, 256)+H_funcs_inpaint.H_pinv(H_funcs_inpaint.H(torch.ones_like(pinv_y_0))).reshape(*pinv_y_0.shape) - 1,  model_kwargs["contrast_img"]).item()

        tmp_inpaint = H_funcs_SR.H_pinv(H_funcs_SR.H(sample_mask_sr)).view(args.batch_size, 3, 256, 256)
        
        ssim_mask_sr = ssim(H_funcs_inpaint.H_pinv(H_funcs_inpaint.H(tmp_inpaint)).view(args.batch_size, 3, 256, 256)+H_funcs_inpaint.H_pinv(H_funcs_inpaint.H(torch.ones_like(pinv_y_0))).reshape(*pinv_y_0.shape) - 1, model_kwargs["contrast_img"]).item()
        
        ssim_SR_kernel = ssim(H_funcs_SR.H_pinv(H_funcs_SR.H(H_SR_kernel.H(sample_SR_kernel))).view(args.batch_size, 3, 256, 256), model_kwargs["contrast_img"]).item()

        ssim_res = "jpeg"
        ssim_sample = ssim_jpeg
        sample_select = sample_jpeg
        
        if ssim_sample < ssim_blur:
            ssim_res = "blur"
            ssim_sample = ssim_blur
            sample_select = sample_blur

        if ssim_sample < ssim_SR:
            ssim_res = "SR"
            ssim_sample = ssim_SR
            sample_select = sample_sr

        if ssim_sample < ssim_inpaint:
            ssim_res = "inpaint"
            ssim_sample = ssim_inpaint
            sample_select = sample_inpaint

        if ssim_sample < ssim_mask_sr:
            ssim_res = "mask_sr"
            ssim_sample = ssim_mask_sr
            sample_select = sample_mask_sr

        if ssim_sample < ssim_SR_kernel:
            ssim_res = "SR_kernel"
            ssim_sample = ssim_SR_kernel
            sample_select = sample_SR_kernel
            
        print(ssim_res, ssim_sample)
       
        if ssim_res == args.source:
            ssim_acc_list.append(1)
        else:
            ssim_acc_list.append(0)
        
        
        model_kwargs["ref_img"] = tmp
        for i in range(args.batch_size):
            path = os.path.join(logger.get_dir(), args.corruption, "gt", filename[0].split('/')[0])
            os.makedirs(path, exist_ok=True)
            out_path = os.path.join(path, filename[0].split('/')[1])

            utils.save_image(
                model_kwargs["ref_img"][i].unsqueeze(0),
                out_path.split('.')[0]+".png",
                nrow=1,
                normalize=True,
                value_range=(-1, 1),
            )

        for i in range(args.batch_size):
            path = os.path.join(logger.get_dir(), args.corruption, "input", filename[0].split('/')[0]+"_gt")
            os.makedirs(path, exist_ok=True)
            out_path = os.path.join(path, filename[0].split('/')[1])

            utils.save_image(
                model_kwargs["contrast_img"][i].unsqueeze(0),
                out_path.split('.')[0]+".png",
                nrow=1,
                normalize=True,
                value_range=(-1, 1),
            )

        #####   SSIM   #####
        model_kwargs["ref_img"] = tmp
        for i in range(args.batch_size):
            path = os.path.join(logger.get_dir(), args.corruption, "select", filename[0].split('/')[0])
            os.makedirs(path, exist_ok=True)
            out_path = os.path.join(path, filename[0].split('/')[1])

            utils.save_image(
                sample_select[i].unsqueeze(0),
                out_path.split('.')[0]+f"_{ssim_res}.png",
                nrow=1,
                normalize=True,
                value_range=(-1, 1),
            )
            # f_alex = fun_lpips_alex(((ssim_sample[i:i+1]+1.0)/2.0).clamp(0.0, 1.0), ((model_kwargs["ref_img"][i:i+1]+1.0)/2.0).clamp(0.0, 1.0))
            # f_vgg = fun_lpips_vgg(((ssim_sample[i:i+1]+1.0)/2.0).clamp(0.0, 1.0), ((model_kwargs["ref_img"][i:i+1]+1.0)/2.0).clamp(0.0, 1.0))
            # ssim_f_alex_list.append(f_alex.detach())
            # ssim_f_vgg_list.append(f_vgg.detach())
            mse = th.mean(( ((sample_select[i]+1.0)/2.0).clamp(0.0, 1.0) - ((model_kwargs["ref_img"][0]+1.0)/2.0).clamp(0.0, 1.0) ) ** 2)
            psnr = 10 * th.log10(1 / mse)
            ssim_psnr_list.append(psnr.detach())
            cnt+=1
            # print("sample: PSNR: ", psnr, "f_alex: ", f_alex, "f_vgg: ", f_vgg)


    # dist.barrier()
    # logger.log("sampling complete")
    logger.log("Accuracy SSIM: %.4f" %(sum(ssim_acc_list)/len(ssim_acc_list)))

    logger.log("Total Average SSIM PSNR : %.4f" % (sum(ssim_psnr_list)/len(ssim_psnr_list)))
    # logger.log("ssim_f_alex_list: %.4f" % (sum(ssim_f_alex_list)/len(ssim_f_alex_list)))
    # logger.log("ssim_f_vgg_list: %.4f" % (sum(ssim_f_vgg_list)/len(ssim_f_vgg_list)))
    logger.log("qf_acc: %.4f" % (sum(qf_acc)/len(qf_acc)))
    logger.log("qf_acc_10: %.4f" % (sum(qf_acc_10)/len(qf_acc_10)))

    logger.log("Number of samples: %d" % cnt)


def create_argparser():
    defaults = dict(
        clip_denoised=True,
        num_samples=10000,
        batch_size=4,
        D=4, # scaling factor
        N=50,
        use_ddim=False,
        base_samples="",
        model_path="",
        save_dir="",
        corruption="shot_noise",
        severity=5,
        scale=1,
        source=None,
        target=None,
        sigma_y=0.02,
    )
    defaults.update(model_and_diffusion_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()
