#!/usr/bin/env python3
import os
import argparse
import glob
import re
import subprocess
import shutil
from pathlib import Path
import cv2
import numpy as np
import sys


def build_combined_first_mask(masks_dir: str) -> str:
    """
    Build a **single** first-frame mask containing *all* persons/objects.
    """
    std = os.path.join(masks_dir, "mask_0.png")
    if os.path.exists(std):
        return std

    first_masks = sorted(glob.glob(os.path.join(masks_dir, "OBJ*_*0000.png")))
    if not first_masks:
        raise FileNotFoundError("No first-frame masks found for combination")

    combined = None
    for mp in first_masks:
        m = cv2.imread(mp, cv2.IMREAD_GRAYSCALE)
        if m is None:
            continue
        _, m_bin = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)
        if combined is None:
            combined = m_bin
        else:
            combined = cv2.bitwise_or(combined, m_bin)

    if combined is None:
        raise RuntimeError("Failed to create combined mask")

    combined_path = os.path.join(masks_dir, "combined_mask_0.png")
    cv2.imwrite(combined_path, combined)
    return combined_path

def run_matanyone(input_path: str,
                  mask_path: str,
                  output_dir: str,
                  suffix: str,
                  matanyone_script: str) -> bool:
    """
    Invoke MatAnyone to generate per-frame RGBA human masks.
    """
    cmd = [
        "python", matanyone_script,
        "-i", input_path,
        "-m", mask_path,
        "-o", output_dir,
        "--suffix", suffix,
        "-w", "0",
        "--alpha_only"
    ]
    print("Running MatAnyone segmentation command:")
    print(" ".join(cmd))
    try:
        subprocess.run(cmd, check=True)
        print(f"MatAnyone segmentation completed, RGBA results saved to {output_dir}")
        return True
    except subprocess.CalledProcessError as e:
        print(f"Failed to run MatAnyone: {e}")
        return False


def parse_args():
    parser = argparse.ArgumentParser(
        description='Complete video processing workflow: detect keyframes, predict forward/backward separately, merge results and cleanup temporary files'
    )
    parser.add_argument('--input_dir', type=str, required=True,
                        help='Directory containing original video frames and detection results')
    parser.add_argument('--output_dir', type=str, required=True,
                        help='Directory to save processing results')
    parser.add_argument('--video_path', type=str, required=True,
                        help='Original video file path for getting frame rate')
    parser.add_argument('--sam_script', type=str, default='scripts/demo_mul.py',
                        help='SAM segmentation script path')
    parser.add_argument('--model_path', type=str, default='sam2/checkpoints/sam2.1_hiera_large.pt',
                        help='SAM model path')
    parser.add_argument('--device', type=str, default='cuda:0',
                        help='Device to run on')
    parser.add_argument('--matanyone_script', type=str,
                        default='../../MatAnyone/inference_matanyone.py',
                        help='Path to MatAnyone inference_matanyone.py')
    parser.add_argument('--keep_temp', action='store_true',
                        help='Keep temporary files (default: false)')
    return parser.parse_args()

def get_detection_frame_id(input_dir):
    detected = glob.glob(os.path.join(input_dir, "frame_*_detected.jpg"))
    if not detected:
        raise ValueError(f"No detected frames found in {input_dir}")
    for f in detected:
        m = re.search(r'frame_(\d+)_detected\.jpg', os.path.basename(f))
        if m:
            fid = int(m.group(1))
            break
    bbox = os.path.join(input_dir, f"frame_{fid}_boxes.txt")

    if not os.path.exists(bbox):
        print(f"Corresponding bbox file not found: {bbox}")
        raise ValueError(f"Corresponding bbox file not found: {bbox}")
    return fid, bbox

def get_video_fps(video_path):
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    cap.release()
    return fps

def get_original_frame_ids(input_dir):
    images_dir = os.path.join(input_dir, "images")
    jpgs = glob.glob(os.path.join(images_dir, "frame_*.jpg"))
    pngs = glob.glob(os.path.join(images_dir, "frame_*.png"))
    files = jpgs + pngs
    if not files:
        raise ValueError(f"No frame files found in directory {images_dir}")
    ids = []
    for f in files:
        m = re.search(r'frame_(\d+)\.(jpg|png)', os.path.basename(f))
        if m:
            ids.append(int(m.group(1)))
    return sorted(ids)

