import io
import re
import json
import random
from typing import List, Dict, Any, Callable, Sequence, Union

import av
import numpy as np
from PIL import Image

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF  # noqa: F401  (kept for parity)

import webdataset as wds
from torch.utils.data import IterableDataset

try:
    # from torchcodec.decoders import VideoDecoder
    VideoDecoder = None  # type: ignore
except Exception:
    print("No TorchCodec :(")
    VideoDecoder = None  # type: ignore


# -----------------------------------------------------------------------------
# Helper utilities
# -----------------------------------------------------------------------------


def _extract_sequences(frames: torch.Tensor, sequence_length: int) -> List[torch.Tensor]:
    """Return sliding‑window sequences of length *sequence_length* from *frames* (T,C,H,W)."""
    if sequence_length is None:
        return [frames]

    T = frames.shape[0]
    if T < sequence_length:
        return []

    # unfold returns a *view*; contiguous() makes a copy so we can safely permute
    sequences_tensor = frames.unfold(0, sequence_length, sequence_length)  # (N,C,H,W,sequence_length)
    return [
        sequences_tensor[i].contiguous().permute(3, 0, 1, 2)  # (sequence_length,C,H,W)
        for i in range(sequences_tensor.size(0))
    ]


def dict_collate_fn(batch: List[Dict[str, Any]], device: torch.device) -> Dict[str, Any]:
    """Stack tensors and move them to *device*; non‑tensors are left untouched."""
    collated: Dict[str, Any] = {}
    for key in batch[0]:
        values = [sample[key] for sample in batch]
        if isinstance(values[0], torch.Tensor):
            collated[key] = torch.stack(values).to(device, non_blocking=True)
        else:
            collated[key] = values
    return collated


# -----------------------------------------------------------------------------
# Selective decoding helpers — each keeps its own resources *scoped*.
# -----------------------------------------------------------------------------


def _decode_and_process_frames_torchcodec(
    video_bytes: bytes,
    indices: Sequence[int],
    transform: Callable[[torch.Tensor], torch.Tensor] | None,
) -> torch.Tensor:
    """Decode selected *indices* using TorchCodec, with proper cleanup."""
    if VideoDecoder is None:
        raise RuntimeError("TorchCodec requested but not available")

    with VideoDecoder(video_bytes) as decoder:  # type: ignore[call-arg]
        frame_batch = decoder.get_frames_at(indices).data  # (T,H,W,C)
    frame_batch = torch.from_numpy(frame_batch).permute(0, 3, 1, 2)  # -> (T,C,H,W)
    if transform is not None:
        frame_batch = transform(frame_batch)
    return frame_batch


def _decode_and_process_frames_pyav(
    video_bytes: bytes,
    indices: Sequence[int],
    transform: Callable[[torch.Tensor], torch.Tensor] | None,
) -> torch.Tensor:
    """Decode selected *indices* with PyAV, closing the container immediately."""
    indices_set = set(indices)
    last_needed = max(indices)
    frames_dict: Dict[int, np.ndarray] = {}

    with av.open(io.BytesIO(video_bytes)) as container:
        frame_idx = 0
        for packet in container.demux(video=0):
            for frame in packet.decode():
                if frame_idx in indices_set:
                    frames_dict[frame_idx] = frame.to_ndarray(format="rgb24")
                frame_idx += 1
                if frame_idx > last_needed:
                    break
            if frame_idx > last_needed:
                break

    frames = [frames_dict[i] for i in indices]
    frames_tensor = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2)  # (T,C,H,W)
    if transform is not None:
        frames_tensor = transform(frames_tensor)
    return frames_tensor


def _decode_images(
    sample: Dict[str, Any],
    indices: Sequence[int],
    transform: Callable[[torch.Tensor], torch.Tensor] | None,
) -> torch.Tensor:
    """Decode RGB JPEG frames stored as separate keys inside *sample*."""
    image_keys = sorted(k for k in sample if re.search(r"^jpg\d*$", k, re.IGNORECASE))
    image_keys = [image_keys[i] for i in indices]

    frames: List[torch.Tensor] = []
    for key in image_keys:
        try:
            with Image.open(io.BytesIO(sample[key])) as img:
                frame = torch.from_numpy(np.array(img.convert("RGB"))).permute(2, 0, 1)
            frames.append(frame)
        except Exception as exc:
            print("Error decoding image", key, exc)
    frames_tensor = torch.stack(frames)  # (T,C,H,W)
    if transform is not None:
        frames_tensor = transform(frames_tensor)
    return frames_tensor


# -----------------------------------------------------------------------------
# Video decoding entry point
# -----------------------------------------------------------------------------


