import os
import logging
import time
import glob

from skimage.metrics import structural_similarity as ssim
import numpy as np
import tqdm
import torch
import torch.utils.data as data

# from models.diffusion import Model
from datasets import get_dataset, data_transform, inverse_data_transform
from functions.ckpt_util import get_ckpt_path, download
from inverse_utils.denoising_by_lmap_rps import resample, latent_dps, psld, stsl, ldir, lmap_rps, daps_latent, dcdp_latent, dmap_latent, sitcom_latent
import lpips

import torchvision.utils as tvu


import random
import yaml
from inverse_utils.default_lr import get_default_lr
from guided_diffusion.unet_ffhq import create_model as create_model_ffhq
from diffusers import StableDiffusionPipeline, DDIMScheduler
import torchvision.transforms as T
import time


def load_yaml(file_path: str) -> dict:
    with open(file_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config


class SD15:
    def __init__(self, model_id="exp/logs/runwayml/stable-diffusion-v1-5", device="cuda"):
        self.device = device
        self.model_id = model_id

        # -------------------------
        self.pipe = StableDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16,
            safety_checker=None
        ).to(device)
        self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
        self.alphas_cumprod = self.pipe.scheduler.alphas_cumprod
        self.unet = self.pipe.unet
        self.vae = self.pipe.vae
        self.unet.eval()
        self.vae.eval()
        self.tokenizer = self.pipe.tokenizer
        self.text_encoder = self.pipe.text_encoder
    
    def apply_model(self, z_t, t, prompt, cfg=1.5):
        with torch.autocast("cuda", dtype=torch.float16):
            tokens = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            text_embeddings = self.text_encoder(**tokens).last_hidden_state
            with torch.autocast("cuda"):
                noise_pred_cond = self.unet(z_t, t, encoder_hidden_states=text_embeddings).sample
            if cfg == 1.0:
                return noise_pred_cond
            else:
                uncond_tokens = self.tokenizer(
                    [""] * z_t.shape[0], return_tensors="pt"
                ).to(self.device)
                uncond_embeddings = self.text_encoder(**uncond_tokens).last_hidden_state
                with torch.autocast("cuda"):
                    noise_pred_uncond = self.unet(z_t, t, encoder_hidden_states=uncond_embeddings).sample
                noise_pred = noise_pred_uncond + cfg * (noise_pred_cond - noise_pred_uncond)
                return noise_pred
    
    def encode_first_stage(self, x):
        with torch.autocast("cuda", dtype=torch.float16):
            latent = self.vae.encode(x).latent_dist.sample() * self.vae.config.scaling_factor
            return latent
    
    def decode_first_stage(self, z):
        with torch.autocast("cuda", dtype=torch.float16):
            recon_img = self.vae.decode(z / self.vae.config.scaling_factor).sample
        return recon_img


