from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Sequence

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset


TRAIN_SEQS = ("Moon_2", "Moon_3", "Moon_4", "Moon_5", "Moon_7", "Moon_9")
VAL_SEQS = ("Moon_6",)
TEST_SEQS = ("Moon_1", "Moon_8")

INVALID_DEPTH_VALUE = 65504.0
MAX_DEPTH_METERS = 50.0


def read_pfm(path: str) -> np.ndarray:
    with open(path, "rb") as f:
        header = f.readline().decode("ascii", errors="ignore").strip()
        if header not in {"PF", "Pf"}:
            raise ValueError(f"Not a PFM file: {path}")
        color = header == "PF"

        dims = f.readline().decode("ascii", errors="ignore").strip()
        while dims.startswith("#") or dims == "":
            dims = f.readline().decode("ascii", errors="ignore").strip()
        width, height = map(int, dims.split())

        scale = float(f.readline().decode("ascii", errors="ignore").strip())
        endian = "<" if scale < 0 else ">"
        scale = abs(scale)

        count = width * height * (3 if color else 1)
        data = np.fromfile(f, endian + "f", count=count)
        if data.size != count:
            raise ValueError(f"Unexpected PFM data size in {path}")

        if color:
            data = data.reshape((height, width, 3))
        else:
            data = data.reshape((height, width))

        if scale != 1.0:
            data = data * scale
        return data.astype(np.float32)


def _load_rgb(path: Path) -> np.ndarray:
    img = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
    if img is None:
        raise FileNotFoundError(f"Could not load image: {path}")
    if img.ndim == 2:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    elif img.shape[2] == 4:
        img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
    else:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img.astype(np.float32) / 255.0


def _resize_with_aspect(
    image: np.ndarray,
    depth: np.ndarray,
    mask: np.ndarray,
    input_size: int,
    ensure_multiple_of: int = 14,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    h, w = image.shape[:2]
    scale_h = input_size / h
    scale_w = input_size / w
    scale = max(scale_h, scale_w)

    new_h = int(np.ceil((h * scale) / ensure_multiple_of) * ensure_multiple_of)
    new_w = int(np.ceil((w * scale) / ensure_multiple_of) * ensure_multiple_of)

    image_r = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
    depth_r = cv2.resize(depth, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
    mask_r = cv2.resize(mask.astype(np.float32), (new_w, new_h), interpolation=cv2.INTER_NEAREST)
    return image_r, depth_r, (mask_r > 0.5)


@dataclass(frozen=True)
class LusnarSample:
    image_path: Path
    depth_path: Path


class LusnarDepthDataset(Dataset):
    def __init__(
        self,
        root: str | Path,
        split: str,
        input_size: int = 518,
        max_depth: float = MAX_DEPTH_METERS,
    ) -> None:
        self.root = Path(root)
        self.split = split.lower()
        self.input_size = int(input_size)
        self.max_depth = float(max_depth)
        self.samples = self._build_index(self.root, self._split_to_sequences(self.split))
        if not self.samples:
            raise RuntimeError(f"No samples found for split='{self.split}' at {self.root}")

    @staticmethod
    def _split_to_sequences(split: str) -> Sequence[str]:
        if split == "train":
            return TRAIN_SEQS
        if split in {"val", "valid", "validation"}:
            return VAL_SEQS
        if split == "test":
            return TEST_SEQS
        raise ValueError(f"Unsupported split: {split}")

    @staticmethod
    def _build_index(root: Path, sequences: Sequence[str]) -> list[LusnarSample]:
        out: list[LusnarSample] = []
        for seq in sequences:
            image_dir = root / seq / "image0" / "images"
            depth_dir = root / seq / "image0" / "depth"
            if not image_dir.exists() or not depth_dir.exists():
                continue

            img_map: dict[str, Path] = {}
            for ext in ("*.png", "*.jpg", "*.jpeg"):
                for p in image_dir.glob(ext):
                    img_map[p.stem] = p

            for depth_path in sorted(depth_dir.glob("*.pfm")):
                image_path = img_map.get(depth_path.stem)
                if image_path is None:
                    continue
                out.append(LusnarSample(image_path=image_path, depth_path=depth_path))
        return out

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, index: int) -> dict[str, torch.Tensor | str]:
        sample = self.samples[index]
        image = _load_rgb(sample.image_path)
        depth = read_pfm(str(sample.depth_path)).astype(np.float32)

        valid_mask = np.isfinite(depth) & (depth > 0.0) & (depth < 65500.0) & (depth != INVALID_DEPTH_VALUE)

        image, depth, valid_mask = _resize_with_aspect(image, depth, valid_mask, self.input_size, ensure_multiple_of=14)

        image = (image - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)

        image_t = torch.from_numpy(np.transpose(image, (2, 0, 1)).astype(np.float32))
        depth_t = torch.from_numpy(depth.astype(np.float32))
        valid_t = torch.from_numpy(valid_mask.astype(np.bool_))

        return {
            "image": image_t,
            "depth": depth_t,
            "valid_mask": valid_t,
            "image_path": str(sample.image_path),
            "depth_path": str(sample.depth_path),
        }