def prepare_forward_frames(input_dir, output_dir, frame_id):
    forward_dir = os.path.join(output_dir, "forward_frames")
    os.makedirs(forward_dir, exist_ok=True)
    forward_masks_dir = os.path.join(output_dir, "forward_masks")
    os.makedirs(forward_masks_dir, exist_ok=True)

    ids = get_original_frame_ids(input_dir)
    max_id = max(ids)
    mapping = {}
    new_idx = 0
    for i in range(frame_id, max_id + 1):
        src = os.path.join(input_dir, "images", f"frame_{i}.jpg")
        if os.path.exists(src):
            dst = os.path.join(forward_dir, f"frame_{new_idx}.jpg")
            shutil.copy(src, dst)
            mapping[new_idx] = i

            new_idx += 1

    with open(os.path.join(output_dir, "forward_mapping.txt"), 'w') as f:
        for ni, oi in mapping.items():
            f.write(f"{ni},{oi}\n")

    return forward_dir, forward_masks_dir, mapping

def prepare_backward_frames(input_dir, output_dir, frame_id):
    backward_dir = os.path.join(output_dir, "backward_frames")
    os.makedirs(backward_dir, exist_ok=True)
    backward_masks_dir = os.path.join(output_dir, "backward_masks")
    os.makedirs(backward_masks_dir, exist_ok=True)

    ids = get_original_frame_ids(input_dir)
    min_id = min(ids)
    mapping = {}
    new_idx = 0
    for i in range(frame_id, min_id - 1, -1):
        src = os.path.join(input_dir, "images", f"frame_{i}.jpg")
        if os.path.exists(src):
            dst = os.path.join(backward_dir, f"frame_{new_idx}.jpg")
            shutil.copy(src, dst)
            mapping[new_idx] = i

            new_idx += 1

    with open(os.path.join(output_dir, "backward_mapping.txt"), 'w') as f:
        for ni, oi in mapping.items():
            f.write(f"{ni},{oi}\n")

    return backward_dir, backward_masks_dir, mapping

def create_bbox_file(bbox_file, output_dir, filename):
    new_file = os.path.join(output_dir, filename)
    with open(bbox_file) as fr, open(new_file, 'w') as fw:
        for line in fr:
            parts = line.strip().split(',')

            if len(parts) == 4:
                x1, y1, x2, y2 = map(float, parts)
                w, h = x2 - x1, y2 - y1
                fw.write(f"{x1},{y1},{w},{h}\n")
    print(f"Created converted bbox file: {new_file}")
    return new_file

def run_sam_segmentation(sam_script, input_path, txt_path, mask_dir, output_path, model_path, device):
    if sam_script.startswith('scripts/') and os.getcwd().endswith('/scripts'):
        sam_script = sam_script.replace('scripts/', '', 1)
    cmd = [
        "python", sam_script,
        "--input_path", input_path,
        "--txt_path", txt_path,
        "--mask_dir", mask_dir,
        "--save_to_video",
        "--video_output_path", output_path,
        "--model_path", model_path,
        "--device", device
    ]
    root = Path(sam_script).resolve().parent.parent
    sam2_dir = root / "sam2"
    env = os.environ.copy()
    env["PYTHONPATH"] = f"{root}:{sam2_dir}:{env.get('PYTHONPATH','')}"
    print("Running SAM segmentation command:", " ".join(cmd))
    try:
        subprocess.run(cmd, check=True, env=env)
        print(f"SAM segmentation completed, results saved to {mask_dir}")
        return True
    except subprocess.CalledProcessError as e:
        print(f"Error running SAM segmentation script: {e}")
        return False


def merge_alpha_masks(forward_rgba_dir, backward_rgba_dir,
                      forward_mapping, backward_mapping,
                      output_dir, original_frame_ids):
    """
    Merge forward/backward MatAnyone RGBA masks into one-to-one
    with original frames, saving under ../alphamasks alongside masks/.
    """
    parent    = os.path.dirname(output_dir)
    alpha_dir = os.path.join(parent, "alphamasks")
    os.makedirs(alpha_dir, exist_ok=True)


    fwd_rev = {orig: new for new, orig in forward_mapping.items()}
    bwd_rev = {orig: new for new, orig in backward_mapping.items()}

    count = 0
    for orig in original_frame_ids:
        if orig in fwd_rev:
            new_idx = fwd_rev[orig]
            src_dir = forward_rgba_dir
        elif orig in bwd_rev:
            new_idx = bwd_rev[orig]
            src_dir = backward_rgba_dir
        else:
            print(f"Warning: original frame {orig} not found in any mapping")
            continue

        rgba_file = os.path.join(src_dir, f"frame_{new_idx}.png")
        if not os.path.exists(rgba_file):
            print(f"Warning: RGBA mask {rgba_file} not found")
            continue

        dst = os.path.join(alpha_dir, f"frame_{orig}.png")
        shutil.copy(rgba_file, dst)
        count += 1

    print(f"Merged {count} alpha masks to {alpha_dir}")
    return alpha_dir


