#!/usr/bin/env python3
"""
Clean RAM++ Inference Script
Based on inference_ram_plus.py with improved modularity and batch processing support
"""

import argparse
import os
import json
import time
from pathlib import Path
from typing import List, Dict, Optional, Union, Tuple

import torch
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np

# Add parent directory to path to import modules
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Try to import original RAM modules for comparison
try:
    from ram.models import ram_plus
    from ram import inference_ram as inference
    from ram import get_transform
    ORIGINAL_RAM_AVAILABLE = True
except ImportError:
    print("Warning: Original RAM modules not available")
    ORIGINAL_RAM_AVAILABLE = False


class ImageDataset(Dataset):
    """Simple dataset for loading images from a directory or file list"""
    
    def __init__(self, image_paths: List[str], transform=None):
        self.image_paths = image_paths
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return {
            'image': image,
            'image_path': image_path,
            'image_id': Path(image_path).stem
        }


class RAMPlusInference:
    """Clean RAM++ inference class with flexible model loading and batch processing"""
    
    def __init__(self, 
                 model_type: str = 'original',
                 pretrained_path: str = None,
                 image_size: int = 384,
                 device: str = 'auto',
                 threshold: float = 0.5,
                 **model_kwargs):
        """
        Initialize RAM++ inference
        
        Args:
            model_type: currently only 'original' (kept for backward compatibility)
            pretrained_path: path to pretrained model checkpoint
            image_size: input image size
            device: computation device ('auto', 'cuda', 'cpu')
            threshold: classification threshold
            **model_kwargs: additional model arguments
        """
        self.model_type = model_type
        self.pretrained_path = pretrained_path
        self.image_size = image_size
        self.threshold = threshold
        
        # Set device
        if device == 'auto':
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        elif device.startswith('cuda'):
            # Support cuda:0, cuda:1, etc.
            self.device = torch.device(device)
        else:
            self.device = torch.device(device)
        
        print(f"Using device: {self.device}")
        
        # Load model
        self.model = self._load_model(**model_kwargs)
        self.model.to(self.device)
        self.model.eval()
        
        # Get transform
        self.transform = self._get_transform()
        
        print(f"Model loaded successfully. Type: {model_type}")
    
    def _load_model(self, **kwargs):
        """Load model based on type"""
        if self.model_type != 'original':
            raise ValueError(f"Unknown model type: {self.model_type}")

        if not ORIGINAL_RAM_AVAILABLE:
            raise ImportError("Original RAM modules not available")

        model = ram_plus(
            pretrained=self.pretrained_path,
            image_size=self.image_size,
            vit=kwargs.get('vit', 'swin_l')
        )
        return model
    
    def _get_transform(self):
        """Get image transform"""
        if self.model_type == 'original' and ORIGINAL_RAM_AVAILABLE:
            return get_transform(image_size=self.image_size)

        from torchvision import transforms
        return transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])
        ])
    
    def predict_single(self, image_input: Union[str, Image.Image, torch.Tensor]) -> Dict:
        """
        Predict on a single image
        
        Args:
            image_input: image path, PIL Image, or preprocessed tensor
            
        Returns:
            prediction dictionary
        """
        # Process input
        if isinstance(image_input, str):
            image = Image.open(image_input).convert('RGB')
            image_path = image_input
        elif isinstance(image_input, Image.Image):
            image = image_input
            image_path = 'PIL_Image'
        elif isinstance(image_input, torch.Tensor):
            # Already preprocessed
            if image_input.dim() == 3:
                image_tensor = image_input.unsqueeze(0).to(self.device)
            else:
                image_tensor = image_input.to(self.device)
        else:
            raise ValueError(f"Unsupported image input type: {type(image_input)}")
        
        # Transform if needed
        if not isinstance(image_input, torch.Tensor):
            image_tensor = self.transform(image).unsqueeze(0).to(self.device)
        
        # Inference
        with torch.no_grad():
            if self.model_type != 'original':
                raise ValueError(f"Unknown model type: {self.model_type}")

            res = inference(image_tensor, self.model)
            return {
                'image_path': image_path if 'image_path' in locals() else 'unknown',
                'tags_en': res[0],
                'tags_zh': res[1] if len(res) > 1 else None,
                'raw_output': res
            }
    
    def predict_batch(self, dataloader: DataLoader, max_batches: Optional[int] = None) -> List[Dict]:
        """
        Predict on a batch of images using DataLoader
        
        Args:
            dataloader: DataLoader containing images
            max_batches: maximum number of batches to process (None for all)
            
        Returns:
            list of prediction dictionaries
        """
        results = []
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(dataloader):
                if max_batches and batch_idx >= max_batches:
                    break
                
                images = batch['image'].to(self.device)
                batch_size = images.shape[0]
                
                if self.model_type != 'original':
                    raise ValueError(f"Unknown model type: {self.model_type}")

                # Process each image in batch individually for original model
                for i in range(batch_size):
                    image_tensor = images[i:i+1]
                    res = inference(image_tensor, self.model)

                    result = {
                        'image_path': batch.get('image_path', ['unknown'])[i],
                        'image_id': batch.get('image_id', ['unknown'])[i],
                        'tags_en': res[0],
                        'tags_zh': res[1] if len(res) > 1 else None,
                        'raw_output': res
                    }
                    results.append(result)
                
                if batch_idx % 10 == 0:
                    print(f"Processed batch {batch_idx + 1}/{len(dataloader)}")
        
        return results
    

