import logging
from pathlib import Path
from typing import Union, Optional, Dict
import os
import numpy as np

import pandas as pd
import torch
from torch.utils.data.dataset import Dataset
# import ffmpeg
import torchaudio
from torch.utils.data import default_collate, DataLoader, get_worker_info
from tqdm import tqdm
import h5py
import hdf5plugin
from tqdm import tqdm
from torio.io import StreamingMediaDecoder

VIDEO_EXTS = {".mp4", ".mov", ".mkv", ".webm"} 

log = logging.getLogger()

def video2audio_path(path: str) -> str:
    if path.endswith(".flac"):
        return path
    if not path.endswith(".mp4"):
        return path

    parts = Path(path).parts
    parts = list(parts)

    for i, p in enumerate(parts):
        if p.startswith("video_"):
            suffix = p.split("_", 1)[1]            # <n>
            parts[i] = f"audio_{suffix}_separated" # → audio_<n>_separated
            parts.insert(i + 1, "sfx")          
            break

    stem = Path(parts[-1]).stem  
    parts[-1] = stem + "_sfx.flac"

    return os.sep.join(parts)
def compute_z_stats(
    dataset,
    stage1_model,
    batch_size: int,
    device: torch.device,
    audio_key: str = "audio"
):
    """    
    Returns:
        z_mean: shape (dim,)
        z_std : shape (dim,)
    """
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=False)
    
    count = 0
    sum_z = None
    sum_z2 = None

    stage1_model.eval()
    stage1_model.to(device)

    with torch.no_grad():
        for batch in tqdm(loader, desc="Computing z-stats"):
            audio = batch[audio_key].to(device)  # shape: (B, ...)
            
            z = stage1_model.encode(audio)  
            z = z.transpose(-2, -1)
            z = z.reshape(-1, z.shape[-1])

            if sum_z is None:
                sum_z = torch.zeros(z.shape[1], device=device)
                sum_z2 = torch.zeros(z.shape[1], device=device)

            sum_z += z.sum(dim=0)
            sum_z2 += (z * z).sum(dim=0)
            count += z.shape[0] 

    z_mean = sum_z / count
    z_var  = (sum_z2 / count) - (z_mean * z_mean)
    z_std  = torch.sqrt(torch.clamp(z_var, min=1.0e-8)) 
    
    return z_mean, z_std



def custom_collate(batch):
    batch = [sample for sample in batch if sample is not None]
    if len(batch) == 0:
        raise ValueError("No valid samples in batch")
    return default_collate(batch)


def load_pca_components(pca_path: Union[str, Path]):
    """
    Load PCA components from a .pt file containing a dict with keys:
      - 'mean': Tensor of shape (C,)
      - 'components': Tensor of shape (k, C)
    Returns:
      mean: torch.Tensor (C,)
      components: torch.Tensor (k, C)
    """
    data = torch.load(str(pca_path), map_location='cpu')
    mean = torch.from_numpy(data['mean']).float() # (C,)
    components = torch.from_numpy(data['components']).float() # (k, C)
    return mean, components