def merge_masks(forward_masks_dir, backward_masks_dir, forward_mapping, backward_mapping, output_dir, original_frame_ids):
    """Merge forward and backward predicted masks into original frame sequence"""
    merged_masks_dir = os.path.join(output_dir)
    os.makedirs(merged_masks_dir, exist_ok=True)
    

    forward_is_sam2 = len(glob.glob(os.path.join(forward_masks_dir, "OBJ*.png"))) > 0
    backward_is_sam2 = len(glob.glob(os.path.join(backward_masks_dir, "OBJ*.png"))) > 0
    

    if forward_is_sam2 != backward_is_sam2:
        raise ValueError("Forward and backward mask formats are inconsistent, cannot merge")
    
    is_sam2 = forward_is_sam2
    print(f"Detected {'SAM2' if is_sam2 else 'standard'} format mask files")
    

    forward_reverse_mapping = {old_idx: new_idx for new_idx, old_idx in forward_mapping.items()}
    backward_reverse_mapping = {old_idx: new_idx for new_idx, old_idx in backward_mapping.items()}
    

    processed_count = 0
    for orig_idx in original_frame_ids:

        if orig_idx in forward_reverse_mapping:
            source_dir = forward_masks_dir
            new_idx = forward_reverse_mapping[orig_idx]
            source_type = "forward"
        elif orig_idx in backward_reverse_mapping:
            source_dir = backward_masks_dir
            new_idx = backward_reverse_mapping[orig_idx]
            source_type = "backward"
        else:
            print(f"Warning: original frame {orig_idx} does not exist in mapping")
            continue
        

        if is_sam2:

            obj_masks = glob.glob(os.path.join(source_dir, f"OBJ*_{new_idx:04}.png"))
            for obj_mask in obj_masks:
                obj_id = int(re.search(r'OBJ(\d+)_', os.path.basename(obj_mask)).group(1))
                dest_path = os.path.join(merged_masks_dir, f"OBJ{obj_id:02}_{orig_idx:04}.png")
                try:
                    shutil.copy(obj_mask, dest_path)
                    processed_count += 1
                except Exception as e:
                    print(f"Error processing mask: {e}, mask: {obj_mask}")
        else:

            mask_path = os.path.join(source_dir, f"mask_{new_idx}.png")
            if os.path.exists(mask_path):
                dest_path = os.path.join(merged_masks_dir, f"mask_{orig_idx}.png")
                try:
                    shutil.copy(mask_path, dest_path)
                    processed_count += 1
                except Exception as e:
                    print(f"Error processing mask: {e}, mask: {mask_path}")
            else:
                print(f"Warning: {source_type} mask {new_idx} not found")
        

        if processed_count % 50 == 0:
            print(f"Processed {processed_count} masks")
    
    print(f"Merged {processed_count} masks to {merged_masks_dir}")
    return merged_masks_dir


