from pathlib import Path
from typing import Union, Optional
import numpy as np
import os
from enum import IntEnum, StrEnum, Enum, auto
from itertools import product
import jax

class ArrayBackend(Enum):
    NUMPY = auto()
    JAX = auto()

class DOTmarkResolution(IntEnum):
    VERY_LOW = 32
    LOW = 64
    MEDIUM = 128
    HIGH = 256
    VERY_HIGH = 512

class DOTmarkClass(StrEnum):
    Cauchy_Density = "CauchyDensity"
    Classic_Images = "ClassicImages"
    GRF_moderate = "GRFmoderate"
    GRF_rough = "GRFrough"
    GRF_smooth = "GRFsmooth"
    Log_GRF = "LogGRF"
    Logit_GRF = "LogitGRF"
    Microscopy_Images = "MicroscopyImages"
    Shapes = "Shapes"
    White_Noise = "WhiteNoise"


class DOTmarkLoader:
    valid_class = [class_.value for class_ in DOTmarkClass]
    valid_resolution = [res.value for res in DOTmarkResolution]
    
    def __init__(
        self,
        dot_class: DOTmarkClass | list[DOTmarkClass],
        resolution: DOTmarkResolution | list[DOTmarkResolution],
        normalize: bool = False,
        array_backend: ArrayBackend = ArrayBackend.NUMPY
        ) -> None:
        self.data_folder: str = os.path.normpath(
            os.path.join(os.path.dirname(__file__), "../../../data")
        )
        dot_class: list[DOTmarkClass] = self._validate_class(dot_class)
        resolution: list[DOTmarkResolution] = self._validate_resolution(resolution)
        self.normalize = normalize
        self.array_backend = array_backend
        self._data = {class_: {res: self.load_resolution(class_, res) for res in resolution} for class_ in dot_class}
    
    def as_list(self,
                dot_class: DOTmarkClass | list[DOTmarkClass] = None,
                resolution: DOTmarkResolution | list[DOTmarkResolution]=None
                ) -> dict:
        if dot_class is None and resolution is None:
            return self._data
        
        if dot_class is None:
            dot_class = self.valid_class
        else:
            dot_class = self._validate_class(dot_class)
        if resolution is None:
            resolution = self.valid_resolution
        else:
            resolution = self._validate_resolution(resolution)
        return {class_:
            {res: class_dict[res] for res in class_dict.keys() if res in resolution}
                for class_, class_dict in self._data.items() if class_ in dot_class}
        
    def as_pairwise_list(self,
                         dot_class: DOTmarkClass | list[DOTmarkClass] = None,
                         resolution: DOTmarkResolution | list[DOTmarkResolution]=None
                         ) -> dict:
        data = self.as_list(dot_class, resolution)
        return {class_: {res: self._pairwise_list(array) for res, array in class_dict.items()}
                for class_, class_dict in data.items()}
        
    def as_pairwise_dict(self,
                        dot_class: DOTmarkClass | list[DOTmarkClass] = None,
                        resolution: DOTmarkResolution | list[DOTmarkResolution]=None
                        ) -> dict:
            data = self.as_list(dot_class, resolution)
            return {class_: {res: self._pairwise_dict(array) for res, array in class_dict.items()}
                    for class_, class_dict in data.items()}
    
    def _pairwise_list(self, array) -> list[tuple[np.ndarray, np.ndarray]]:
        return [(array[i], array[j]) for i, j in product(range(len(array)), repeat=2) if i != j]
    
    def _pairwise_dict(self, array) -> dict[tuple[int, int]: tuple[np.ndarray, np.ndarray]]:
        return {(i,j): (array[i], array[j]) for i, j in product(range(len(array)), repeat=2) if i != j}
    
    def _validate_class(self, dot_class: DOTmarkClass | list[DOTmarkClass]) -> list[DOTmarkClass]:
        if isinstance(dot_class, str):
            dot_class = [dot_class]
        if not set(dot_class).issubset(self.valid_class):
            raise ValueError(f"Invalid DOTmark class: {dot_class}")
        return dot_class
    
    def _validate_resolution(self, resolution: DOTmarkResolution | list[DOTmarkResolution]) -> list[DOTmarkResolution]:
        if isinstance(resolution, int):
            resolution = [resolution]
        if not set(resolution).issubset(self.valid_resolution):
            raise ValueError(f"Invalid DOTmark resolution: {resolution}")
        return resolution
    
    def _class_folder(self, class_name: str) -> Path:
        return Path(self.data_folder) / "DOTmark_1.0" / "Data" / class_name

    def _load_image(self, name: str, res: int, index: int) -> np.ndarray:
        file_path = self._class_folder(name) / f"data{int(res)}_10{str(index).zfill(2)}.csv"
        return np.loadtxt(file_path, delimiter=",", dtype=int)
    
    def load_image(self, name: DOTmarkClass, res: DOTmarkResolution, index: int) -> np.ndarray | jax.Array:
        image = self._load_image(name, res, index)
        if self.array_backend == ArrayBackend.JAX:
            image = jax.numpy.asarray(image).astype(jax.numpy.float32)
        if self.normalize:
            image = image / image.sum()
        return image

    def load_resolution(
        self, dot_class: DOTmarkClass, resolution: DOTmarkResolution) -> Union[np.ndarray, list[np.ndarray]]:
        return [self.load_image(name=dot_class, res=resolution, index=idx) for idx in range(1, 11)]


if __name__ == "__main__":
    loader = DOTmarkLoader(dot_class=DOTmarkClass.Microscopy_Images, resolution=[32, 64], normalize=True, array_backend=ArrayBackend.JAX)
    print(loader.as_pairwise_dict(resolution=32))
