import cv2
import torch
import numpy as np

from .model import FSRCNN


class FSRCNNInference:

    def __init__(self, model_path, scaling_factor=4, device="cpu"):
        self.device = torch.device(device)
        self.model = FSRCNN(scaling_factor)
        self.model.load_state_dict(torch.load(model_path, self.device, weights_only=True)["state_dict"])
        self.model.to(self.device)
        self.model.eval()
        for p in self.model.parameters():
            p.requires_grad = False

    def __call__(self, imgs: list[np.ndarray]) -> torch.Tensor:
        "imgs: List of RGB images."

        y_list, crcb_list = [], []
        for img in imgs:
            y, crcb = process_img(img)
            y_list.append(y)
            crcb_list.append(crcb)

        in_tensor = prepare_input(y_list, self.device)
        with torch.no_grad(), torch.autocast(str(self.device)):
            out_tensor = self.model(in_tensor).clamp(0, 1.)

        y_list = get_y_channels(out_tensor);

        out_imgs = []
        for y, crcb in zip(y_list, crcb_list):
            crcb = cv2.resize(crcb, y.shape[1::-1], interpolation=cv2.INTER_CUBIC)
            out_imgs.append(rev_process_img(y, crcb))

        return out_imgs


def process_img(img: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """
    Process the image for FSRCNN.

    Args:
        img: RGB image.

    Returns:
        The Y-channel and the CrCb channels seperated. The pixel values are normalized to [0, 1].
    """
    img = img.astype(np.float32) / 255
    img = img @ [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]] + [16, 128, 128]
    img /= 255
    Y, CrCb = img[..., :1].astype(np.float32), img[..., 1:]
    return Y, CrCb


def rev_process_img(Y: np.ndarray, CrCb: np.ndarray) -> np.ndarray:
    """
    Convert to an RGB image.
    """
    img = np.concat([Y, CrCb], axis=2)
    img *= 255
    img = img @ [[0.00456621, 0.00456621, 0.00456621],
                 [0, -0.00153632, 0.00791071],
                 [0.00625893, -0.00318811, 0]] * 255.0 + [-222.921, 135.576, -276.836]
    return img.astype(np.uint8)


def prepare_input(y_chls: list[np.ndarray], device = "cpu") -> torch.Tensor:
    """
    Convert y_chls to a tensor which can be fed to the FSRCNN.

    Args:
        y_chls: A list of the Y-channels of the input images.

    Returns:
        A tensor that can be provided to the FSRCNN.
    """
    y_stack = np.stack(y_chls, 0)
    y_tensor = torch.from_numpy(y_stack) # .half()
    y_tensor = y_tensor.to(torch.device(device))
    return y_tensor.permute(0, 3, 1, 2)


def get_y_channels(y_tensor: torch.Tensor) -> list[np.ndarray]:
    y_stack = y_tensor.permute(0, 2, 3, 1).cpu().numpy()
    y_chls = [y_stack[i] for i in range(y_stack.shape[0])]
    return y_chls
