import os
import random
import h5py
import numpy as np
import torch
from pathlib import Path
from tqdm import tqdm
from PIL import Image
import logging
import torch.nn.functional as F
import gc

# Configure logger
logger = logging.getLogger(__name__)

def extract_places_embeddings(
    model, 
    preprocess, 
    places_dir, 
    output_file, 
    scene_names,
    samples_per_scene, 
    batch_size=32, 
    device="cuda:0"
):
    """
    Extract CLIP embeddings for specified scenes from the Places dataset and save all scenes in one file.
    
    Args:
        model: CLIP model
        preprocess: CLIP preprocessing function
        places_dir: Path to the Places dataset directory
        output_file: Path to output h5 file
        scene_names: List of scene names to process, if None, all scenes will be processed
        samples_per_scene: Number of images to process per scene
        batch_size: Batch size for processing
        device: Computation device
        
    Returns:
        bool: Whether embeddings were successfully extracted and saved
    """
    
    # Ensure output directory exists
    output_file = Path(output_file)
    output_file.parent.mkdir(parents=True, exist_ok=True)
    
    # Get all scene directories
    places_dir = Path(places_dir)
    all_scene_dirs = {d.name: d for d in places_dir.iterdir() if d.is_dir()}
    
    if len(all_scene_dirs) == 0:
        logger.error(f"No scene directories found in {places_dir}")
        return False
    
    # Determine which scenes to process
    if scene_names is None:
        logger.info("No scene names specified, will process all available scenes")
        selected_scenes = list(all_scene_dirs.values())
        selected_scene_names = list(all_scene_dirs.keys())
    else:
        # Filter valid scene names
        valid_scene_names = []
        selected_scenes = []
        for scene_name in scene_names:
            if scene_name in all_scene_dirs:
                valid_scene_names.append(scene_name)
                selected_scenes.append(all_scene_dirs[scene_name])
            else:
                logger.warning(f"Scene '{scene_name}' does not exist in the dataset, will be skipped")
        
        if not valid_scene_names:
            logger.error("No valid scenes found, processing terminated")
            return False
            
        selected_scene_names = valid_scene_names
    
    num_scenes = len(selected_scenes)
    logger.info(f"Will process {num_scenes} scenes: {', '.join(selected_scene_names[:5])}{'...' if num_scenes > 5 else ''}")
    
    # Create H5 file and save all scenes
    with h5py.File(output_file, 'w') as f:
        # Create scene names dataset as index
        scene_names_dataset = f.create_dataset('scene_names', shape=(num_scenes,), dtype=h5py.special_dtype(vlen=str))
        
        # Create scenes group
        scenes_group = f.create_group('scenes')
        
        # Process each scene
        success_count = 0
        for scene_idx, (scene_name, scene_dir) in enumerate(zip(selected_scene_names, selected_scenes)):
            scene_names_dataset[scene_idx] = scene_name
            
            # Create group for this scene
            scene_group = scenes_group.create_group(scene_name)
            scene_group.attrs['index'] = scene_idx
            
            # Get all images in the scene
            image_files = list(scene_dir.glob('*.jpg')) + list(scene_dir.glob('*.png'))
            if not image_files:
                logger.warning(f"No image files found in scene {scene_name}")
                continue
            
            # Randomly select images
            samples = min(samples_per_scene, len(image_files))
            selected_images = random.sample(image_files, samples)
            
            # Prepare containers for this scene's data
            scene_embeddings = []
            scene_image_paths = []
            
            # Process images in batches
            for i in tqdm(range(0, len(selected_images), batch_size), desc=f"Processing scene {scene_name}"):
                batch_images = selected_images[i:i+batch_size]
                
                # Load and preprocess images
                batch_tensors = []
                batch_paths = []
                for img_path in batch_images:
                    try:
                        img = Image.open(img_path).convert('RGB')
                        img_tensor = preprocess(img).unsqueeze(0)
                        batch_tensors.append(img_tensor)
                        batch_paths.append(str(img_path))
                    except Exception as e:
                        logger.warning(f"Failed to process image {img_path}: {e}")
                
                if not batch_tensors:
                    continue
                    
                # Concatenate batch and compute features
                batch = torch.cat(batch_tensors)
                
                with torch.no_grad():
                    batch_embeddings = model.encode_image(batch.to(device), normalize=False)
                    # Normalize features
                    batch_embeddings = F.normalize(batch_embeddings, dim=-1).cpu().numpy()
                    
                scene_embeddings.append(batch_embeddings)
                scene_image_paths.extend(batch_paths)
                
                # Clean GPU memory
                del batch
                del batch_tensors
                torch.cuda.empty_cache()
            
            # If there are processed images, save to the scene group
            if scene_embeddings:
                scene_embeddings = np.concatenate(scene_embeddings, axis=0)
                total_images = len(scene_image_paths)
                
                # Save embeddings and image paths
                scene_group.create_dataset('embeddings', data=scene_embeddings)
                image_paths_dataset = scene_group.create_dataset('image_paths', shape=(total_images,), 
                                                            dtype=h5py.special_dtype(vlen=str))
                
                # Save image paths
                for i, path in enumerate(scene_image_paths):
                    image_paths_dataset[i] = path
                
                logger.info(f"Scene {scene_name}: Successfully processed {total_images} images, embedding shape: {scene_embeddings.shape}")
                success_count += 1
            else:
                logger.warning(f"Scene {scene_name}: No images were successfully processed")
            
            # Clean up memory after scene processing
            del scene_embeddings
            del scene_image_paths
            gc.collect()
            torch.cuda.empty_cache()
        
        # Record overall statistics
        f.attrs['total_scenes'] = num_scenes
        f.attrs['processed_scenes'] = success_count
        logger.info(f"Successfully processed {success_count}/{num_scenes} scenes in total")
        
        return success_count > 0