def create_image_dataloader(image_dir: str, 
                          image_size: int = 384, 
                          batch_size: int = 16,
                          num_workers: int = 4) -> DataLoader:
    """Create DataLoader for images in a directory"""
    import torchvision.transforms as transforms
    
    # Get all image files
    image_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
    image_paths = []
    
    for ext in image_exts:
        image_paths.extend(Path(image_dir).glob(f'*{ext}'))
        image_paths.extend(Path(image_dir).glob(f'*{ext.upper()}'))
    
    image_paths = sorted([str(p) for p in image_paths])
    
    if not image_paths:
        raise ValueError(f"No images found in {image_dir}")
    
    print(f"Found {len(image_paths)} images in {image_dir}")
    
    # Create transform
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Create dataset and dataloader
    dataset = ImageDataset(image_paths, transform=transform)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return dataloader


def main():
    parser = argparse.ArgumentParser(description='Clean RAM++ Inference Script')
    
    # Model arguments
    parser.add_argument('--model-type', default='original', choices=['original'],
                       help='Model type to use')
    parser.add_argument('--pretrained', 
                       default='pretrained/ram_plus_swin_large_14m.pth',
                       help='Path to pretrained model')
    parser.add_argument('--image-size', default=384, type=int,
                       help='Input image size')
    parser.add_argument('--threshold', default=0.5, type=float,
                       help='Classification threshold')
    parser.add_argument('--vit', default='swin_l', choices=['swin_b', 'swin_l'],
                       help='Vision transformer size')
    
    # Input arguments
    parser.add_argument('--image', type=str,
                       help='Single image path')
    parser.add_argument('--image-dir', type=str,
                       help='Directory containing images')
    
    # Processing arguments
    parser.add_argument('--batch-size', default=16, type=int,
                       help='Batch size for processing')
    parser.add_argument('--num-workers', default=4, type=int,
                       help='Number of workers for data loading')
    parser.add_argument('--max-batches', type=int,
                       help='Maximum number of batches to process')
    parser.add_argument('--device', default='auto', type=str,
                       help='Device to use (auto, cpu, cuda, cuda:0, cuda:1, etc.)')
    
    # Output arguments
    parser.add_argument('--output-dir', type=str,
                       help='Output directory for results')
    parser.add_argument('--save-results', action='store_true',
                       help='Save results to JSON file')
    
    args = parser.parse_args()
    
    # Initialize inference
    inferencer = RAMPlusInference(
        model_type=args.model_type,
        pretrained_path=args.pretrained,
        image_size=args.image_size,
        threshold=args.threshold,
        device=args.device,
        vit=args.vit
    )
    
    results = []
    
    # Single image inference
    if args.image:
        print(f"Processing single image: {args.image}")
        result = inferencer.predict_single(args.image)
        results.append(result)
        
        # Print results
        print(f"English Tags: {result['tags_en']}")
        if result['tags_zh']:
            print(f"Chinese Tags: {result['tags_zh']}")
    
    # Batch processing
    elif args.image_dir:
        print(f"Processing images in directory: {args.image_dir}")
        dataloader = create_image_dataloader(
            args.image_dir,
            image_size=args.image_size,
            batch_size=args.batch_size,
            num_workers=args.num_workers
        )
        
        start_time = time.time()
        results = inferencer.predict_batch(dataloader, max_batches=args.max_batches)
        end_time = time.time()
        
        print(f"Processed {len(results)} images in {end_time - start_time:.2f} seconds")
        print(f"Average time per image: {(end_time - start_time) / len(results):.3f} seconds")
    
    else:
        print("Please specify --image or --image-dir")
        return
    
    # Save results if requested
    if args.save_results and args.output_dir:
        os.makedirs(args.output_dir, exist_ok=True)
        
        if results:
            output_file = os.path.join(args.output_dir, 'inference_results.json')
            # Convert numpy arrays to lists for JSON serialization
            json_results = []
            for result in results:
                json_result = {}
                for key, value in result.items():
                    if isinstance(value, np.ndarray):
                        json_result[key] = value.tolist()
                    else:
                        json_result[key] = value
                json_results.append(json_result)
            
            with open(output_file, 'w') as f:
                json.dump(json_results, f, indent=2)
            print(f"Results saved to {output_file}")


if __name__ == "__main__":
    main()
