import os, random, io
from pathlib import Path
from typing import Optional, Dict, Set, Any

import torch
import webdataset as wds
import pandas as pd
import re
from glob import glob
import math
from torch.utils.data import default_collate
from torch.utils.data import IterableDataset

def tensor_only_collate(batch):
    elem = batch[0]
    if isinstance(elem, dict):
        return {
            k: default_collate([d[k] for d in batch])
            for k, v in elem.items()
            if torch.is_tensor(v)
        }
    else:
        return default_collate(batch)

class LenWrapper(IterableDataset):
    def __init__(self, dataset, length: int):
        self.dataset = dataset
        self._length = int(length)

    def __iter__(self):
        if hasattr(self.dataset, "__iter__"):
            return iter(self.dataset)
        elif callable(self.dataset):
            return iter(self.dataset())         
        else:
            raise TypeError("Wrapped object is not iterable")

    def __len__(self):
        return self._length

def _expand_shards(pattern: str) -> list[str]:
    if "%" in pattern and pattern.endswith(".tar"):
        pattern = re.sub(r"%0\d+d", "*", pattern)
    shards = sorted(glob(pattern))
    if len(shards) == 0:
        raise FileNotFoundError(f"No shards matched pattern: {pattern}")
    return shards

def _normalize_pattern(pat: str) -> str:
    if "%0" in pat and pat.endswith(".tar"):
        pat = pat.replace("%06d", "*").replace("%07d", "*")
    return pat

class OGameDataPipe(wds.DataPipeline):
    def __init__(
        self,
        shard_dir: str,
        shard_pattern: str,
        csv_filter: Optional[Path] = None,
        shuffle_size: int = 10_000,
        seed: int = 42,
        audio_pretraining: bool = False,
        start_index: Optional[int] = None,
        duration_frame: int = 360,
        data: str = "ogamedata",
        video_embed_dim: int = 146,
    ):  
        if csv_filter is not None:
            df = pd.read_csv(csv_filter)
            self.key_set = {
                Path(p).stem                     
                for p in df["video_folder"].tolist()
            }
        else:
            self.key_set = None    
        shard_pattern = _normalize_pattern(shard_pattern)
        pattern = os.path.join(shard_dir, shard_pattern)
        self.shards = _expand_shards(pattern)  # Check if shards exist
        self.audio_pretraining = audio_pretraining
        self.start_index = start_index
        self.duration_frame = duration_frame
        self._rng = random.Random(seed)
        self.video_embed_dim = video_embed_dim
        super().__init__(
            wds.SimpleShardList(self.shards, seed=seed),
            wds.split_by_node,
            wds.split_by_worker,
            wds.shuffle(shuffle_size, rng=random.Random(seed)),
            wds.tarfile_to_samples(handler=wds.ignore_and_continue),
            (lambda src: (s for s in src if (self.key_set is None or s["__key__"] in self.key_set))),
            wds.decode("torch"),
            wds.map(self._slice_and_format),
        )
        if data == 'ogamedata':
            self.grid_feature_length = 540
        elif data == 'vggsound':
            self.grid_feature_length = 336

    
    # ---------------- utils ---------------- #
    def _slice_and_pack(self, sample: Dict[str, torch.Tensor]) -> Dict:
        key = sample["__key__"]
        v = sample["video.pth"]
        a = sample["audio.pth"]

        total = v.shape[0]
        if self.duration_frames > total:
            raise ValueError(f"{key} is too short ({total} < {self.duration_frames})")
        if self.start_frame is None:
            start = random.randint(0, total - self.duration_frames)
        else:
            start = self.start_frame
            if start + self.duration_frames > total:
                raise ValueError(f"{key}: start+duration out of range")

        end = start + self.duration_frames
        v = v[..., :self.video_embed_dim]  
        v_clip = v[start:end]  # (F, H, W, C)
        a_clip = a[start:end]  # (F, feat_dim)

        out = {"filename": key}
        if self.audio_pretraining:
            out["audio_feature"] = a_clip
        else:
            out["video_feature"] = v_clip
            out["audio_feature"] = a_clip
        return out
    def _slice_and_format(self, sample: Dict[str, Any]) -> Dict[str, torch.Tensor]:
        v: torch.Tensor = sample["video.pth"]
        a: torch.Tensor = sample["audio.pth"]
        total_frames = v.shape[0]

        if self.duration_frame > total_frames:
            raise ValueError(f"Clip shorter than duration: {sample['__key__']}")

        if self.start_index is None:
            start = self._rng.randint(0, total_frames - self.duration_frame)
            # print("start index is None, using random start index:", start)
        else:
            if self.start_index + self.duration_frame > total_frames:
                raise ValueError("start_index + duration exceeds clip length")
            start = self.start_index

        end = start + self.duration_frame
        v = v[..., :self.video_embed_dim]  
        v_clip = v[start:end]
        a_clip = a[start:end]

        if self.audio_pretraining:
            return {"audio_feature": a_clip}
        else:
            return {"video_feature": v_clip, "audio_feature": a_clip}
        
    def _peek_grid_len(self) -> int:
        first = next(iter(self))
        vid = first.get("video_feature")
        if vid is None or vid.ndim < 4:
            raise RuntimeError("Unexpected video tensor shape.")
        return vid.shape[-3] * vid.shape[-2]

    def _probe_grid_len(
        self,
        shard_list: str,
        key_set: Optional[Set[str]],
        audio_pretraining: bool,
    ) -> int:
        ds = (
            wds.WebDataset(shard_list, handler=wds.ignore_and_continue, nodesplitter=None)
            .decode("torch")
            .to_tuple("__key__", "video.pth", "audio.pth")
        )
        for key, v, a in ds:
            if (key_set is None) or (key in key_set):
                tensor = a if audio_pretraining else v
                if tensor.ndim < 4:
                    raise RuntimeError(f"Unexpected tensor shape: {tensor.shape} in {key}")
                return tensor.shape[-3] * tensor.shape[-2]
        raise RuntimeError(
            "No sample matched key_set. "
            "Check csv_filter or make sure __key__ names align with Path.stem."
        )
    def __len__(self) -> int:
        if self.key_set is not None:
            return len(self.key_set)
        else:
            raise RuntimeError("Length cannot be determined without csv_filter")

