from typing import Tuple, Optional
from dataclasses import dataclass

import numpy as np
import torch

from src.models.gs.cameras import Cameras


@dataclass
class ImageSet:
    image_names: list

    image_paths: list
    """ Full path to the image file """

    cameras: Cameras

    mask_paths: Optional[list] = None
    """ Full path to the mask file """

    depth_paths: Optional[list] = None
    """ Full path to the depth file """

    normal_paths: Optional[list] = None
    """ Full path to the normal file """


    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, index):
        return {
            "image_name": self.image_names[index],
            "image_path": self.image_paths[index],
            "camera": self.cameras[index],
            "mask_path": self.mask_paths[index] if self.mask_paths is not None else None,
            "depth_path": self.depth_paths[index] if self.depth_paths is not None else None,
            "normal_path": self.normal_paths[index] if self.normal_paths is not None else None,
        }

    def __iter__(self):
        for i in range(len(self)):
            yield self[i]

    def __post_init__(self):
        if self.mask_paths is None:
            self.mask_paths = [None for _ in range(len(self.image_paths))]


@dataclass
class PointCloud:
    xyz: np.ndarray

    rgb: np.ndarray


@dataclass
class DataParserOutputs:
    train_set: ImageSet

    val_set: ImageSet

    test_set: ImageSet

    point_cloud: PointCloud

    # ply_path: str

    appearance_group_ids: Optional[dict]

    camera_extent: Optional[float] = None

    def __post_init__(self):
        if self.camera_extent is None:
            camera_centers = self.train_set.cameras.camera_center
            average_camera_center = torch.mean(camera_centers, dim=0)
            camera_distance = torch.linalg.norm(camera_centers - average_camera_center, dim=-1)
            max_distance = torch.max(camera_distance)
            self.camera_extent = float(max_distance * 1.1)


class DataParser:
    def get_outputs(self) -> DataParserOutputs:
        """
        :return: [training set, validation set, point cloud]
        """

        pass
