import clip
import torch
import numpy as np
import os
import json
from PIL import Image
import cv2
from tqdm import tqdm

def load_mask_info(mask_dir):
    """Load info.json file from mask directory"""
    info_path = os.path.join(mask_dir, "info.json")
    with open(info_path, 'r') as f:
        info = json.load(f)
    return info

def extract_masked_region(image, mask):
    """Extract masked region from image"""
    # Ensure mask is binary
    if len(mask.shape) == 3:
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    
    # Create 3-channel mask
    mask_3ch = np.stack([mask, mask, mask], axis=2)
    
    # Apply mask
    masked_image = image * (mask_3ch / 255.0)
    
    return masked_image

def get_all_valid_views(mask_dir, image_dir, category_id):
    """Get all valid views for specified category"""
    valid_views = []
    
    # Iterate through all mask files
    for mask_file in os.listdir(mask_dir):
        if mask_file.endswith('.png') and mask_file.startswith(f"{category_id}_"):
            mask_path = os.path.join(mask_dir, mask_file)
            
            # Extract corresponding image filename
            image_name = mask_file.split('_', 1)[1].replace('.png', '.jpg')
            image_path = os.path.join(image_dir, image_name)
            
            if os.path.exists(image_path):
                valid_views.append((image_path, mask_path))
    
    return valid_views

def extract_averaged_features(mask_dir, image_dir, model_name="ViT-B/32", device=None):
    """Averaged image matching: compute mean CLIP features from all valid views"""
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load CLIP model
    model, preprocess = clip.load(model_name, device=device)
    
    # Load mask info
    info = load_mask_info(mask_dir)
    total_categories = info["total_categories"]
    
    all_features = []
    all_labels = []
    
    print(f"Processing {total_categories} categories...")
    
    for category_id in tqdm(range(total_categories), desc="Extracting averaged features"):
        # Get all valid views for this category
        valid_views = get_all_valid_views(mask_dir, image_dir, category_id)
        
        if not valid_views:
            print(f"Warning: No valid images found for category {category_id}")
            continue
        
        category_features = []
        
        for image_path, mask_path in valid_views:
            try:
                # Load image and mask
                image = cv2.imread(image_path)
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                
                if image is None or mask is None:
                    continue
                
                # Check if mask is valid (non-empty)
                if np.sum(mask > 0) == 0:
                    continue
                
                # Extract masked region
                masked_image = extract_masked_region(image, mask)
                
                # Convert to PIL image
                masked_image_pil = Image.fromarray(cv2.cvtColor(masked_image.astype(np.uint8), cv2.COLOR_BGR2RGB))
                
                # Preprocess image
                image_tensor = preprocess(masked_image_pil).unsqueeze(0).to(device)
                
                # Extract CLIP features
                with torch.no_grad():
                    features = model.encode_image(image_tensor)
                    features = features / features.norm(dim=-1, keepdim=True)
                
                category_features.append(features.cpu().numpy())
                
            except Exception as e:
                print(f"Error processing {image_path}: {e}")
                continue
        
        # Compute mean features for all views of this category
        if category_features:
            category_features = np.vstack(category_features)
            averaged_features = np.mean(category_features, axis=0, keepdims=True)
            
            all_features.append(averaged_features)
            all_labels.append(category_id)
            
            print(f"Category {category_id}: {len(category_features)} valid views")
        else:
            print(f"Warning: No valid features extracted for category {category_id}")
    
    # Merge all data
    if all_features:
        all_features = np.vstack(all_features)
        all_labels = np.array(all_labels)
    else:
        all_features = np.array([])
        all_labels = np.array([])
    
    return all_labels, all_features

def save_to_npy(labels, features, prefix="clip_output"):
    """Save features and labels to npy files"""
    # Save to clip directory, consistent with extract_clip_features.py
    clip_dir = "clip"
    os.makedirs(clip_dir, exist_ok=True)
    np.save(os.path.join(clip_dir, f"{prefix}_features.npy"), features)
    np.save(os.path.join(clip_dir, f"{prefix}_labels.npy"), labels)
    print(f"Saved to {clip_dir}/{prefix}_features.npy and {clip_dir}/{prefix}_labels.npy")

if __name__ == "__main__":
    # Set paths
    mask_dir = "/path/to/masks"  # Mask directory path
    image_dir = "/path/to/images"  # Image directory path
    
    # Extract features
    labels, features = extract_averaged_features(mask_dir, image_dir)
    
    # Save results
    save_to_npy(labels, features)
    
    print(f"Extracted features shape: {features.shape}")
    print(f"Extracted labels shape: {labels.shape}")
