#!/usr/bin/env python3
import os
import subprocess
import argparse
from pathlib import Path
import shutil
import tempfile
import re

import numpy as np
import cv2
import torch
import gc
import sys
sys.path.append("./sam2")
from sam2.build_sam import build_sam2_video_predictor
import imageio.v3 as iio


COLOR = [
    (128, 0, 0),
    (128, 128, 0),
    (128, 0, 128),
    (0, 128, 128),
    (64, 0, 0),
    (64, 128, 0),
    (64, 0, 128),
    (0, 64, 128),
    (128, 64, 0),
    (128, 0, 64),
    (0, 128, 64),
]

def convert_frames_to_mp4(input_dir, output_file, fps=30, crf=23):
    """
    Convert a sequence of PNG frames to MP4 video using ffmpeg

    Args:
        input_dir: Directory containing the PNG frames
        output_file: Path to output MP4 file
        fps: Frames per second (default: 30)
        crf: Constant Rate Factor - controls quality (18-28 is good, lower is better)
    """
    input_dir = Path(input_dir)


    if not input_dir.exists():
        raise ValueError(f"Input directory {input_dir} does not exist")


    output_file = Path(output_file)
    output_file.parent.mkdir(parents=True, exist_ok=True)


    cmd = [
        'ffmpeg',
        '-y',
        '-framerate', str(fps),
        '-i', str(input_dir / 'frame_%d.png'),
        '-c:v', 'libx264',
        '-preset', 'medium',
        '-crf', str(crf),
        '-pix_fmt', 'yuv420p',
        str(output_file)
    ]

    print("Running ffmpeg command:")
    print(" ".join(cmd))


    try:
        subprocess.run(cmd, check=True)
        print(f"\nSuccessfully created video: {output_file}")
    except subprocess.CalledProcessError as e:
        print(f"Error during conversion: {e}")
        raise
    except FileNotFoundError:
        print("Error: ffmpeg not found. Please install ffmpeg first.")
        raise

def load_txt(gt_path):
    with open(gt_path, 'r') as f:
        gt = f.readlines()
    prompts = {}
    for fid, line in enumerate(gt):
        line = line.strip()
        if len(line) == 0:
            continue
        x, y, w, h = map(float, line.split(','))
        x, y, w, h = int(x), int(y), int(w), int(h)
        prompts[fid] = ((x, y, x + w, y + h), 0)
    return prompts

def determine_model_cfg(model_path):
    if "large" in model_path:
        return "configs/samurai/sam2.1_hiera_l.yaml"
    elif "base_plus" in model_path:
        return "configs/samurai/sam2.1_hiera_b+.yaml"
    elif "small" in model_path:
        return "configs/samurai/sam2.1_hiera_s.yaml"
    elif "tiny" in model_path:
        return "configs/samurai/sam2.1_hiera_t.yaml"
    else:
        raise ValueError("Unknown model size in path!")

def extract_frame_number(filename):
    match = re.search(r'frame_(\d+)', filename)
    if match:
        return int(match.group(1))
    return None

def prepare_frames_or_path(video_path):
    if video_path.endswith(".mp4") or Path(video_path).is_dir():
        return video_path
    else:
        raise ValueError("Invalid video_path format. Should be .mp4 or a directory of jpg/png frames.")


