import cv2

from .inference import *


def test_func():
    """
    Mock super-resolution to test the functions in this file.
    This function uses all the functions in this file except the model itself.
    Bicubic interpolation is used instead of FSRCNN for upscaling.
    """
    input_img = ".data/input.png"
    output_path = ".data/output.png"

    print("Test function running...")

    img = cv2.imread(input_img)
    img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25)

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    y, crcb = process_img(img)
    y = cv2.resize(y, (0, 0), fx=4, fy=4)[..., None] # FSRCNN does this
    tensor = prepare_input([y], "cuda")

    y = get_y_channels(tensor)[0]
    crcb = cv2.resize(crcb, y.shape[1::-1], interpolation=cv2.INTER_CUBIC)
    img = rev_process_img(y, crcb)
    sr_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

    cv2.imwrite(output_path, sr_img)
    print("Output image written to", output_path)


def main():
    "Perform super-resolution on a single image."

    input_img = ".data/input.png"
    output_path = ".data/output.png"
    model_path = ".artifacts/pretrained/fsrcnn_x4-T91-97a30bfb.pth.tar"
    scaling_factor = 4
    device = "cuda" if torch.cuda.is_available() else "cpu"

    print("Running on", device.upper())
    model = FSRCNNInference(model_path, scaling_factor, device)

    img = cv2.imread(input_img)

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    sr_img = model([img])[0]
    sr_img = cv2.cvtColor(sr_img, cv2.COLOR_RGB2BGR)

    cv2.imwrite(output_path, sr_img)
    print("Output image written to", output_path)


if __name__ == "__main__":
    main()