class Diffusion(object):
    def __init__(self, args, config, device=None):
        self.args = args
        self.config = config
        self.config.data.image_size = 512  # for coco
        if device is None:
            device = (
                torch.device("cuda")
                if torch.cuda.is_available()
                else torch.device("cpu")
            )
        self.device = device

    def sample(self):
        # cls_fn = None
        model = SD15(device=self.device)
        # print(model)
        self.alphas_cumprod = model.alphas_cumprod.cuda()
        self.num_timesteps = len(self.alphas_cumprod)
        self.sample_sequence(model)

    def sample_sequence(self, model, cls_fn=None):
        args, config = self.args, self.config

        #get original images and corrupted y_0
        # dataset, test_dataset = get_dataset(args, config)
        from datasets.cocoval import CocoVal100
        test_dataset = CocoVal100(data_root="exp/datasets/coco_100_for_inversion_512")
        
        device_count = torch.cuda.device_count()
        
        if args.subset_start >= 0 and args.subset_end > 0:
            assert args.subset_end > args.subset_start
            test_dataset = torch.utils.data.Subset(test_dataset, range(args.subset_start, args.subset_end))
        else:
            args.subset_start = 0
            args.subset_end = len(test_dataset)

        print(f'Dataset has size {len(test_dataset)}')    
        
        def seed_worker(worker_id):
            worker_seed = args.seed % 2**32
            np.random.seed(worker_seed)
            random.seed(worker_seed)

        g = torch.Generator()
        g.manual_seed(args.seed)
        # if 'phase' in args.deg:
        #     if config.sampling.batch_size > 1:
        #         key = input('Recommend using batch size 1. Current batch size is {}, switch to 1? [y/n]'.format(config.sampling.batch_size))
        #         if key == 'y':
        #             config.sampling.batch_size = 1
        #             print('switch to 1')
        #         else:
        #             print('keep using {}'.format(config.sampling.batch_size))
        val_loader = data.DataLoader(
            test_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=4,
            worker_init_fn=seed_worker,
            generator=g,
        )
        

        ## get degradation matrix ##
        deg = args.deg
        H_funcs = None
        
        if 'sr' in deg:
            if deg[:10] == 'sr_bicubic':
                factor = int(deg[10:])
                from obs_functions.Hfuncs import SRConv
                def bicubic_kernel(x, a=-0.5):
                    if abs(x) <= 1:
                        return (a + 2)*abs(x)**3 - (a + 3)*abs(x)**2 + 1
                    elif 1 < abs(x) and abs(x) < 2:
                        return a*abs(x)**3 - 5*a*abs(x)**2 + 8*a*abs(x) - 4*a
                    else:
                        return 0
                k = np.zeros((factor * 4))
                for i in range(factor * 4):
                    x = (1/factor)*(i- np.floor(factor*4/2) +0.5)
                    k[i] = bicubic_kernel(x)
                k = k / np.sum(k)
                kernel = torch.from_numpy(k).float().to(self.device)
                H_funcs = SRConv(kernel / kernel.sum(), \
                                self.config.data.channels, self.config.data.image_size, self.device, stride = factor)
            else:
                # Super-Resolution
                blur_by = int(deg[2:])
                from obs_functions.Hfuncs import SuperResolution
                H_funcs = SuperResolution(config.data.channels, config.data.image_size, blur_by, self.device)
        elif 'inp' in deg:
            # Random inpainting
            missing_r = torch.randperm(config.data.image_size**2)[:config.data.image_size**2 // 2].to(self.device).long()
            from obs_functions.Hfuncs import Inpainting
            H_funcs = Inpainting(config.data.channels, config.data.image_size, missing_r, self.device)
        elif deg == 'deblur_aniso':
            from obs_functions.Hfuncs import Deblurring2D
            sigma = 20
            pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x/sigma)**2]))
            kernel2 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to(self.device)
            sigma = 1
            pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x/sigma)**2]))
            kernel1 = torch.Tensor([pdf(-4), pdf(-3), pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2), pdf(3), pdf(4)]).to(self.device)
            H_funcs = Deblurring2D(kernel1 / kernel1.sum(), kernel2 / kernel2.sum(), self.config.data.channels, self.config.data.image_size, self.device)
        elif 'cs' in deg:
            compress_by = int(deg[2:])
            from obs_functions.Hfuncs import WalshHadamardCS
            H_funcs = WalshHadamardCS(self.config.data.channels, self.config.data.image_size, compress_by, torch.randperm(self.config.data.image_size**2, device=self.device), self.device)
        elif deg == 'deblur_gaussian_61':
            from obs_functions.Hfuncs import GaussianBlurOperator
            H_funcs = GaussianBlurOperator(kernel_size=61, intensity=3.0, device=self.device)
        elif 'deblur_gauss' in deg:
            # Gaussian Deblurring
            from obs_functions.Hfuncs import Deblurring
            sigma = 10
            pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x/sigma)**2]))
            kernel = torch.Tensor([pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2)]).to(self.device)
            H_funcs = Deblurring(kernel / kernel.sum(), config.data.channels, self.config.data.image_size, self.device)
        elif deg == 'denoise':
            from obs_functions.Hfuncs import Denoising
            H_funcs = Denoising(config.data.channels, config.data.image_size, self.device)
        elif 'phase' in deg:
            # Phase Retrieval
            from obs_functions.Hfuncs import PhaseRetrievalOperator
            H_funcs = PhaseRetrievalOperator(oversample=2.0, device=self.device)
        elif 'hdr' in deg:
            # HDR
            from obs_functions.Hfuncs import HDR
            H_funcs = HDR()
        elif deg == 'deblur_nonlinear':
            from obs_functions.Hfuncs import NonlinearBlurOperator
            H_funcs = NonlinearBlurOperator(self.device, opt_yml_path='./bkse/options/generate_blur/default.yml')   
        else:
            print("ERROR: degradation type not supported")
            quit()

        # for linear observations
        if 'sr' in deg or 'inp' in deg or 'deblur_gauss' in deg:
            args.sigma_0 = 2 * args.sigma_0 #to account for scaling to [-1,1]
        sigma_0 = args.sigma_0

        
        # step size
        if args.default_lr: # using default step size to reproduce the metrics
            N = 1
            steps=args.timesteps
            if 'imagenet' in args.config:
                dataset_name = 'imagenet'
            elif 'celeba' in args.config:
                dataset_name = 'celeba'
            elif 'ffhq' in args.config:
                dataset_name = 'ffhq'
            else:
                dataset_name = 'unknown'
            # print(deg)
            # print(steps)
            # print(sigma_0)
            # print(dataset_name)
            lr = get_default_lr(deg, steps, sigma_0, dataset_name)
        else:
            lr = args.lr
            N = args.N

        print(f'Start from {args.subset_start}')
        idx_init = args.subset_start
        idx_so_far = args.subset_start
        avg_psnr = 0.0
        avg_ssim = 0.0
        avg_lpips = 0.0
        avg_rmse = 0.0
        pbar = tqdm.tqdm(val_loader)
        loss_fn_vgg = lpips.LPIPS(net='vgg').cuda()
        with torch.no_grad():
            for batch in pbar:
                # print(batch)
                for _ in range(1):
                    x_orig = batch['image']
                    prompt = batch['caption']
                    x_orig = x_orig.to(self.device)
                    x_orig = data_transform(self.config, x_orig)
                    # print(x_orig.shape)
                    latent = model.encode_first_stage(x_orig)
                    recon = model.decode_first_stage(latent)
                    print(torch.norm(recon - x_orig) / torch.norm(x_orig))
                    recon = recon.detach().cpu()

                    img_tensor = recon[0]

                    img_tensor = (img_tensor + 1.0) / 2.0
                    img_tensor = torch.clamp(img_tensor, 0.0, 1.0)

                    to_pil = T.ToPILImage()
                    img = to_pil(img_tensor)

                    os.makedirs("output", exist_ok=True)
                    img.save("output/recon.png")

                    y_0 = H_funcs.forward(x_orig)
                    y_0 = y_0 + sigma_0 * torch.randn_like(y_0)
                    y_pinv = H_funcs.H_pinv(y_0).view(y_0.shape[0], config.data.channels, self.config.data.image_size, self.config.data.image_size)
                    # y_pinv = y_0.view(y_0.shape[0], config.data.channels, self.config.data.image_size, self.config.data.image_size)
                    # print(y_0.shape)
                    for i in range(len(y_0)):
                        tvu.save_image(
                            inverse_data_transform(config, y_pinv[i]), os.path.join(self.args.image_folder, f"y0_{idx_so_far + i}.png")
                        )
                        tvu.save_image(
                            inverse_data_transform(config, x_orig[i]), os.path.join(self.args.image_folder, f"orig_{idx_so_far + i}.png")
                        )
                    # continue
                    ##Begin DDIM
                    x = torch.randn(
                        y_0.shape[0],
                        4,
                        64,
                        64,
                        device=self.device,
                    )
                    with torch.autocast("cuda", dtype=torch.float16):
                        with torch.no_grad():
                            x, _ = self.sample_image(x, model, H_funcs, y_0, sigma_0, lr, N, optimize_iters=args.optimize_iters, vae_lr=args.vae_lr, w_prior=args.w_prior, noise_t=args.noise_t, lam=args.lam, renoise_t=args.renoise_t, eta_min=args.eta_min, ps_method=args.ps_method, stable=args.stable, last=False, prompt=prompt, classes=None)

                    x = [inverse_data_transform(config, y) for y in x]
                    flag = 0
                    for i in [-1]: #range(len(x)):
                        for j in range(x[i].size(0)):
                            tvu.save_image(
                                x[i][j], os.path.join(self.args.image_folder, f"{idx_so_far + j}_{i}.png")
                            )
                            if i == len(x)-1 or i == -1:
                                orig = inverse_data_transform(config, x_orig[j])
                                # print(torch.norm(orig[0]))
                                mse = torch.mean((x[i][j].to(self.device) - orig) ** 2)
                                psnr = 10 * torch.log10(1 / mse)
                                if torch.isnan(psnr):
                                    pass
                                else:
                                    avg_psnr += psnr
                                    # print(x[i][j].shape)
                                    avg_ssim += ssim(x[i][j].numpy(), orig.cpu().numpy(), data_range=x[i][j].numpy().max() - x[i][j].numpy().min(), channel_axis=0)
                                    LPIPS = loss_fn_vgg(2*orig-1.0, 2*torch.tensor(x[i][j]).to(torch.float32).cuda()-1.0)
                                    avg_lpips += LPIPS[0,0,0,0]
                                    avg_rmse += mse.sqrt()
                                    flag = 1
                    if flag:
                        break
                idx_so_far += y_0.shape[0]

                pbar.set_description("PSNR:{}, SSIM:{}, LPIPS:{}, RMSE:{}".format(avg_psnr / (idx_so_far - idx_init), avg_ssim / (idx_so_far - idx_init), avg_lpips / (idx_so_far - idx_init), avg_rmse / (idx_so_far - idx_init)))
                
            avg_psnr = avg_psnr / (idx_so_far - idx_init)
            avg_ssim = avg_ssim / (idx_so_far - idx_init)
            avg_lpips = avg_lpips / (idx_so_far - idx_init)
            print("Total Average PSNR: %.2f, Total Average SSIM: %.4f, Total Average LPIPS: %.4f" % (avg_psnr, avg_ssim, avg_lpips))
            print("Number of samples: %d" % (idx_so_far - idx_init))

    def sample_image(self, x, model, H_funcs, y_0, sigma_0, lr, N, optimize_iters=200, vae_lr=0.5, w_prior=0.15, noise_t=50, renoise_t=100, lam=1.0, eta_min=1e-5, ps_method='latent_dps', stable=False, last=True, prompt=None, classes=None):
        # y_0 = y_0.detach().clone().float()
        skip = self.num_timesteps // self.args.timesteps
        seq = range(0, self.num_timesteps, skip)
        # if H_funcs.get_type() == 'SVD' and sigma_0 > 0:
        #     x = efficient_generalized_steps_noisy_SVD(x, seq, model, self.alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, cls_fn=cls_fn, classes=classes)
        # elif sigma_0 > 0:
        #     x = efficient_generalized_steps_noisy(x, seq, model, self.alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, cls_fn=cls_fn, classes=classes)
        # elif 'phase' in self.args.deg:
        #     x = efficient_generalized_steps_phase(x, seq, model, self.alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, cls_fn=cls_fn, classes=classes)
        # else:
        #     x = efficient_generalized_steps(x, seq, model, self.alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, cls_fn=cls_fn, classes=classes)
        # with model.ema_scope("Plotting"):
        start = time.time()
        if self.args.algo == 'resample':
            x = resample(x, seq, model, self.alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, prompt=prompt, classes=classes)
        elif self.args.algo == 'latent_dps':
            x = latent_dps(x, seq, model, self.alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, prompt=prompt, classes=classes)
        elif self.args.algo == 'psld':
            x = psld(x, seq, model, self.alphas_cumprod, H_funcs, y_0, sigma_0, lr, lam, N, prompt=prompt, classes=classes)
        elif self.args.algo == 'stsl':
            x = stsl(x, seq, model, self.alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, stepsize=lam, prompt=prompt, classes=classes)
        elif self.args.algo == 'ldir':
            x = ldir(x, seq, model, self.alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, prompt=prompt, classes=classes)
        elif self.args.algo == 'daps':
            x = daps_latent(x, seq, model, self.alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, prompt=prompt, classes=classes)
        elif self.args.algo == 'dcdp':
            skip = 400 // self.args.timesteps
            seq = range(0, 400, skip)
            x = dcdp_latent(x, seq, model, self.alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, optimize_iters=optimize_iters, prompt=prompt, classes=classes)
        elif self.args.algo == 'dmap':
            x = dmap_latent(x, seq, model, self.alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, prompt=prompt, classes=classes)
        elif self.args.algo == 'sitcom':
            x = sitcom_latent(x, seq, model, self.alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, optimize_iters=optimize_iters, prompt=prompt, classes=classes)
        elif self.args.algo == 'lmap_rps' or 'lmap_rps' in self.args.algo:
            x = lmap_rps(x, seq, model, self.alphas_cumprod, H_funcs, y_0, sigma_0, lr, N, optimize_iters=optimize_iters, vae_lr=vae_lr, w_prior=w_prior, noise_t=noise_t, renoise_t=renoise_t, lam=lam, ps_method=ps_method, eta_min=eta_min, stable=stable, prompt=prompt, classes=classes)
        if last:
            x = x[0][-1]
        end = time.time()
        print("Time taken for sampling:{:.2f} seconds".format(end - start))
        return x