# ===== file: data/video_stream_dataset.py =====
import os
import csv
import logging
from pathlib import Path
from typing import Iterable, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torio.io import StreamingMediaDecoder

log = logging.getLogger(__name__)


def error_avoidance_collate(batch):
    """Filter out None samples so that a single bad file doesn't crash the loader."""
    batch = [b for b in batch if b is not None]
    if not batch:
        return None
    from torch.utils.data.dataloader import default_collate
    return default_collate(batch)


def _read_paths_from_csv(
    csv_path: Union[str, Path],
    column: str = "video_folder",
) -> List[str]:
    """
    Read a list of video paths from a CSV file. Uses pandas if available,
    otherwise falls back to Python's csv module.
    """
    csv_path = str(csv_path)
    paths: List[str] = []
    try:
        import pandas as pd  # type: ignore
        df = pd.read_csv(csv_path)
        if column not in df.columns:
            raise ValueError(f"Column '{column}' not found in CSV: {csv_path}")
        # dropna & strip
        series = df[column].dropna().astype(str).map(lambda s: s.strip())
        paths = [p for p in series.tolist() if len(p) > 0]
    except Exception as e_pd:
        log.warning(f"pandas failed to read CSV ({e_pd}); falling back to csv module.")
        with open(csv_path, "r", newline="") as f:
            reader = csv.DictReader(f)
            if column not in reader.fieldnames:
                raise ValueError(f"Column '{column}' not found in CSV header: {reader.fieldnames}")
            for row in reader:
                val = (row.get(column) or "").strip()
                if val:
                    paths.append(val)
    return paths

class OnlineVideoDataset_2(Dataset):
    """
    Stream videos with torio and return a tensor:
      {'video_frames': [T, 3, H, W] in [0,1] float32, 'filename': str, 'path': str}

    - If a video is shorter than required frames, returns None (use error_avoidance_collate).
    - Optional post-decode resize to (H, W) with bilinear (align_corners=False).
    """

    def __init__(
        self,
        video_paths: Iterable[Union[str, Path]],
        *,
        start_time: float,
        duration_sec: float,
        fps: float,
        resize_hw: Optional[Tuple[int, int]] = None,  # (H, W) post-decode resize; None = no resize
        drop_short: bool = False,                     # if True, raise on short video; else return None
        error_log_path: Optional[Union[str, Path]] = None,
    ):
        self.paths = [str(p) for p in video_paths]
        self.start = float(start_time)
        self.duration = float(duration_sec)
        self.fps = float(fps)
        self.resize_hw = resize_hw
        self.drop_short = drop_short
        self.error_log_path = str(error_log_path) if error_log_path else None

        # expected number of frames
        self._T = int(self.fps * self.duration)

    @classmethod
    def from_csv(
        cls,
        csv_path: Union[str, Path],
        *,
        column: str = "video_folder",
        start_time: float,
        duration_sec: float,
        fps: float,
        resize_hw: Optional[Tuple[int, int]] = None,
        drop_short: bool = False,
        error_log_path: Optional[Union[str, Path]] = None,
        exist_filter: bool = True,  # filter out non-existing files
    ):
        paths = _read_paths_from_csv(csv_path, column=column)
        if exist_filter:
            before = len(paths)
            paths = [p for p in paths if os.path.exists(p)]
            removed = before - len(paths)
            if removed > 0:
                log.warning(f"Filtered {removed} non-existing files; {len(paths)} remain.")
        return cls(
            paths,
            start_time=start_time,
            duration_sec=duration_sec,
            fps=fps,
            resize_hw=resize_hw,
            drop_short=drop_short,
            error_log_path=error_log_path,
        )

    def __len__(self) -> int:
        return len(self.paths)

    def _log_error(self, msg: str):
        log.error(msg)
        if self.error_log_path:
            try:
                with open(self.error_log_path, "a") as f:
                    f.write(msg.rstrip() + "\n")
            except Exception as e_io:
                log.warning(f"Failed to write error log: {e_io}")

    def __getitem__(self, idx: int):
        path = self.paths[idx]
        reader = StreamingMediaDecoder(path)
        kw = dict(frames_per_chunk=self._T, frame_rate=self.fps, format="rgb24")
        if self.resize_hw is not None:
            Ht, Wt = self.resize_hw
            kw.update(height=int(Ht), width=int(Wt))
            
        reader.add_basic_video_stream(**kw)
        if self.start > 0:
            reader.seek(self.start)
        reader.fill_buffer()

        chunks = reader.pop_chunks()
        vid = chunks[0] if isinstance(chunks, (list, tuple)) else chunks

        if vid is None or vid.shape[0] < self._T:
            msg = f"Short video ({'None' if vid is None else vid.shape[0]} < {self._T}): {path}"
            if self.drop_short: raise RuntimeError(msg)
            self._log_error(msg); return None
        vid = vid[: self._T]  # [T, 3, H, W]
        if isinstance(vid, torch.Tensor):
            assert vid.dtype == torch.uint8
            frames_u8 = vid.contiguous()  # [T,3,H,W] uint8
        else:
            import numpy as np
            assert isinstance(vid, np.ndarray) and vid.dtype == np.uint8
            frames_u8 = torch.from_numpy(vid).contiguous()

        return {"video_frames": frames_u8, "filename": Path(path).stem, "path": path}
    
    
    
    
