import clip
import torch
import numpy as np
import os
import json
from PIL import Image
import cv2
from tqdm import tqdm

def load_mask_info(mask_dir):
    """Load info.json file from mask directory"""
    info_path = os.path.join(mask_dir, "info.json")
    with open(info_path, 'r') as f:
        info = json.load(f)
    return info

def extract_masked_region(image, mask):
    """Extract masked region from image"""
    # Ensure mask is binary
    if len(mask.shape) == 3:
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    
    # Create 3-channel mask
    mask_3ch = np.stack([mask, mask, mask], axis=2)
    
    # Apply mask
    masked_image = image * (mask_3ch / 255.0)
    
    return masked_image

def get_largest_mask_view(mask_dir, image_dir, category_id):
    """Get view with largest mask area for specified category"""
    max_area = 0
    best_image_path = None
    best_mask_path = None
    
    # Iterate through all mask files
    for mask_file in os.listdir(mask_dir):
        if mask_file.endswith('.png') and mask_file.startswith(f"{category_id}_"):
            mask_path = os.path.join(mask_dir, mask_file)
            
            # Extract corresponding image filename
            image_name = mask_file.split('_', 1)[1].replace('.png', '.jpg')
            image_path = os.path.join(image_dir, image_name)
            
            if not os.path.exists(image_path):
                continue
                
            # Load mask and calculate area
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            if mask is None:
                continue
                
            mask_area = np.sum(mask > 0)
            
            if mask_area > max_area:
                max_area = mask_area
                best_image_path = image_path
                best_mask_path = mask_path
    
    return best_image_path, best_mask_path

def extract_single_view_features(mask_dir, image_dir, model_name="ViT-B/32", device=None):
    """Single-view image matching: extract CLIP features from view with largest mask area"""
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load CLIP model
    model, preprocess = clip.load(model_name, device=device)
    
    # Load mask info
    info = load_mask_info(mask_dir)
    total_categories = info["total_categories"]
    
    all_features = []
    all_labels = []
    
    print(f"Processing {total_categories} categories...")
    
    for category_id in tqdm(range(total_categories), desc="Extracting single-view features"):
        # Get view with largest mask area for this category
        best_image_path, best_mask_path = get_largest_mask_view(mask_dir, image_dir, category_id)
        
        if best_image_path is None or best_mask_path is None:
            print(f"Warning: No valid images found for category {category_id}")
            continue
        
        try:
            # Load image and mask
            image = cv2.imread(best_image_path)
            mask = cv2.imread(best_mask_path, cv2.IMREAD_GRAYSCALE)
            
            if image is None or mask is None:
                continue
            
            # Extract masked region
            masked_image = extract_masked_region(image, mask)
            
            # Convert to PIL image
            masked_image_pil = Image.fromarray(cv2.cvtColor(masked_image.astype(np.uint8), cv2.COLOR_BGR2RGB))
            
            # Preprocess image
            image_tensor = preprocess(masked_image_pil).unsqueeze(0).to(device)
            
            # Extract CLIP features
            with torch.no_grad():
                features = model.encode_image(image_tensor)
                features = features / features.norm(dim=-1, keepdim=True)
            
            all_features.append(features.cpu().numpy())
            all_labels.append(category_id)
            
        except Exception as e:
            print(f"Error processing category {category_id}: {e}")
            continue
    
    # Merge all data
    if all_features:
        all_features = np.vstack(all_features)
        all_labels = np.array(all_labels)
    else:
        all_features = np.array([])
        all_labels = np.array([])
    
    return all_labels, all_features

def save_to_npy(labels, features, prefix="clip_output"):
    """Save features and labels to npy files"""
    # Save to clip directory, consistent with extract_clip_features.py
    clip_dir = "clip"
    os.makedirs(clip_dir, exist_ok=True)
    np.save(os.path.join(clip_dir, f"{prefix}_features.npy"), features)
    np.save(os.path.join(clip_dir, f"{prefix}_labels.npy"), labels)
    print(f"Saved to {clip_dir}/{prefix}_features.npy and {clip_dir}/{prefix}_labels.npy")

if __name__ == "__main__":
    # Set paths
    mask_dir = "/path/to/masks"  # Mask directory path
    image_dir = "/path/to/images"  # Image directory path
    
    # Extract features
    labels, features = extract_single_view_features(mask_dir, image_dir)
    
    # Save results
    save_to_npy(labels, features)
    
    print(f"Extracted features shape: {features.shape}")
    print(f"Extracted labels shape: {labels.shape}")