def generate_final_video(input_dir, merged_masks_dir, output_dir, original_frame_ids, video_path):
    """
    Generate visualization video combining masks and color overlays.
    """
    images_dir = os.path.join(input_dir, "images")
    out_vid = os.path.join(output_dir, "../final_samurai_result.mp4")

    first = cv2.imread(os.path.join(images_dir, f"frame_{original_frame_ids[0]}.jpg"))
    h, w, _ = first.shape
    fps = get_video_fps(video_path)
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    vw = cv2.VideoWriter(out_vid, fourcc, fps, (w, h))

    is_sam2 = len(glob.glob(os.path.join(merged_masks_dir, "OBJ*.png"))) > 0
    colors = [(128,0,0),(128,128,0),(128,0,128),(0,128,128),
              (64,0,0),(64,128,0),(64,0,128),(0,64,128),
              (128,64,0),(128,0,64),(0,128,64)]

    for i, idx in enumerate(original_frame_ids):
        fp = os.path.join(images_dir, f"frame_{idx}.jpg")
        if not os.path.exists(fp):
            fp = os.path.join(images_dir, f"frame_{idx}.png")
        frame = cv2.imread(fp)

        if is_sam2:
            for oid in range(10):
                mp = os.path.join(merged_masks_dir, f"OBJ{oid:02}_{idx:04}.png")
                if not os.path.exists(mp): continue
                mask = cv2.imread(mp, cv2.IMREAD_GRAYSCALE)
                _, bm = cv2.threshold(mask,127,255,cv2.THRESH_BINARY)
                col = colors[oid % len(colors)]
                cm = np.zeros((h, w, 3), dtype=np.uint8)
                cm[bm>0] = col
                frame = cv2.addWeighted(frame,1.0,cm,0.5,0)
                cnts,_ = cv2.findContours(bm,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
                if cnts:
                    x,y,ww,hh = cv2.boundingRect(max(cnts, key=cv2.contourArea))
                    cv2.rectangle(frame, (x,y),(x+ww,y+hh),col,2)
        else:
            mp = os.path.join(merged_masks_dir, f"mask_{idx}.png")
            if os.path.exists(mp):
                mask = cv2.imread(mp)
                if mask.shape[:2] != frame.shape[:2]:
                    mask = cv2.resize(mask, (w,h))
                frame = cv2.addWeighted(frame,1.0,mask,0.7,0)

        vw.write(frame)
        if i % 50 == 0 or i == len(original_frame_ids)-1:
            print(f"Frame processing progress: {i+1}/{len(original_frame_ids)}")

    vw.release()
    print(f"Generated final video: {out_vid}")
    return out_vid

def cleanup_temp_files(output_dir):
    temp_dirs = ["forward_frames","backward_frames","forward_masks","backward_masks","forward_rgba","backward_rgba"]
    temp_files = ["forward_bboxes.txt","backward_bboxes.txt",
                  "forward_mapping.txt","backward_mapping.txt",
                  "forward_result.mp4","backward_result.mp4"]
    for d in temp_dirs:
        p = os.path.join(output_dir, d)
        if os.path.exists(p):
            shutil.rmtree(p); print(f"Deleted directory: {p}")
    for f in temp_files:
        p = os.path.join(output_dir, f)
        if os.path.exists(p):
            os.remove(p); print(f"Deleted file: {p}")
    print("Temporary file cleanup completed!")

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    frame_id, bbox_file = get_detection_frame_id(args.input_dir)
    print(f"Detected keyframe ID: {frame_id}, bbox file: {bbox_file}")


    fwd_dir, fwd_masks, fwd_map = prepare_forward_frames(args.input_dir, args.output_dir, frame_id)
    fwd_bbox = create_bbox_file(bbox_file, args.output_dir, "forward_bboxes.txt")
    run_sam_segmentation(args.sam_script, fwd_dir, fwd_bbox, fwd_masks,
                         os.path.join(args.output_dir, "forward_result.mp4"),
                         args.model_path, args.device)
    first_fwd = build_combined_first_mask(fwd_masks)
    fwd_rgba = os.path.join(args.output_dir, "forward_rgba")
    run_matanyone(fwd_dir, first_fwd, fwd_rgba, "forward", args.matanyone_script)


    bwd_dir, bwd_masks, bwd_map = prepare_backward_frames(args.input_dir, args.output_dir, frame_id)
    bwd_bbox = create_bbox_file(bbox_file, args.output_dir, "backward_bboxes.txt")
    run_sam_segmentation(args.sam_script, bwd_dir, bwd_bbox, bwd_masks,
                         os.path.join(args.output_dir, "backward_result.mp4"),
                         args.model_path, args.device)
    first_bwd = build_combined_first_mask(bwd_masks)
    bwd_rgba = os.path.join(args.output_dir, "backward_rgba")
    run_matanyone(bwd_dir, first_bwd, bwd_rgba, "backward", args.matanyone_script)


    original_ids = get_original_frame_ids(args.input_dir)
    print(f"Found {len(original_ids)} original frames")


    alpha_dir = merge_alpha_masks(fwd_rgba, bwd_rgba, fwd_map, bwd_map,
                                  args.output_dir, original_ids)
    print(f"Alpha masks saved to: {alpha_dir}")


    merged_masks = merge_masks(fwd_masks, bwd_masks, fwd_map, bwd_map,
                               args.output_dir, original_ids)


    final_video = generate_final_video(args.input_dir, merged_masks,
                                       args.output_dir, original_ids, args.video_path)

    if not args.keep_temp:
        cleanup_temp_files(args.output_dir)

    print("Processing completed!")
    print(f"Final SAM masks saved at: {merged_masks}")
    print(f"Final alpha masks saved at: {alpha_dir}")
    print(f"Final segmentation visualization video: {final_video}")
    return 0

if __name__ == "__main__":
    sys.exit(main())