def process_video(args, video_path):
    device = args.device
    model_cfg = determine_model_cfg(args.model_path)
    predictor = build_sam2_video_predictor(model_cfg, args.model_path, device=device)
    prompts = load_txt(args.txt_path)

    frame_rate = 30
    loaded_frames = []
    if args.save_to_video:
        if isinstance(video_path, (str, Path)):
            video_path = Path(video_path)
        
        if video_path.is_dir():

            frames = [
                f for f in sorted(video_path.glob("*.jpg")) + sorted(video_path.glob("*.png"))
                if extract_frame_number(f.name) is not None
            ]
            
            if not frames:
                raise ValueError("No valid frames found in the directory.")
            

            frames_sorted = sorted(frames, key=lambda x: extract_frame_number(x.name))
            loaded_frames = [cv2.imread(str(frame_path)) for frame_path in frames_sorted]
            height, width = loaded_frames[0].shape[:2]
        else:
            cap = cv2.VideoCapture(str(video_path))
            frame_rate = cap.get(cv2.CAP_PROP_FPS)
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                loaded_frames.append(frame)
            cap.release()
            if not loaded_frames:
                raise ValueError("No frames were loaded from the video.")
            height, width = loaded_frames[0].shape[:2]

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(args.video_output_path, fourcc, frame_rate, (width, height))

    mask_dir = args.mask_dir
    if mask_dir is not None:
        mask_dir = Path(mask_dir)
        mask_dir.mkdir(exist_ok=True, parents=True)

    with torch.inference_mode(), torch.autocast('cuda', dtype=torch.float16):

        video_path_str = str(video_path)
        state = predictor.init_state(video_path_str, offload_video_to_cpu=True)
        all_masks = []
        
        for idx, (bbox, track_label) in enumerate(prompts.values()):
            _, _, masks = predictor.add_new_points_or_box(state, box=bbox, frame_idx=0, obj_id=idx)
            all_masks.append(masks)

        for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
            mask_to_vis = {}
            bbox_to_vis = {}

            for obj_id, mask in zip(object_ids, masks):
                mask = mask[0].cpu().numpy()
                mask = mask > 0.0
                non_zero_indices = np.argwhere(mask)
                if len(non_zero_indices) == 0:
                    bbox = [0, 0, 0, 0]
                else:
                    y_min, x_min = non_zero_indices.min(axis=0).tolist()
                    y_max, x_max = non_zero_indices.max(axis=0).tolist()
                    bbox = [x_min, y_min, x_max - x_min, y_max - y_min]
                bbox_to_vis[obj_id] = bbox
                mask_to_vis[obj_id] = mask
                
                if mask_dir is not None:
                    mask_path = mask_dir / f'OBJ{obj_id:02}_{frame_idx:04}.png'
                    iio.imwrite(str(mask_path), mask.astype(np.uint8) * 255)

            if args.save_to_video and frame_idx < len(loaded_frames):
                img = loaded_frames[frame_idx].copy()
                for obj_id, mask in mask_to_vis.items():
                    mask_img = np.zeros((height, width, 3), np.uint8)
                    mask_img[mask] = COLOR[obj_id % len(COLOR)]
                    img = cv2.addWeighted(img, 1, mask_img, 0.95, 0)

                for obj_id, bbox in bbox_to_vis.items():
                    cv2.rectangle(
                        img,
                        (bbox[0], bbox[1]),
                        (bbox[0] + bbox[2], bbox[1] + bbox[3]),
                        COLOR[obj_id % len(COLOR)],
                        2
                    )
                out.write(img)

    if args.save_to_video:
        out.release()

    del predictor, state
    gc.collect()
    torch.clear_autocast_cache()
    torch.cuda.empty_cache()

def main():
    parser = argparse.ArgumentParser(description='Process video frames with SAM.')
    

    parser.add_argument('--input_path', required=True, help='Input path: directory of frames or MP4 video.')
    parser.add_argument('--txt_path', required=True, help='Path to ground truth text file.')
    

    parser.add_argument('--model_path', default="sam2/checkpoints/sam2.1_hiera_large.pt", help='Path to the model checkpoint.')
    parser.add_argument('--video_output_path', default="demo.mp4", help='Path to save the output video.')
    parser.add_argument('--save_to_video', action='store_true', help='Flag to save results to a video.')
    parser.add_argument('--mask_dir', help='If provided, save mask images to the given directory')
    parser.add_argument('--device', default="cuda:0", help='Device to run the model on.')

    args = parser.parse_args()
    

    input_path = Path(args.input_path)
    
    if not input_path.exists():
        raise ValueError(f"Input path does not exist: {input_path}")
    
    if not (input_path.is_dir() or input_path.suffix.lower() == '.mp4'):
        raise ValueError("Input path must be a directory of frames or an MP4 video.")


    process_video(args, input_path)

if __name__ == "__main__":
    main()


