import os
import cv2
import torch
import numpy as np
from tqdm import tqdm
from torchvision import transforms
import imageio
import argparse
import sys

sys.path.append("/root_path/RAFT/core")
from raft import RAFT
from utils.utils import InputPadder

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def load_raft_model(ckpt_path):
    args = argparse.Namespace(
        small=False,
        mixed_precision=False,
        alternate_corr=False,
        dropout=0.0,
        max_depth=8,
        depth_network=False,
        depth_residual=False,
        depth_scale=1.0
    )
    model = torch.nn.DataParallel(RAFT(args))
    model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
    return model.module.to(DEVICE).eval()

def run_masking(video_path, output_path, mask_path, raft):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Failed to open video: {video_path}")
        return

    fps = cap.get(cv2.CAP_PROP_FPS)
    n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    ok, first = cap.read()
    if not ok:
        print(f"Failed to read first frame in {video_path}")
        return

    resize_to = (720, 480)
    first = cv2.resize(first, resize_to)
    H, W, _ = first.shape
    area_thresh = (H * W) // 6

    grid = np.stack(np.meshgrid(np.arange(W), np.arange(H)), -1).astype(np.float32)
    pos = grid.copy()
    vis = np.ones((H, W), dtype=bool)

    writer = imageio.get_writer(output_path, fps=int(fps),
        codec='libx264',
        quality=10,              
        ffmpeg_params=['-crf', '18', '-preset', 'slow']
        )

    prev = first.copy()
    frames_since_corr = 0
    freeze_mask = False
    frozen_mask = None
    all_masks = []

    # ---- Add first frame directly ----
    writer.append_data(first[:, :, ::-1])
    all_masks.append(np.ones((H, W), dtype=bool))

    def to_tensor(bgr):
        return transforms.ToTensor()(bgr).unsqueeze(0).to(DEVICE)

    def raft_flow(img1_bgr, img2_bgr):
        t1, t2 = to_tensor(img1_bgr), to_tensor(img2_bgr)
        padder = InputPadder(t1.shape)
        i1, i2 = padder.pad(t1, t2)
        with torch.no_grad():
            _, flow = raft(i1, i2, iters=20, test_mode=True)
        return padder.unpad(flow)[0].permute(1, 2, 0).cpu().numpy()

    for _ in range(1, n_frames):
        ok, cur = cap.read()
        if not ok:
            break
        cur = cv2.resize(cur, resize_to)

        if not freeze_mask:
            flow_fw = raft_flow(prev, cur)
            pos += flow_fw
            frames_since_corr += 1

            
            x_ok = (0 <= pos[..., 0]) & (pos[..., 0] < W)
            y_ok = (0 <= pos[..., 1]) & (pos[..., 1] < H)
            vis &= x_ok & y_ok

            m = np.zeros((H, W), np.uint8)

            ys, xs = np.where(vis)
            px = np.round(pos[ys, xs, 0]).astype(int)
            py = np.round(pos[ys, xs, 1]).astype(int)

            inb = (0 <= px) & (px < W) & (0 <= py) & (py < H)
            m[py[inb], px[inb]] = 1 
            m = cv2.dilate(m, np.ones((2, 2), np.uint8))


            visible_ratio = m.sum() / (H * W)
            if visible_ratio < 0.3:
                flow_0t = raft_flow(first, cur)
                pos = grid + flow_0t    

                vis = np.ones((H, W), dtype=bool)
                x_ok = (0 <= pos[..., 0]) & (pos[..., 0] < W)
                y_ok = (0 <= pos[..., 1]) & (pos[..., 1] < H)
                vis &= x_ok & y_ok

                m.fill(0)
                ys, xs = np.where(vis)
                px = np.round(pos[ys, xs, 0]).astype(int)
                py = np.round(pos[ys, xs, 1]).astype(int)
                inb = (0 <= px) & (px < W) & (0 <= py) & (py < H)
                m[py[inb], px[inb]] = 1
                m = cv2.dilate(m, np.ones((2, 2), np.uint8))

                # freeze check
                if m.sum() < area_thresh:
                    freeze_mask = True
                    frozen_mask = m.copy()

                frames_since_corr = 0

        else:
            m = frozen_mask

        effective_mask = m.astype(bool)
        all_masks.append(effective_mask)

        out = cur.copy()
        out[~effective_mask] = 0
        writer.append_data(out[:, :, ::-1])

        prev = cur if not freeze_mask else prev


    writer.close()
    cap.release()

    all_masks_array = np.stack(all_masks, axis=0)
    np.savez_compressed(mask_path, mask=all_masks_array)
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--video_path", type=str, default="panda70m_filtered/videos", help="Path to input video")
    parser.add_argument("--output_path", type=str, default="panda70m_filtered/masked_videos", help="Path to save masked video")
    parser.add_argument("--mask_path", type=str, default="panda70m_filtered/masks", help="Path to save masks")
    parser.add_argument("--raft_ckpt", type=str, default="RAFT/models/raft-things.pth")
    parser.add_argument("--start_idx", type=int, default=0, help="Start frame index")
    parser.add_argument("--end_idx", type=int, default=-1, help="End frame index (non-inclusive)")

    args = parser.parse_args()
    
    video_list = sorted([
        f for f in os.listdir(args.video_path)
        if f.endswith(".mp4")
    ])

    video_list = sorted(video_list)
    selected_videos = video_list[args.start_idx : args.end_idx]
    print(f"[Processing {len(selected_videos)} videos] Index {args.start_idx} to {args.end_idx}")

    model = load_raft_model(args.raft_ckpt)
    for fname in tqdm(selected_videos, desc="Batch Processing"):
        input_path = os.path.join(args.video_path, fname)
        mask_path = os.path.join(args.mask_path, fname.replace(".mp4", ".npz"))
        output_path = os.path.join(args.output_path, fname)

        if os.path.exists(mask_path):
            try:
                np.load(mask_path)["mask"]
                continue 
            except Exception as e:
                print(f"⚠️ Mask corrupt or unreadable: {mask_path} - Regenerating")
                
        if os.path.exists(output_path):
            continue

        run_masking(input_path, output_path, mask_path, model)