import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
os.chdir(os.path.dirname(os.getcwd()))
import copy
import argparse
import random
from pathlib import Path
from easydict import EasyDict as edict

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.distributed as dist
from torch.multiprocessing import Process
from torch.utils.data import DataLoader, Subset
from torch_ema import ExponentialMovingAverage
import torchvision.utils as tu
from i2sb.runner_sharp import Runner


from logger import Logger
import distributed_util as dist_util
# from i2sb import Runner, download_ckpt

from corruption import build_corruption
from dataset import imagenet
from i2sb import ckpt_util
from measures import *
from corruption.sisr import *

import colored_traceback.always
from ipdb import set_trace as debug
import cv2
from dataset.imagenet_subset import ImageDataset


RESULT_DIR = Path("./checkpoints/")

def set_seed(seed):
    # https://github.com/pytorch/pytorch/issues/7068
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU


def compute_batch(ckpt_opt, corrupt_type, corrupt_method, out):
    clean_img, y = out
    mask = None
    corrupt_img_pinv, corrupt_img_y = corrupt_method(clean_img.to(opt.device))
    corrupt_img = corrupt_img_pinv
    x1 = corrupt_img.to(opt.device)
    x1_pinv = x1
    x1_forw = corrupt_img_y.to(opt.device)

    cond = x1.detach() if ckpt_opt.cond_x1 else None

    return corrupt_img, x1, mask, cond, y, clean_img, x1_pinv, x1_forw


# @torch.no_grad()
def main(opt):
    log = Logger(opt.global_rank, ".log")

    # get (default) ckpt option
    ckpt_opt = ckpt_util.build_ckpt_option(opt, log, opt.ckpt)
    corrupt_type = 'sisr'
    nfe = opt.nfe or ckpt_opt.interval - 1

    # build corruption method
    corrupt_method = build_corruption(opt, log, corrupt_type=corrupt_type)


    subset_dataset = ImageDataset(os.path.join('./datasets/imagenet', 'imagenet'),
                           os.path.join('./dataset', 'imagenet_val_100.txt'),
                           image_size=256,
                           normalize=True)

    val_loader = DataLoader(subset_dataset,
                            batch_size=opt.batch_size, shuffle=False, pin_memory=True, num_workers=1, drop_last=False,
                            )

    # build runner
    runner = Runner(ckpt_opt, log,)


    kernel_H = Blurkernel(blur_type='gaussian',
                   kernel_size=31,
                   std=3).get_kernel()

    kernel_A = Blurkernel(blur_type='gaussian',
                   kernel_size=31,
                   std=1.5).get_kernel()
    psnr_list = []
    ssim_list = []

    for loader_itr, out in enumerate(val_loader):
        id = out[-1][0]
        corrupt_img, x1, mask, cond, y, clean_img, x1_pinv, x1_forw = compute_batch(ckpt_opt, corrupt_type,
                                                                                    corrupt_method, out)

        xs, pred_x0s = runner.sharp_iteration(
            opt.nfe, x1, x1_forw, kernel_A=kernel_A, kernel_H=kernel_H, t_all=[200, 210, 220, 230, 240, 250])

        recon_img = pred_x0s[:, 0, ...].to(opt.device)

        if recon_img.shape[-1] != clean_img.shape[-1]:
            w = recon_img.shape[-1]
            clean_img = clean_img[:, :, :w, :w]


        assert recon_img.shape == corrupt_img.shape

        psnr = compare_psnr(pred_x0s[:, 0].cpu(), clean_img.cpu())
        ssim = compare_ssim(pred_x0s[:, 0].cpu(), clean_img.cpu())
        print(psnr)
        print(ssim)

        psnr_list.append(psnr)
        ssim_list.append(ssim)


        img = np.squeeze((recon_img.cpu()[0].permute(1,2,0) + 1) / 2)
        img = img[:, :, [2, 1, 0]].numpy() * 255
        cv2.imwrite('result/' + str(id) + "_psnr_{:.4f}".format(psnr) + "_recon.png", img)


    del runner


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--n-gpu-per-node", type=int, default=1, help="number of gpu on each node")
    parser.add_argument("--master-address", type=str, default='localhost', help="address for master")
    parser.add_argument("--node-rank", type=int, default=0, help="the index of node")
    parser.add_argument("--num-proc-node", type=int, default=1, help="The number of nodes in multi node env")

    # data
    parser.add_argument("--image-size", type=int, default=256)

    # sample
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--ckpt", type=str, default='./checkpoints/blur-gauss', help="the checkpoint name from which we wish to sample")
    parser.add_argument("--nfe", type=int, default=300, help="sampling steps")
    parser.add_argument("--clip-denoise", action="store_true", help="clamp predicted image to [-1,1] at each")
    parser.add_argument("--use-fp16", action="store_true", help="use fp16 network weight for faster sampling")

    arg = parser.parse_args()

    opt = edict(
        distributed=(arg.n_gpu_per_node > 1),
        device="cuda:3",
    )
    opt.update(vars(arg))

    # one-time download: ADM checkpoint
    # download_ckpt("data/")

    set_seed(opt.seed)

    if opt.distributed:
        size = opt.n_gpu_per_node

        processes = []
        for rank in range(size):
            opt = copy.deepcopy(opt)
            opt.local_rank = rank
            global_rank = rank + opt.node_rank * opt.n_gpu_per_node
            global_size = opt.num_proc_node * opt.n_gpu_per_node
            opt.global_rank = global_rank
            opt.global_size = global_size
            print('Node rank %d, local proc %d, global proc %d, global_size %d' % (
            opt.node_rank, rank, global_rank, global_size))
            p = Process(target=dist_util.init_processes, args=(global_rank, global_size, main, opt))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()
    else:
        torch.cuda.set_device(3)
        opt.global_rank = 0
        opt.local_rank = 0
        opt.global_size = 1
        dist_util.init_processes(0, opt.n_gpu_per_node, main, opt)

