import os
import cv2
import argparse
import numpy as np
from tqdm import tqdm
from PIL import Image, ImageSequence
from pathlib import Path


def read_video_frames(path, n_frames):
    """Return a list of PIL.Image frames from a .gif or .mp4 file."""
    frames = []
    path = Path(path)
    if path.suffix.lower() == ".gif":
        gif = Image.open(path)
        for i, frame in enumerate(ImageSequence.Iterator(gif)):
            if i >= n_frames:
                break
            frames.append(frame.convert("RGB"))
    elif path.suffix.lower() == ".mp4":
        cap = cv2.VideoCapture(str(path))
        while len(frames) < n_frames:
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
        cap.release()
    else:
        raise ValueError("Unsupported file type")
    return frames


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_root", required=True, help="Root directory for output frames")
    parser.add_argument("--input_video", required=True, help="Original video used for bbox computation")
    parser.add_argument("--out_dir", required=True, help="Directory containing SV4D generated videos")
    parser.add_argument("--generate_id", required=True, type=int, help="Video ID")
    parser.add_argument("--n_frames", type=int, default=240)
    parser.add_argument("--image_frame_ratio", type=float, default=0.9)
    parser.add_argument("--W", type=int, default=576)
    parser.add_argument("--H", type=int, default=576)
    parser.add_argument("--white_thresh", type=int, default=250)
    parser.add_argument("--cams", type=str, default="1,2,3,4")
    args = parser.parse_args()

    output_root = args.data_root
    input_path = args.input_video
    processed_video_root = args.out_dir
    vid_num = args.generate_id
    n_frames = args.n_frames
    image_frame_ratio = args.image_frame_ratio
    W, H = args.W, args.H
    white_thresh = args.white_thresh

    # Step 1: read video frames
    frames = read_video_frames(input_path, n_frames)
    if len(frames) < n_frames:
        raise ValueError(f"Video has {len(frames)} frames, less than {n_frames}")

    # Step 2: compute global bounding box
    box_coord = [np.inf, np.inf, 0, 0]
    for image in frames:
        image_arr = np.array(image)
        in_h, in_w = image_arr.shape[:2]
        if image.mode == "RGBA":
            mask = np.array(image.split()[-1])
            _, mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY)
        else:
            _, mask = cv2.threshold(
                (np.array(image).mean(-1) <= white_thresh).astype(np.uint8) * 255,
                0,
                255,
                cv2.THRESH_BINARY,
            )

        x, y, w, h = cv2.boundingRect(mask)
        box_coord[0] = min(box_coord[0], x)
        box_coord[1] = min(box_coord[1], y)
        box_coord[2] = max(box_coord[2], x + w)
        box_coord[3] = max(box_coord[3], y + h)

    # Step 3: expand to square region (center aligned)
    original_center = (in_h // 2, in_w // 2)
    box_square = max(
        original_center[0] - box_coord[0],
        original_center[1] - box_coord[1],
    )
    box_square = max(box_square, box_coord[2] - original_center[0])
    box_square = max(box_square, box_coord[3] - original_center[1])
    x = max(0, original_center[1] - box_square)
    y = max(0, original_center[0] - box_square)
    w, h = min(in_w, 2 * box_square), min(in_h, 2 * box_square)
    box_size = box_square * 2

    # Step 4: compute side_len
    side_len = int(box_size / image_frame_ratio)

    print("original_size:", [in_w, in_h])
    print("crop_box:", [x, y, x + w, y + h])
    print("side_len:", side_len)
    print("resize_to:", [W, H])

    # Pack params like your original code expected
    params = {
        "x": x, "y": y, "w": w, "h": h,
        "side_len": side_len, "W": W, "H": H,
        "in_w": in_w, "in_h": in_h,
        "box_size": box_size
    }

    # Step 5: process camera videos
    cam_num = [int(s) for s in args.cams.split(",") if s.strip()]
    box_size_w = min(w, params["box_size"])
    box_size_h = min(h, params["box_size"])
    center = side_len // 2
    offset_r = center - box_size_h // 2
    offset_c = center - box_size_w // 2

    for cam in cam_num:
        proc_path = os.path.join(processed_video_root, f"{vid_num:06d}_v0{cam:02d}.mp4")
        out_dir = os.path.join(output_root, f"frames_cam{cam:02d}")
        os.makedirs(out_dir, exist_ok=True)

        cap = cv2.VideoCapture(proc_path)
        frame_idx = 0
        pbar = tqdm(total=int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), desc=f"cam{cam:02d}", unit="f")

        while True:
            ret, frame_bgr = cap.read()
            if not ret:
                break

            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
            pil_img = Image.fromarray(frame_rgb)

            pil_unresized = pil_img.resize((side_len, side_len), Image.LANCZOS)
            crop_box = (
                offset_c, offset_r,
                offset_c + box_size_w,
                offset_r + box_size_h
            )
            pil_crop = pil_unresized.crop(crop_box)
            pil_crop_resized = pil_crop.resize((w, h), Image.LANCZOS)

            canvas = Image.new("RGB", (in_w, in_h), (255, 255, 255))
            canvas.paste(pil_crop_resized, (x, y))
            canvas = Image.fromarray(255 - np.array(canvas))

            canvas.save(os.path.join(out_dir, f"frame_{frame_idx:03d}.png"))

            frame_idx += 1
            pbar.update(1)

        pbar.close()
        cap.release()


if __name__ == "__main__":
    main()