class ExtractedGameGenXwithPCA(Dataset):

    def __init__(
        self,
        csv_path: Union[str, Path],
        premade_feature_dir_video: Union[str, Path],
        pca_path: Union[str, Path],
        video_fps: float = 30.0,
        start_time: float = 0.0,
        duration: float = 2.56,
        vision_aggregation: bool = False,
        error_log_path: Optional[str] = None,
        source: str = "hdf5",  
        no_pca: bool = False,  # If True, skip PCA transformation
        dataset: str = "ogamedata",
    ):
        assert source in {"hdf5", "npy"}, "source must be 'hdf5' or 'npy'"
        self.source = source
        
        self.data = pd.read_csv(csv_path)
        self.premade_feature_dir_video = premade_feature_dir_video
        # self.feature_root = Path(premade_feature_dir)premade_feature_dir_video
        self.pca_mean, self.pca_components = load_pca_components(pca_path)
        self.video_fps = video_fps
        self.start_index = int(start_time * video_fps)
        self.frames_needed = int(duration * video_fps)
        self.no_pca = no_pca
        self.dataset = dataset
        
        self.error_log_path = error_log_path
        # self._h5_file: Optional[h5py.File] = None
        self._h5_file: Optional[h5py.File] = None
        self.h5_path = str(premade_feature_dir_video)
        
        self._cache: dict[str, torch.Tensor] = {}
        with open(self.error_log_path, "w") as f:
            f.write("")
        
        if vision_aggregation:
            row = self.data.iloc[-1]
            file_path = row['video_folder']  
            parent = Path(file_path).parent.name
            filename = os.path.splitext(os.path.basename(file_path))[0]
            if self.source == "hdf5":
                key = filename
                hf = self._open_h5()
                if key not in hf:
                    raise ValueError(f"Key {key} not found in HDF5 file {self.premade_feature_dir_video}.")
                feature = hf[key][...] 
                feature = torch.tensor(feature.copy()).float()
            else:
                if dataset == "ogamedata":
                    npy_path = os.path.join(self.premade_feature_dir_video, parent, f"{filename}_grid_resx2.npy")
                else:
                    npy_path = os.path.join(self.premade_feature_dir_video, f"{filename}_grid_resx2.npy")
                feature = torch.tensor(np.load(npy_path)).float()
            self.grid_feature_length = int(feature.size(-3) * feature.size(-2))
            
    def _open_h5(self) -> h5py.File:
        if self._h5_file is None:
            self._h5_file = h5py.File(self.h5_path, "r")
        return self._h5_file

    def __len__(self):
        return len(self.data)
    
    def __del__(self):
        if self._h5_file is not None:
            self._h5_file.close()

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        row = self.data.iloc[idx]
        video_path = row['video_folder']
        parent = Path(video_path).parent.name
        filename = os.path.splitext(os.path.basename(video_path))[0]
    
        try:
            if self.source == "hdf5":
                key = filename
                hf = self._open_h5()
                if key not in hf:
                    raise ValueError(f"Key {key} not found in HDF5 file {self.h5_path}.")
                feature = hf[key][...]
            else:
                if self.dataset == "ogamedata":
                    npy_path = os.path.join(self.premade_feature_dir_video, parent, f"{filename}_grid_resx2.npy")
                else:
                    npy_path = os.path.join(self.premade_feature_dir_video, f"{filename}_grid_resx2.npy")
                #     raise FileNotFoundError(npy_path)
                feature = np.load(npy_path)  # (T, H, W, C) float16
            total_frames = feature.shape[0]
            if total_frames < self.frames_needed:
                raise ValueError(
                    f"Video feature length ({total_frames}) is shorter than required ({self.frames_needed})."
                )
            feature_path = os.path.join(self.premade_feature_dir_video, filename + '_grid_resx2.npy')
            feature_slice = feature[self.start_index : self.start_index + self.frames_needed, :]
            feature_tensor = torch.tensor(feature_slice.copy())
            if not self.no_pca:
                T, H, W, C = feature_tensor.shape

                flat_feature_tensor = feature_tensor.reshape(-1, C).float()         # (N, C)      
                centered = flat_feature_tensor - self.pca_mean.unsqueeze(0)           # (N, C)
                reduced = centered @ self.pca_components.t()             # (N, k)
                k = self.pca_components.size(0)
                feature_tensor = reduced.view(T, H, W, k).half().contiguous()         # (T2, H, W, k)
            feature_tensor = feature_tensor.float()  # float32 Tensor
            return {
                'filename': filename,
                'video_feature': feature_tensor,  # shape: (T2, H, W, k)
            }

        except Exception as e:
            if self.error_log_path:
                with open(self.error_log_path, 'a') as f:
                    f.write(f"{filename}: {e}\n")
            return None