# Example usage
if __name__ == "__main__":
    import argparse
    import sys
    project_root = os.path.abspath(os.getcwd())
    sys.path.append(project_root)
    from CLIP_utils.factory import create_model_and_transforms, get_tokenizer
    
    # Configure logging
    logging.basicConfig(level=logging.INFO, 
                        format="%(asctime)s [%(levelname)s] %(message)s",
                        datefmt="%Y-%m-%d %H:%M")
    
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Extract scene embeddings from Places dataset")
    parser.add_argument("--places_dir", type=str, default="./data/places/train", help="Places dataset directory")
    parser.add_argument("--output_dir", type=str, default="./results/places_embeddings", help="Output directory")
    parser.add_argument("--model", type=str, default="ViT-B-32", 
                        choices=["ViT-B-32", "ViT-B-16", "ViT-L-14", "ViT-H-14"], 
                        help="CLIP model name")
    parser.add_argument("--samples_per_scene", type=int, default=50, help="Number of images to process per scene")
    parser.add_argument("--dataset", type=str, default="nico", help="Dataset name")
    parser.add_argument("--batch_size", type=int, default=50, help="Batch size")
    parser.add_argument("--cuda_id", type=str, default="0", help="CUDA device ID")
    args = parser.parse_args()
    
    # Set random seed
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)
    
    # Set device
    device = torch.device(f"cuda:{args.cuda_id}" if torch.cuda.is_available() else "cpu")
    
    # Load model
    if args.model == "ViT-B-32":
        model, _, preprocess = create_model_and_transforms(args.model, pretrained="laion2b_s34b_b79k")
    elif args.model == "ViT-B-16":
        model, _, preprocess = create_model_and_transforms(args.model, pretrained="laion2b_s34b_b88k")
    elif args.model == "ViT-L-14":
        model, _, preprocess = create_model_and_transforms(args.model, pretrained="laion2b_s32b_b82k")
    elif args.model == "ViT-H-14":
        model, _, preprocess = create_model_and_transforms(args.model, pretrained="laion2b_s32b_b79k")
    else:
        raise ValueError(f"Unsupported model: {args.model}")
    
    model.to(device)
    model.eval()
    
    # Set output file path
    output_file = Path(args.output_dir) / f"{args.model}_{args.dataset}_places_embeddings.h5"
    
    if args.dataset == "waterbirds":
        # Various backgrounds for waterbirds and landbirds
        scenes = [
            "marsh", "pond", "lake-natural", "river", "swamp", "creek", "canal-natural", "fishpond",
            "forest-broadleaf", "bamboo_forest", "orchard", "field-cultivated", "field-wild",
            "pasture", "farm", "forest_path"
        ]
    elif args.dataset == "urbancars":
        # Urban scenes + rural scenes
        scenes = [
            "street", "downtown", "parking_lot", "parking_garage-outdoor", "highway",
            "gas_station", "viaduct", "crosswalk",
            "field-cultivated", "field_road", "forest_road", "village", "farm", "pasture", "barn"
        ]
    elif args.dataset == "cocogb":
        # Typical daily scenes for COCO gender bias dataset
        scenes = [
            "living_room", "bedroom", "kitchen", "bathroom", "dining_room",
            "office", "classroom", "hotel_room", "restaurant", "shopping_mall-indoor",
            "bus_station-indoor", "supermarket", "park", "street", "playground"
        ]
    elif args.dataset == "imagenet":
        # General natural and artificial scenes for ImageNet
        scenes = [
            "forest-broadleaf", "bamboo_forest", "mountain", "mountain_snowy",
            "desert-vegetation", "beach", "lake-natural", "river", "waterfall",
            "canyon", "cliff", "snowfield", "street", "building_facade", "office_building"
        ]
    elif args.dataset == "nico":
                scenes = [
                    "beach",              
                    "bridge",             
                    "desert_road",        
                    "forest_path",        
                    "river",               
                    "snowfield",         
                    "street",           
                    "residential_neighborhood", 
                    "home_office",      
                    "mountain_snowy",     
                    "pasture",            
                    "airplane_cabin",      
                    "train_station-platform",
                    "garage-outdoor",   
                    "amphitheater",          
                    "corral"      
                ]
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")

    # Clean memory before processing
    gc.collect()
    torch.cuda.empty_cache()
    
    # Extract and save embeddings
    try:
        extract_places_embeddings(
            model=model,
            preprocess=preprocess,
            places_dir=args.places_dir,
            output_file=output_file,
            scene_names=scenes,
            samples_per_scene=args.samples_per_scene,
            batch_size=args.batch_size,
            device=device
        )
    finally:
        # Ensure memory is released when program ends
        del model
        gc.collect()
        torch.cuda.empty_cache()
        logger.info("Program completed, memory released")