# -----------------------------------------------------------------------------------
# SemanIR: Sharing Key Semantics in Transformer Makes Efficient Image Restoration
# -----------------------------------------------------------------------------------


from model import util
import os
from glob import glob
import argparse
import torch
from model.semanir_unet import SemanIRUnet


def get_data(testset):
    img_list = sorted(glob(f"./dataset/{testset}/*.tif"))
    return img_list


def load_model(color="color", noise_level=50):

    # SemanIRUnet
    model = SemanIRUnet(
        img_size=128,
        in_channels=3,
        out_channels=3,
        dim=48,
        window_size=8,
        top_k=16,
        num_blocks=[4, 6, 6, 8],
        num_refinement_blocks=4,
        heads=[1, 2, 4, 8],
        mlp_ratio=4,
        bias=True,
        dual_pixel_task=False,  ## True for dual-pixel defocus deblurring only. Also set in_channels=6
        version="v2"
    )

    from model.common import model_analysis
    model_analysis(model)

    checkpoint_path = f"./model_zoo/dn_{color}_c_sigma{noise_level}.ckpt"
    if os.path.isfile(checkpoint_path):
        state_dict = torch.load(checkpoint_path)
        model.load_state_dict(state_dict, strict=True)

    # print(model)
    return model


def main(args):
    # Data
    img_list = get_data(args.test_set)
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = "cpu"
    # Model
    model = load_model("color" if args.color == "rgb" else "gray", args.noise_level)
    model = model.to(device)

    for img_path in img_list:
        img_gt = util.imread_uint(img_path, n_channels=3)
        noise = util.add_noise(img_gt, args.noise_level, img_path)
        img_lq = util.uint2tensor4(img_gt + noise, args.data_range)
        img_lq = img_lq.to(device)

        img_hq = model(img_lq)
        img_hq = util.tensor2uint(img_hq, args.data_range)

        save_path = "./results/dn"
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        util.imsave(img_hq, os.path.join(save_path, os.path.basename(img_path)))


if __name__ == "__main__":
    parser = argparse.ArgumentParser("SemanIR Dn test code")
    parser.add_argument("--data_range", default=1.0, type=float)
    parser.add_argument("--test_set", default="McMaster", type=str)
    parser.add_argument("--noise_level", default=50, type=int)
    parser.add_argument("--color", default="rgb", type=str)
    args = parser.parse_args()
    main(args)
