import torch
import torch.nn as nn
from torchvision.utils import save_image
import os
import sys
import tqdm
import argparse
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(parent_dir)
from dataloader import Test_Loader
from MIMOUNet_prior import build_MIMOUnet_net
from utils import same_seed, count_parameters, judge_and_remove_module_dict
from Latent_BlurDM_arch import LE_arch
from Latent_BlurDM_arch import LatentExposureDiffusion as LatentBlurDM
import numpy as np
import cv2
from torchvision import transforms

class Normalize(object):
    def __init__(self, ZeroToOne=False):
        super(Normalize, self).__init__()
        self.ZeroToOne = ZeroToOne
        self.num = 0 if ZeroToOne else 0.5

    def __call__(self, data):
        for key in data.keys():
            if key != 'flow':
                data[key] = ((data[key] / 255) - self.num).copy()
        return data

class ToTensor(object):
    def __call__(self, data):
        for key in data.keys():
            data[key] = torch.from_numpy(data[key].transpose((2, 0, 1))).clone()
        return data
    
@torch.no_grad()
def predict(model, model_le, args, device):
    model.eval()
    model_le.eval()

    save_dir = os.path.join('./imgs/demo_output')
    os.makedirs(save_dir, exist_ok=True)

    transform = transforms.Compose([Normalize(), ToTensor()])

    blur = cv2.imread('./imgs/demo_input/'+args.img_name).astype(np.float32)
    blur = cv2.cvtColor(blur, cv2.COLOR_BGR2RGB)
    sample = {'blur': blur,}
    sample = transform(sample)

    input = sample['blur'].unsqueeze(0).to(device)
    b, c, h, w = input.shape
    factor=8
    h_n = (factor - h % factor) % factor
    w_n = (factor - w % factor) % factor
    input = torch.nn.functional.pad(input, (0, w_n, 0, h_n), mode='reflect')
    
    z_pred = model_le(input)
    output = model(input, z_pred)
    output = output[2][:, :, :h, :w]
    output = output.clamp(-0.5, 0.5)

    save_img_path = os.path.join(save_dir, args.img_name)

    save_image(output.squeeze(0).cpu() + 0.5, save_img_path)



if __name__ == "__main__":
    # hyperparameters
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", default=1, type=int)
    parser.add_argument("--img_name", default='GOPR0384_11_00-000002.png', type=str)
    parser.add_argument("--model_path", default='./weights/final_deblur_MIMO_UNetPlus_stage3.pth', type=str)
    parser.add_argument("--model_dm_path", default='./weights/final_dm_MIMO_UNetPlus_stage3.pth', type=str)
    parser.add_argument("--model", default='MIMO-UNetPlusPrior', type=str, choices=['MIMO-UNet', 'MIMO-UNetPlus'])
    parser.add_argument("--dataset", default='GoPro', type=str, choices=['GoPro+HIDE', 'GoPro', 'HIDE', 'Realblur_J', 'RealBlur_R', 'RWBI'])
    parser.add_argument("--crop_size", default=None, type=int)

    args = parser.parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device :", device)

    # Model and optimizer
    net = build_MIMOUnet_net(args.model)
    net_dm = LatentBlurDM()
    
    load_model_state = torch.load(args.model_path)
    load_le_model_state = torch.load(args.model_dm_path)

    if 'model_state' in load_model_state.keys():
        load_model_state["model_state"] = judge_and_remove_module_dict(load_model_state["model_state"])
        net.load_state_dict(load_model_state["model_state"])
    elif 'model' in load_model_state.keys():
        load_model_state["model"] = judge_and_remove_module_dict(load_model_state["model"])
        net.load_state_dict(load_model_state["model"])
    else:
        load_model_state = judge_and_remove_module_dict(load_model_state)
        net.load_state_dict(load_model_state)

    # if 'model_dm_state' in load_model_state.keys():
    load_le_model_state["model_dm_state"] = judge_and_remove_module_dict(load_le_model_state["model_dm_state"])
    net_dm.load_state_dict(load_le_model_state["model_dm_state"])

    net = nn.DataParallel(net)
    net.to(device)
    
    net_dm = nn.DataParallel(net_dm)
    net_dm.to(device)

    print("device:", device)
    print(f'args: {args}')
    # print(f'model: {net}')
    print(f'model parameters: {count_parameters(net)}')

    same_seed(2023)
    predict(net, net_dm, args=args, device=device)





