from model import util
import os
from glob import glob
import argparse
import torch
from model.hiir import HiIR


def get_data(testset, scale):
    img_list = sorted(glob(f"./dataset/{testset}_X{scale}/*.png"))
    return img_list


def load_model(upscale=4, version="base"):

    if version == "base":
        # Hi-IR Base
        model = HiIR(
            upscale=upscale,
            window_size=16,
            grid_size=16,
            global_size=128,
            img_size=64,
            img_range=1.0,
            depths=[6, 6, 6, 6, 6, 6],
            embed_dim=180,
            num_heads=[6, 6, 6, 6, 6, 6],
            mlp_ratio=2,
            conv_type="3conv",
            shift=True,
            upsampler="pixelshuffle",
        )
    else:
        raise NotImplementedError(f"Model version {version} not implemented!")

    checkpoint_path = f"./model_zoo/sr_{version}_c_x{upscale}.ckpt"
    state_dict = torch.load(checkpoint_path)
    model.load_state_dict(state_dict, strict=True)

    # print(model)
    return model

def main(args):
    # Data
    img_list = get_data(args.test_set, args.scale)
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = "cpu"
    # Model
    model = load_model(args.scale, args.model_size)
    model = model.to(device)

    for img_path in img_list:
        img_lr = util.imread_uint(img_path, n_channels=3)
        img_lr = util.uint2tensor4(img_lr, args.data_range)
        img_lr = img_lr.to(device)

        img_sr = model(img_lr)
        img_sr = util.tensor2uint(img_sr, args.data_range)

        save_path = "./results/sr"
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        util.imsave(img_sr, os.path.join(save_path, os.path.basename(img_path)))


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Hi-IR SR test code")
    parser.add_argument("--data_range", default=1.0, type=float)
    parser.add_argument("--test_set", default="Set5", type=str)
    parser.add_argument("--scale", default=4, type=int)
    parser.add_argument("--model_size", default="base", type=str)
    args = parser.parse_args()
    main(args)
