import tempfile
import numpy as np
import PIL
from typing import List, Union, Optional, Tuple
import mediapy
import os
from PIL import Image
import cv2

def save_video(
    video_frames: Union[List[np.ndarray], List[PIL.Image.Image]],
    output_video_path: str = None,
    fps: int = 10,
    crf: int = 18,
) -> str:
    if output_video_path is None:
        output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name

    if isinstance(video_frames[0], np.ndarray):
        video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]

    elif isinstance(video_frames[0], PIL.Image.Image):
        video_frames = [np.array(frame) for frame in video_frames]
    mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf)
    return output_video_path

def export_to_video(video_frames, output_video_path, fps):
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    h, w, _ = video_frames[0].shape
    video_writer = cv2.VideoWriter(
        output_video_path, fourcc, fps=fps, frameSize=(w, h))
    for i in range(len(video_frames)):
        img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
        video_writer.write(img)

def export_to_gif(frames, output_gif_path, fps):
    """
    Export a list of frames to a GIF.

    Args:
    - frames (list): List of frames (as numpy arrays or PIL Image objects).
    - output_gif_path (str): Path to save the output GIF.
    - duration_ms (int): Duration of each frame in milliseconds.

    """
    # Convert numpy arrays to PIL Images if needed
    pil_frames = [Image.fromarray(frame) if isinstance(
        frame, np.ndarray) else frame for frame in frames]

    pil_frames[0].save(output_gif_path.replace('.mp4', '.gif'),
                       format='GIF',
                       append_images=pil_frames[1:],
                       save_all=True,
                       duration=500,
                       loop=0)
            
def load_png_folder_to_array(folder_path):
    # 获取文件夹下所有png文件，并按文件名排序
    png_files = sorted([f for f in os.listdir(folder_path) if f.lower().endswith('.png')])

    if not png_files:
        raise ValueError(f"No PNG files found in {folder_path}")

    frames = []
    for file_name in png_files:
        img_path = os.path.join(folder_path, file_name)
        img = Image.open(img_path).convert('RGB')  # 转成RGB保证3通道
        img_array = np.array(img, dtype=np.uint8)
        frames.append(img_array)

    # 转成 numpy 数组, 形状 (T, H, W, 3)
    frames_array = np.stack(frames, axis=0)
    return frames_array


def _resize_keep_aspect_crop(img: np.ndarray,
                             target_hw: Tuple[int, int],
                             interpolation: int) -> np.ndarray:
    """等比例缩放后中心裁剪到 (H, W)。"""
    th, tw = target_hw
    h, w = img.shape[:2]

    # 选更大的比例，保证缩放后覆盖目标尺寸，再做中心裁剪
    scale = max(th / h, tw / w)
    new_w = int(round(w * scale))
    new_h = int(round(h * scale))

    resized = cv2.resize(img, (new_w, new_h), interpolation=interpolation)

    y0 = (new_h - th) // 2
    x0 = (new_w - tw) // 2
    return resized[y0:y0 + th, x0:x0 + tw]

def read_video_to_array(path: str,
                        resize: Optional[Tuple[int, int]] = None,  # (H, W)
                        keep_aspect: Optional[str] = None,         # "crop" 或 None
                        interpolation: Optional[int] = None        # cv2.INTER_*
                        ) -> np.ndarray:
    """
    读取 .mp4，返回 (T, H, W, 3) (RGB)。
    - resize=(H, W) 时启用缩放
    - keep_aspect="crop"：等比例缩放后中心裁剪为 (H, W)，不变形、无黑边
    - interpolation：插值策略，默认按缩放方向自动选择
    """
    cap = cv2.VideoCapture(path)
    if not cap.isOpened():
        raise ValueError(f"Failed to open video: {path}")

    frames = []

    # 先读取一帧用于确定默认插值策略
    ret, frame = cap.read()
    if not ret:
        cap.release()
        return np.empty((0, 0, 0, 3), dtype=np.uint8)

    # BGR -> RGB
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # 选择默认插值策略（若用户未指定）
    if interpolation is None and resize is not None:
        h0, w0 = frame.shape[:2]
        th, tw = resize
        # 估算缩放比例（以面积变化近似判断缩小/放大）
        downscale = (th * tw) < (h0 * w0)
        interpolation = cv2.INTER_AREA if downscale else cv2.INTER_CUBIC
    elif interpolation is None:
        interpolation = cv2.INTER_LINEAR  # 不缩放时用线性占位

    # 处理首帧
    if resize is not None:
        if keep_aspect == "crop":
            out = _resize_keep_aspect_crop(frame, resize, interpolation)
        else:
            # 直接拉伸（可能变形）
            out = cv2.resize(frame, (resize[1], resize[0]), interpolation=interpolation)
    else:
        out = frame
    frames.append(out)

    # 其余帧
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        if resize is not None:
            if keep_aspect == "crop":
                frame = _resize_keep_aspect_crop(frame, resize, interpolation)
            else:
                frame = cv2.resize(frame, (resize[1], resize[0]), interpolation=interpolation)

        frames.append(frame)

    cap.release()
    return np.asarray(frames, dtype=np.uint8)

def temporal_resample(frames, T_target: int, method: str = "linear"):
    """
    将视频帧序列重采样到目标帧数 T_target。
    
    参数
    ----
    frames : np.ndarray 或 可转为 np.ndarray
        形状为 (T, H, W, C) 的帧序列，dtype 可为 uint8/float 等。
    T_target : int
        目标帧数，>=1。
    method : {"linear", "nearest"}
        - "nearest": 最近邻时间抽样（快，可能跳变）
        - "linear" : 相邻帧线性插值（平滑，适合插帧）
    
    返回
    ----
    out : np.ndarray
        形状为 (T_target, H, W, C)，与输入 dtype 一致。
    """
    arr = np.asarray(frames)
    if arr.ndim != 4:
        raise ValueError(f"frames must be (T,H,W,C), got shape {arr.shape}")
    if T_target <= 0:
        raise ValueError("T_target must be >= 1")

    T = arr.shape[0]
    if T == 0:
        return np.empty((0,) + arr.shape[1:], dtype=arr.dtype)
    if T == 1:
        # 只有一帧时，重复到目标长度
        return np.repeat(arr, T_target, axis=0)

    # 目标时间位置（包含两端）
    # 例如 T=5 -> 索引范围 [0,4]；T_target=8 -> 生成 8 个等距位置
    pos = np.linspace(0, T - 1, T_target)

    if method == "nearest":
        idx = np.rint(pos).astype(np.int64)
        return arr[idx]

    elif method == "linear":
        left = np.floor(pos).astype(np.int64)
        right = np.minimum(left + 1, T - 1)
        alpha = (pos - left).astype(np.float32)

        # [T_target, 1, 1, 1] 方便广播
        alpha = alpha[:, None, None, None]

        f0 = arr[left].astype(np.float32)
        f1 = arr[right].astype(np.float32)
        out = f0 * (1.0 - alpha) + f1 * alpha

        # 回到原 dtype
        if np.issubdtype(arr.dtype, np.integer):
            out = np.clip(np.round(out), 0, 255).astype(arr.dtype)
        else:
            out = out.astype(arr.dtype)
        return out

    else:
        raise ValueError(f'Unknown method "{method}", expected "linear" or "nearest".')

def get_name(video_path):
    video_name = os.path.basename(video_path)  # "my_clip.mp4"
    video_stem = os.path.splitext(video_name)[0]  # "my_clip"
    return video_name