def _probe_pyav(video_bytes: bytes) -> tuple[int, float | None]:
    with av.open(io.BytesIO(video_bytes)) as container:
        if not container.streams.video:
            raise RuntimeError("No video stream found")
        stream = container.streams.video[0]
        n_frames = stream.frames if stream.frames > 0 else 10_000  # fallback
        native_fps = float(stream.average_rate) if stream.average_rate else None
        return n_frames, native_fps


def _probe_torchcodec(video_bytes: bytes) -> tuple[int, float | None]:
    if VideoDecoder is None:
        raise RuntimeError("TorchCodec not available")
    with VideoDecoder(video_bytes) as decoder:  # type: ignore[call-arg]
        meta = decoder.metadata
        n_frames = meta.num_frames
        native_fps = float(meta.average_fps) if meta.average_fps else None
        return n_frames, native_fps


def decode_video(
    video_source: Union[bytes, Dict[str, Any]],
    sequence_length: int | None,
    n_sequences_per_video: int | None,
    transform: Callable[[torch.Tensor], torch.Tensor] | None,
    target_fps: int | None,
    max_sequence_length: int | None = None,
) -> List[torch.Tensor]:
    """Decode *video_source* into a list of sequences respecting *sequence_length* etc."""

    # ------------------------------------------------------------------
    # 1. Identify source type & get basic metadata
    # ------------------------------------------------------------------
    if isinstance(video_source, dict):  # Image sequence stored as separate keys
        n_frames = len([k for k in video_source if re.search(r"^jpg\d*$", k, re.IGNORECASE)])
        native_fps = 25  # sensible default for image folders
        decode_fn = _decode_images
    elif isinstance(video_source, bytes):  # Actual video bitstream
        if VideoDecoder is None:
            n_frames, native_fps = _probe_pyav(video_source)
            decode_fn = _decode_and_process_frames_pyav
        else:
            try:
                n_frames, native_fps = _probe_torchcodec(video_source)
                decode_fn = _decode_and_process_frames_torchcodec
            except Exception:
                # Fallback to PyAV if TorchCodec fails at run‑time
                n_frames, native_fps = _probe_pyav(video_source)
                decode_fn = _decode_and_process_frames_pyav
    else:
        raise TypeError("Unsupported video source type")

    # ------------------------------------------------------------------
    # 2. Build index list according to target FPS
    # ------------------------------------------------------------------
    step = 1
    if target_fps and native_fps and native_fps > target_fps:
        step = int(round(native_fps / target_fps))
    sampled_indices = list(range(0, n_frames, step))

    if sequence_length is None:
        video = decode_fn(video_source, sampled_indices, transform)
        if max_sequence_length is not None:
            video = video[:max_sequence_length]
        return [video], native_fps # TODO

    # Need at least one full sequence
    if len(sampled_indices) < sequence_length:
        return [], None

    if n_sequences_per_video == 1:
        start_idx = 0  # could randomise if desired
        seq_indices = sampled_indices[start_idx : start_idx + sequence_length]
        sequences = [decode_fn(video_source, seq_indices, transform)]
    else:
        # Decode once, then slice
        full_video = decode_fn(video_source, sampled_indices, transform)
        sequences = _extract_sequences(full_video, sequence_length)
        if n_sequences_per_video is not None and len(sequences) >= n_sequences_per_video:
            sequences = random.sample(sequences, n_sequences_per_video)

    # Final clean‑up / trimming
    sequences = [seq for seq in sequences if seq.shape[0] == sequence_length]
    if max_sequence_length is not None:
        sequences = [seq[:max_sequence_length] for seq in sequences]
    return sequences, native_fps


