from model import util
import os
from glob import glob
import argparse
import torch
from model.hiir_unet import HiIRUnet


def get_data(testset):
    img_list = sorted(glob(f"./dataset/{testset}/*.tif"))
    return img_list


def load_model(color="color", noise_level=50):

    # HiIRUnet Base
    model = HiIRUnet(
        img_size=256,
        in_channels=3,
        out_channels=3,
        embed_dim=128,
        depths="4+4+4+4",
        num_heads="4+4+8+16",
        expansion_ratio="1+1+2+4",
        window_size=32,
        grid_size=16,
        global_size=128,
        shift=True,
        qkv_conv=False,
        qk_reduce=True,
        version="v2",
        mlp_ratio=2.0,
        mlp_type="locality",
        mlp_kernel_size=3,
        multiple_degradation=False,
        subsample_type="simple",
        dual_conv=False,
        dual_pixel_task=False,  ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
        fairscale_checkpoint=True,
    )

    from model.common import model_analysis
    model_analysis(model)

    checkpoint_path = f"./model_zoo/dn_{color}_c_sigma{noise_level}.ckpt"
    state_dict = torch.load(checkpoint_path)
    model.load_state_dict(state_dict, strict=True)

    # print(model)
    return model


def main(args):
    # Data
    img_list = get_data(args.test_set)
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = "cpu"
    # Model
    model = load_model("color" if args.color == "rgb" else "gray", args.noise_level)
    model = model.to(device)

    for img_path in img_list:
        img_gt = util.imread_uint(img_path, n_channels=3)
        noise = util.add_noise(img_gt, args.noise_level, img_path)
        img_lq = util.uint2tensor4(img_gt + noise, args.data_range)
        img_lq = img_lq.to(device)

        img_hq = model(img_lq)
        img_hq = util.tensor2uint(img_hq, args.data_range)

        save_path = "./results/dn"
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        util.imsave(img_hq, os.path.join(save_path, os.path.basename(img_path)))


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Hi-IR Dn test code")
    parser.add_argument("--data_range", default=1.0, type=float)
    parser.add_argument("--test_set", default="McMaster", type=str)
    parser.add_argument("--noise_level", default=50, type=int)
    parser.add_argument("--color", default="rgb", type=str)
    args = parser.parse_args()
    main(args)