def build_dataloader(
    shard_pattern: str,
    shard_dir: str,
    csv_path: Optional[Path],
    batch_size: int,
    num_workers: int = 4,
    audio_pretraining: bool = False,
    seed: int = 42,
    pin_memory: bool = True, 
    persistent_workers: bool =True,
    prefetch_factor: int = 2,
    shuffle_size: int = 10_000,
    start_time: Optional[float] = 0,
    duration: float = 10.0,
    video_fps: int = 30,
    drop_last: bool = True,
    n_epochs: int = 400,
    data: str = "ogamedata",
    video_embed_dim: int = 146,
):  
    if start_time is None or start_time < 0:
        start_index = None
        print("start_index", start_index)
        print("start_time is None or negative, using random start index")
    else:
        start_index = int(video_fps * start_time)
    datapipes = OGameDataPipe(
        shard_dir,
        shard_pattern,
        csv_filter=csv_path,
        audio_pretraining=audio_pretraining,
        seed=seed,
        shuffle_size=shuffle_size,
        start_index=start_index,
        duration_frame=int(video_fps * duration),  
        data=data,
        video_embed_dim=video_embed_dim,
    )
    if datapipes.key_set is not None:
        n_samples = len(datapipes.key_set)
    else:
        raise RuntimeError("n_samples cannot be determined without csv_filter")
    
    n_batches = math.floor(n_samples / batch_size) if drop_last else math.ceil(
        n_samples / batch_size
    )
    samples_per_rank = n_batches * batch_size
    # batched_dp = datapipes.with_epoch(samples_per_rank)
    batched_dp = datapipes.repeat(nepochs=n_epochs)
    # batched_dp = wds.DataPipeline(batched_dp).with_epoch(n_batches)
    loader = wds.WebLoader(
        batched_dp,
        batch_size=batch_size,
        num_workers=num_workers,
        persistent_workers=persistent_workers,
        pin_memory=pin_memory,
        prefetch_factor=prefetch_factor,
        collate_fn=tensor_only_collate
    )
    import types
    loader.__len__ = types.MethodType(lambda self, n=n_batches: n, loader)
    

    return loader, datapipes.grid_feature_length, n_samples