from dataclasses import dataclass
from pathlib import Path
from cv2 import Mat, cvtColor, COLOR_RGB2BGR, COLOR_RGB2YUV, COLOR_BGR2RGB, COLOR_BGR2YUV, COLOR_YUV2BGR, COLOR_YUV2RGB
from warnings import warn 

IMAGE_CLASS = ['LR', 'SR', 'HR', 'RF']
AVAIABLE_FORMATS = ['bgr', 'rgb', 'yuv']

# LR stands for 'Low Resolution'
# HR stands for 'High Resolution'
# SR stands for 'SR image' - result of function SR(LR)
# RF stands for 'ReFerence image' which used as alternative for SR 

class SRImage:
    path_LR: Path
    path_HR: Path
    path_RF: Path
    path_SR: Path
    mat_LR: Mat
    mat_HR: Mat
    mat_RF: Mat
    mat_SR: Mat
    format : str

    format_warning = True

    def __init__(
            self, 
            path_LR : Path = None, path_SR : Path = None, path_HR : Path = None, path_RF : Path = None, 
            mat_LR : Mat = None, mat_SR : Mat = None, mat_HR : Mat = None, mat_RF : Mat = None, 
            format : str = None
        ):
        is_path_provided = not ((path_HR is None) and (path_LR is None) and (path_RF is None) and (path_SR is None))
        is_mat_provided = not ((mat_HR is None) or (mat_LR is None) or (mat_RF is None) or (mat_SR is None))

        if not (is_path_provided or is_mat_provided):
            raise ValueError("Niether path or cv2.Mat provided")
        
        if path_LR:
            self.path_LR = Path(path_LR)
        if path_SR:
            self.path_SR = Path(path_SR)
        if path_HR:
            self.path_HR = Path(path_HR)
        if path_RF:
            self.path_RF = Path(path_RF)

        self.mat_LR = Mat(mat_LR) if mat_LR is not None else None
        self.mat_HR = Mat(mat_HR) if mat_HR is not None else None
        self.mat_RF = Mat(mat_RF) if mat_RF is not None else None
        self.mat_SR = Mat(mat_SR) if mat_SR is not None else None

        if (format is None) and self.format_warning:
            warn("Image format didn't provided, please specify format (default='bgr')")
            self.format_warning = False

        self.format = format if format else 'bgr'

    @staticmethod
    def convert_image(image : Mat, format_from : str, format_to : str, dst : Mat = None):
        
        convert = {'rgb' : {'yuv' : COLOR_RGB2YUV, 'bgr' : COLOR_RGB2BGR}, 'yuv' : {'rgb' : COLOR_YUV2RGB, 'bgr' : COLOR_YUV2BGR}, 'bgr' : {'yuv' : COLOR_BGR2YUV, 'rgb' : COLOR_BGR2RGB}}
        res = cvtColor(image, convert[format_from][format_to], dst=dst)
        return res
    
    def convert(self, format : str, var : str = None):
        """
        Convert all images to provided format. If var provided (select from 'LR', 'SR', 'HR', 'RF') doesn't change format
        """

        if format not in AVAIABLE_FORMATS:
            raise ValueError(f"Incorect image formate, choose from {', '.join(AVAIABLE_FORMATS)}")

        self_convert = {'LR' : self.mat_LR, 'SR' : self.mat_SR, 'HR' : self.mat_HR, 'RF' : self.mat_RF}
        local_convert_names = IMAGE_CLASS

        if var:
            if var not in local_convert_names:
                raise ValueError("Provide var name from ('LR', 'SR', 'HR', 'RF')")
            local_convert_names = [var]

        for img_class in local_convert_names:
            img = self_convert[img_class]
            SRImage.convert_image(img, self.format, format, dst=img)

        if not var:
            self.format = format