


import os
import os.path as osp
import numpy as np
import torch
import cv2
from tqdm import tqdm
import argparse
import multiprocessing
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
import logging
from functools import partial
from insightface.app import FaceAnalysis


logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(process)d - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("face_embedding_precompute.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger()

def get_face_app(gpu_id=0):
    """Initialize face analysis app on specific GPU"""

    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    

    face_app = FaceAnalysis(
        name="path to /DanceTogether/models/buffalo_l",
        providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
    )

    face_app.prepare(ctx_id=0, det_size=(640, 640))
    return face_app

def crop_face(ref_tensor, mask_path, width, height):
    """Extract face crop from reference tensor using mask"""
    if not os.path.exists(mask_path):
        return None
    m = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    if m is None or m.sum() == 0:
        return None
    m = cv2.resize(m, (width, height))
    pos = torch.nonzero(torch.from_numpy(m))
    if pos.numel() == 0:
        return None
    y0, x0 = pos.min(0).values
    y1, x1 = pos.max(0).values
    pad = 20
    y0 = max(0, int(y0) - pad); y1 = min(height, int(y1) + pad)
    x0 = max(0, int(x0) - pad); x1 = min(width, int(x1) + pad)
    crop = ref_tensor[:, y0:y1, x0:x1]

    crop = crop.float()
    crop = torch.nn.functional.interpolate(crop.unsqueeze(0), size=[112, 112], mode='bilinear', align_corners=False).squeeze(0)
    return crop.contiguous()

def process_frame(frame_name, frames_path, face_masks_path0, face_masks_path1, embeddings_dir, width, height, face_app):
    """Process a single frame to extract face embeddings"""
    try:
        frame_path = osp.join(frames_path, frame_name)
        frame_basename = osp.splitext(osp.basename(frame_path))[0] + ".jpg"
        

        frame_output_name = osp.splitext(frame_name)[0]
        embed0_path = osp.join(embeddings_dir, f"{frame_output_name}_embed0.npy")
        embed1_path = osp.join(embeddings_dir, f"{frame_output_name}_embed1.npy")
        

        if (os.path.exists(embed0_path) and os.path.getsize(embed0_path) > 0 and 
            os.path.exists(embed1_path) and os.path.getsize(embed1_path) > 0):
            logger.debug(f"Skipping {frame_name}: embeddings already exist")
            return frame_name, True


        face_mask_path0 = osp.join(face_masks_path0, frame_basename)
        face_mask_path1 = osp.join(face_masks_path1, frame_basename)
        

        person1_exists = os.path.exists(face_masks_path1) and os.path.exists(face_mask_path1)
        

        ref_bgr = cv2.imread(frame_path)
        if ref_bgr is None:
            logger.warning(f"Could not read {frame_path}. Skipping.")
            return frame_name, False
            
        ref_bgr = cv2.resize(ref_bgr, (width, height))
        ref_rgb = cv2.cvtColor(ref_bgr, cv2.COLOR_BGR2RGB)

        ref_tensor = torch.from_numpy(ref_rgb).to(dtype=torch.float32, device='cuda', non_blocking=True)
        ref_tensor = ref_tensor.permute(2, 0, 1)
        

        crop0 = crop_face(ref_tensor, face_mask_path0, width, height)
        crop1 = crop_face(ref_tensor, face_mask_path1, width, height) if person1_exists else None
        

        embeds = []
        crops = [c for c in (crop0, crop1) if c is not None]
        
        for t in crops:
            cv_img = t.permute(1, 2, 0).cpu().numpy()

            cv_img = (cv_img * 255).astype(np.uint8) if cv_img.dtype == np.float32 else cv_img
            faces = face_app.get(cv_img)
            if faces:
                faces.sort(key=lambda x: x.bbox[2]*x.bbox[3])
                embeds.append(faces[-1].embedding)
            else:
                embeds.append(np.zeros(512, np.float32))
        

        if len(embeds) == 0:
            embedding0 = np.zeros(512, np.float32)
            embedding1 = np.zeros(512, np.float32)
        elif len(embeds) == 1:
            if crop0 is not None:
                embedding0 = embeds[0]
                embedding1 = np.zeros(512, np.float32)
            else:
                embedding0 = np.zeros(512, np.float32)
                embedding1 = embeds[0]
        else:
            embedding0 = embeds[0]
            embedding1 = embeds[1]
        

        np.save(embed0_path, embedding0)
        np.save(embed1_path, embedding1)
        
        return frame_name, True
    except Exception as e:
        logger.error(f"Error processing frame {frame_name}: {e}")
        return frame_name, False

def process_video(video_path, width, height, gpu_id):
    """Process all frames in a video to extract face embeddings"""
    try:


        face_app = get_face_app(gpu_id)
        
        logger.info(f"GPU {gpu_id} initialized for processing {video_path}")
        

        frames_path = osp.join(video_path, "images")
        face_masks_path0 = osp.join(video_path, "faces/person_0")
        face_masks_path1 = osp.join(video_path, "faces/person_1")
        

        if not os.path.exists(frames_path):
            logger.warning(f"Path {frames_path} does not exist. Skipping.")
            return video_path, 0, 0
            

        files = os.listdir(frames_path)
        image_files = [file for file in files if file.endswith('.png') or file.endswith('.jpg')]
        
        if not image_files:
            logger.warning(f"No image files found in {frames_path}. Skipping.")
            return video_path, 0, 0
            
        if image_files[0].startswith('frame_'):
            image_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
        else:
            image_files.sort(key=lambda x: int(x.split('.')[0]))
        

        embeddings_dir = osp.join(video_path, "embeddings")
        os.makedirs(embeddings_dir, exist_ok=True)
        

        success_count = 0
        total_frames = len(image_files)


        num_threads = min(12, total_frames)
        
        logger.info(f"GPU {gpu_id} processing {video_path} with {num_threads} threads for {total_frames} frames")
        
        process_frame_with_params = partial(
            process_frame, 
            frames_path=frames_path, 
            face_masks_path0=face_masks_path0, 
            face_masks_path1=face_masks_path1, 
            embeddings_dir=embeddings_dir, 
            width=width, 
            height=height, 
            face_app=face_app
        )
        
        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            futures = {executor.submit(process_frame_with_params, frame): frame for frame in image_files}
            
            with tqdm(total=total_frames, desc=f"GPU {gpu_id} - Processing frames") as pbar:
                for future in as_completed(futures):
                    frame = futures[future]
                    try:
                        _, success = future.result()
                        if success:
                            success_count += 1
                        pbar.update(1)
                    except Exception as e:
                        logger.error(f"Error in thread processing frame {frame}: {e}")
                        pbar.update(1)
        
        logger.info(f"GPU {gpu_id} completed processing {video_path}: {success_count}/{total_frames} frames")
        return video_path, success_count, total_frames
    
    except Exception as e:
        logger.error(f"Error processing video {video_path}: {e}")
        return video_path, 0, 0

def worker_process(gpu_id, video_files, width, height):
    """Worker process for each GPU"""

    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    
    logger.info(f"GPU {gpu_id} worker started with {len(video_files)} videos assigned")
    logger.info(f"GPU {gpu_id} CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")
    
    results = []
    
    for video_idx, video_path in enumerate(video_files):
        try:
            result = process_video(video_path, width, height, gpu_id)
            results.append(result)
            

            vid_path, success_count, total_frames = result
            vid_name = osp.basename(vid_path)
            logger.info(f"GPU {gpu_id} - Video {video_idx+1}/{len(video_files)} - {vid_name}: {success_count}/{total_frames} frames processed")
            
        except Exception as e:
            logger.error(f"GPU {gpu_id} - Error processing video {video_path}: {e}")
    
    logger.info(f"GPU {gpu_id} worker completed all {len(video_files)} videos")
    return results

def distribute_workload(txt_path, width=512, height=512, num_gpus=8):
    """Distribute video processing workload across multiple GPUs"""

    with open(txt_path, 'r') as file:
        video_files = [line.strip() for line in file.readlines()]
    
    logger.info(f"Found {len(video_files)} videos to process")
    

    video_batches = []
    batch_size = len(video_files) // num_gpus
    for i in range(num_gpus):
        if i == num_gpus - 1:

            video_batches.append(video_files[i * batch_size:])
        else:
            video_batches.append(video_files[i * batch_size:(i + 1) * batch_size])
    
    logger.info(f"Distributed videos across {num_gpus} GPUs")
    for i, batch in enumerate(video_batches):
        logger.info(f"GPU {i}: {len(batch)} videos")
    

    processes = []
    for gpu_id in range(num_gpus):

        if not video_batches[gpu_id]:
            continue
            

        p = multiprocessing.Process(
            target=worker_process,
            args=(gpu_id, video_batches[gpu_id], width, height)
        )
        processes.append(p)
        p.start()

        time.sleep(0.5)
    

    for p in processes:
        p.join()
    
    logger.info("All processes completed")

if __name__ == "__main__":

    multiprocessing.set_start_method('spawn', force=True)
    
    parser = argparse.ArgumentParser(description="Pre-compute face embeddings for video dataset using multiple GPUs")
    parser.add_argument("--txt_path", type=str, required=True, help="Path to the dataset list text file")
    parser.add_argument("--width", type=int, default=512, help="Frame width")
    parser.add_argument("--height", type=int, default=512, help="Frame height")
    parser.add_argument("--num_gpus", type=int, default=8, help="Number of GPUs to use")
    
    args = parser.parse_args()
    

    start_time = time.time()
    

    distribute_workload(args.txt_path, args.width, args.height, args.num_gpus)
    

    elapsed_time = time.time() - start_time
    logger.info(f"Total execution time: {elapsed_time:.2f} seconds ({elapsed_time/3600:.2f} hours)")