# -----------------------------------------------------------------------------
# Dataset & Loader classes
# -----------------------------------------------------------------------------
class MinimalVideoDataset(IterableDataset):
    """Iterable WebDataset‑based dataset that yields pre‑processed frame sequences."""

    def __init__(
        self,
        tar_paths: List[str],
        shardshuffle: bool = True,
        repeat: int | None = 1,
        n_sequences_per_video: int | None = None,
        sequence_length: int | None = None,
        max_sequence_length: int | None = None,
        transform: Callable[[torch.Tensor], torch.Tensor] | None = None,
        fps: int | None = None,
        multi_node: bool = True,
    ) -> None:
        super().__init__()
        self.tar_paths = tar_paths
        self.shardshuffle = shardshuffle
        self.repeat = repeat
        self.n_sequences_per_video = n_sequences_per_video
        self.sequence_length = sequence_length
        self.max_sequence_length = max_sequence_length
        self.transform = transform
        self.fps = fps
        self.multi_node = multi_node

    # --------------------------------------------------
    # Utility helpers for brace‑expansion of shard patterns
    # --------------------------------------------------
    @staticmethod
    def _is_shard_pattern(path: str) -> bool:
        return bool(re.search(r"\{\d+\.\.\d+\}", path))

    def _expand_shard_patterns(self, patterns: List[str]) -> List[str]:
        brace = re.compile(r"\{(\d+)\.\.(\d+)\}")

        def _expand_one(p: str) -> List[str]:
            m = brace.search(p)
            if not m:
                return [p]
            start, end = int(m.group(1)), int(m.group(2))
            width = len(m.group(1))
            return [
                sub
                for i in range(start, end + 1)
                for sub in _expand_one(p[: m.start()] + str(i).zfill(width) + p[m.end() :])
            ]

        out: List[str] = []
        for pat in patterns:
            out.extend(_expand_one(pat) if self._is_shard_pattern(pat) else [pat])
        return out

    # --------------------------------------------------
    # Core sample decoding
    # --------------------------------------------------
    def _decode_sample(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
        video_name = sample.get("__key__")
        url = sample.get("__url__")

        # -------- class label heuristics --------
        cls_label = sample.get("cls")
        if cls_label is None:
            meta_blob = next((sample[k] for k in sample if k.endswith("json")), None)
            if meta_blob is not None:
                try:
                    meta = json.loads(meta_blob)
                    for key in ("class", "labels", "label", "actions", "category"):
                        if key in meta:
                            cls_label = meta[key]
                            break
                except json.JSONDecodeError:
                    pass

        # -------- caption heuristics --------
        caption = sample.get("txt", None)

        # -------- choose video bytes or image‑dict --------
        video_bytes = next((sample[ext] for ext in ("video.mpg", "mp4", "avi", "webm") if ext in sample), None)
        if video_bytes is None:  # image sequence instead of encoded video
            video_source: Union[bytes, Dict[str, Any]] = sample  # type: ignore[assignment]
        else:
            video_source = video_bytes  # type: ignore[assignment]

        try:
            sequences, fps = decode_video(
                video_source,
                self.sequence_length,
                self.n_sequences_per_video,
                self.transform,
                self.fps,
                self.max_sequence_length,
            )
        except Exception as exc:
            print("Error decoding video", video_name, exc)
            return []

        return [
            {
                "data": seq,
                "video_name": video_name,
                "pts": None,
                "cls": cls_label,
                "filename": url,
                "fps": fps,
                "caption": caption,
            }
            for seq in sequences
        ]

    # --------------------------------------------------
    # __iter__
    # --------------------------------------------------
    def __iter__(self):  # type: ignore[override]
        ds = wds.WebDataset(
            self._expand_shard_patterns(self.tar_paths),
            nodesplitter=wds.shardlists.split_by_node if self.multi_node else lambda shards: shards,
            shardshuffle=self.shardshuffle,
            handler=wds.warn_and_continue,
            empty_check=False,
        )

        if self.repeat:
            ds = ds.repeat(self.repeat)
        if self.shardshuffle:
            ds = ds.shuffle(1_000)

        for sample in ds:
            seqs = self._decode_sample(sample)
            if seqs:
                yield from seqs


class MinimalVideoLoader:
    """Thin wrapper around WebLoader + our Dataset."""

    def __init__(
        self,
        tar_paths: List[str],
        shardshuffle: bool = True,
        repeat: int | None = 1,
        n_sequences_per_video: int | None = None,
        sequence_length: int | None = None,
        max_sequence_length: int | None = None,
        transform: Callable[[torch.Tensor], torch.Tensor] | None = None,
        fps: int | None = None,
        batch_size: int = 1,
        num_workers: int = 0,
        device: str | torch.device = "cpu",
        prefetch_factor: int = 2,
        multi_node: bool = True,
    ) -> None:
        self.dataset = MinimalVideoDataset(
            tar_paths,
            shardshuffle,
            repeat,
            n_sequences_per_video,
            sequence_length,
            max_sequence_length,
            transform,
            fps,
            multi_node,
        )
        self.device = torch.device(device)
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.prefetch_factor = prefetch_factor

    def __iter__(self):  # type: ignore[override]
        loader = wds.WebLoader(
            self.dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=lambda batch: dict_collate_fn(batch, self.device),
            pin_memory=True,
            prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None,
            drop_last=True,
            shuffle=False,
        )
        for batch in loader:
            batch["x"] = batch["data"].permute(0, 2, 1, 3, 4)  # (B,T,C,H,W) -> (B,C,T,H,W)
            yield batch


# -----------------------------------------------------------------------------
# Convenience helper for users
# -----------------------------------------------------------------------------


def get_minimal_video_transform(size: int) -> Callable[[torch.Tensor], torch.Tensor]:
    def _transform(frames: torch.Tensor) -> torch.Tensor:
        frames = frames.float().div(255)
        frames = F.interpolate(frames, size=(size, size), mode="bilinear", align_corners=False)
        return (frames - 0.5) * 2

    return _transform
