#!/usr/bin/env python3
"""
Dataset splitting script for Caltech-256
Splits the dataset into train/validation/test splits while maintaining
the same class distribution proportions across all splits.
"""

import os
import argparse
import shutil
import random
from pathlib import Path
from collections import defaultdict
import json

import numpy as np
from tqdm import tqdm


def get_class_distribution(dataset_path):
    """
    Analyze the class distribution in the dataset.
    
    Args:
        dataset_path (str): Path to the dataset directory
        
    Returns:
        dict: Dictionary with class names as keys and counts as values
    """
    dataset_path = Path(dataset_path)
    class_counts = {}
    
    print("Analyzing dataset class distribution...")
    
    for class_dir in tqdm(sorted(dataset_path.iterdir()), desc="Analyzing classes"):
        if class_dir.is_dir():
            class_name = class_dir.name
            # Count all image files
            image_files = list(class_dir.glob("*.jpg")) + list(class_dir.glob("*.jpeg")) + list(class_dir.glob("*.png"))
            class_counts[class_name] = len(image_files)
    
    return class_counts


def split_dataset(dataset_path, output_path, train_ratio=0.7, val_ratio=0.20, test_ratio=0.10, seed=42):
    """
    Split the dataset into train/validation/test splits while maintaining class distribution.
    
    Args:
        dataset_path (str): Path to the original dataset
        output_path (str): Path to output directory
        train_ratio (float): Proportion for training set (default: 0.7)
        val_ratio (float): Proportion for validation set (default: 0.15)
        test_ratio (float): Proportion for test set (default: 0.15)
        seed (int): Random seed for reproducibility
    """
    # Validate ratios
    total_ratio = train_ratio + val_ratio + test_ratio
    if abs(total_ratio - 1.0) > 1e-6:
        raise ValueError(f"Ratios must sum to 1.0, got {total_ratio}")
    
    # Set random seed
    random.seed(seed)
    np.random.seed(seed)
    
    dataset_path = Path(dataset_path)
    output_path = Path(output_path)
    
    # Create output directories
    train_path = output_path / "train"
    val_path = output_path / "val"
    test_path = output_path / "test"
    
    for split_path in [train_path, val_path, test_path]:
        split_path.mkdir(parents=True, exist_ok=True)
    
    # Get class distribution
    class_counts = get_class_distribution(dataset_path)
    
    # Calculate split sizes for each class
    split_info = {}
    for class_name, count in class_counts.items():
        train_size = int(count * train_ratio)
        val_size = int(count * val_ratio)
        test_size = count - train_size - val_size  # Remaining goes to test
        
        split_info[class_name] = {
            'total': count,
            'train': train_size,
            'val': val_size,
            'test': test_size
        }
    
    # Print split summary
    print("\n" + "="*60)
    print("DATASET SPLIT SUMMARY")
    print("="*60)
    print(f"{'Class':<20} {'Total':<8} {'Train':<8} {'Val':<8} {'Test':<8}")
    print("-"*60)
    
    total_images = 0
    total_train = 0
    total_val = 0
    total_test = 0
    
    for class_name, info in split_info.items():
        print(f"{class_name:<20} {info['total']:<8} {info['train']:<8} {info['val']:<8} {info['test']:<8}")
        total_images += info['total']
        total_train += info['train']
        total_val += info['val']
        total_test += info['test']
    
    print("-"*60)
    print(f"{'TOTAL':<20} {total_images:<8} {total_train:<8} {total_val:<8} {total_test:<8}")
    print(f"Ratios: {100*train_ratio:.1f}% / {100*val_ratio:.1f}% / {100*test_ratio:.1f}%")
    print("="*60)
    
    # Perform the split
    print("\nPerforming dataset split...")
    
    for class_name in tqdm(split_info.keys(), desc="Processing classes"):
        class_dir = dataset_path / class_name
        
        # Get all image files for this class
        image_files = list(class_dir.glob("*.jpg")) + list(class_dir.glob("*.jpeg")) + list(class_dir.glob("*.png"))
        image_files = sorted(image_files)  # Sort for reproducibility
        
        # Shuffle the files
        random.shuffle(image_files)
        
        # Split the files
        split_sizes = split_info[class_name]
        train_files = image_files[:split_sizes['train']]
        val_files = image_files[split_sizes['train']:split_sizes['train'] + split_sizes['val']]
        test_files = image_files[split_sizes['train'] + split_sizes['val']:]
        
        # Create class directories in each split
        for split_path, files in [(train_path, train_files), (val_path, val_files), (test_path, test_files)]:
            class_split_dir = split_path / class_name
            class_split_dir.mkdir(exist_ok=True)
            
            # Copy files
            for file_path in files:
                shutil.copy2(file_path, class_split_dir / file_path.name)
    
    # Save split information
    split_summary = {
        'dataset_path': str(dataset_path),
        'output_path': str(output_path),
        'split_ratios': {
            'train': train_ratio,
            'val': val_ratio,
            'test': test_ratio
        },
        'seed': seed,
        'class_distributions': split_info,
        'total_images': total_images,
        'split_totals': {
            'train': total_train,
            'val': total_val,
            'test': total_test
        }
    }
    
    summary_path = output_path / "split_summary.json"
    with open(summary_path, 'w') as f:
        json.dump(split_summary, f, indent=2)
    
    print(f"\nSplit summary saved to: {summary_path}")
    
    return split_info


def verify_split(output_path, split_info):
    """
    Verify that the split was performed correctly.
    
    Args:
        output_path (str): Path to the split dataset
        split_info (dict): Information about the split
    """
    print("\nVerifying split...")
    
    output_path = Path(output_path)
    verification_results = {}
    
    for split_name in ['train', 'val', 'test']:
        split_path = output_path / split_name
        verification_results[split_name] = {}
        
        for class_name in split_info.keys():
            class_path = split_path / class_name
            if class_path.exists():
                actual_count = len(list(class_path.glob("*.jpg")) + list(class_path.glob("*.jpeg")) + list(class_path.glob("*.png")))
                expected_count = split_info[class_name][split_name]
                verification_results[split_name][class_name] = {
                    'expected': expected_count,
                    'actual': actual_count,
                    'match': actual_count == expected_count
                }
    
    # Print verification results
    print("\n" + "="*80)
    print("SPLIT VERIFICATION RESULTS")
    print("="*80)
    
    all_correct = True
    for split_name, classes in verification_results.items():
        print(f"\n{split_name.upper()} SPLIT:")
        print("-" * 40)
        split_correct = True
        
        for class_name, result in classes.items():
            status = "✓" if result['match'] else "✗"
            print(f"{class_name:<20} Expected: {result['expected']:<4} Actual: {result['actual']:<4} {status}")
            if not result['match']:
                split_correct = False
                all_correct = False
        
        if split_correct:
            print(f"✓ {split_name.upper()} split is correct")
        else:
            print(f"✗ {split_name.upper()} split has errors")
    
    if all_correct:
        print("\n🎉 All splits verified successfully!")
    else:
        print("\n⚠️  Some splits have verification errors!")
    
    return all_correct


def create_dataset_info(output_path, split_info):
    """
    Create additional dataset information files.
    
    Args:
        output_path (str): Path to the split dataset
        split_info (dict): Information about the split
    """
    output_path = Path(output_path)
    
    # Create class mapping file
    classes = sorted(split_info.keys())
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    idx_to_class = {i: cls_name for cls_name, i in class_to_idx.items()}
    
    class_mapping = {
        'class_to_idx': class_to_idx,
        'idx_to_class': idx_to_class,
        'num_classes': len(classes)
    }
    
    mapping_path = output_path / "class_mapping.json"
    with open(mapping_path, 'w') as f:
        json.dump(class_mapping, f, indent=2)
    
    # Create dataset statistics
    stats = {
        'num_classes': len(classes),
        'total_images': sum(info['total'] for info in split_info.values()),
        'class_distributions': split_info,
        'split_ratios': {
            'train': sum(info['train'] for info in split_info.values()),
            'val': sum(info['val'] for info in split_info.values()),
            'test': sum(info['test'] for info in split_info.values())
        }
    }
    
    stats_path = output_path / "dataset_stats.json"
    with open(stats_path, 'w') as f:
        json.dump(stats, f, indent=2)
    
    print(f"Class mapping saved to: {mapping_path}")
    print(f"Dataset statistics saved to: {stats_path}")


def main():
    parser = argparse.ArgumentParser(description='Split Caltech-256 dataset into train/val/test splits')
    parser.add_argument('--dataset_path', type=str, required=True,
                       help='Path to the original Caltech-256 dataset')
    parser.add_argument('--output_path', type=str, required=True,
                       help='Path to output directory for split dataset')
    parser.add_argument('--train_ratio', type=float, default=0.7,
                       help='Proportion for training set (default: 0.7)')
    parser.add_argument('--val_ratio', type=float, default=0.20,
                       help='Proportion for validation set (default: 0.20)')
    parser.add_argument('--test_ratio', type=float, default=0.10,
                       help='Proportion for test set (default: 0.10)')
    parser.add_argument('--seed', type=int, default=42,
                       help='Random seed for reproducibility (default: 42)')
    parser.add_argument('--verify', action='store_true',
                       help='Verify the split after completion')
    parser.add_argument('--skip_copy', action='store_true',
                       help='Skip copying files (for testing)')
    
    args = parser.parse_args()
    
    # Validate input path
    dataset_path = Path(args.dataset_path)
    if not dataset_path.exists():
        print(f"Error: Dataset path {dataset_path} does not exist!")
        return
    
    # Create output directory
    output_path = Path(args.output_path)
    output_path.mkdir(parents=True, exist_ok=True)
    
    print(f"Dataset path: {dataset_path}")
    print(f"Output path: {output_path}")
    print(f"Split ratios: Train={args.train_ratio:.1%}, Val={args.val_ratio:.1%}, Test={args.test_ratio:.1%}")
    print(f"Random seed: {args.seed}")
    
    # Perform the split
    try:
        split_info = split_dataset(
            dataset_path=args.dataset_path,
            output_path=args.output_path,
            train_ratio=args.train_ratio,
            val_ratio=args.val_ratio,
            test_ratio=args.test_ratio,
            seed=args.seed
        )
        
        # Create additional dataset information
        create_dataset_info(args.output_path, split_info)
        
        # Verify the split if requested
        if args.verify:
            verify_split(args.output_path, split_info)
        
        print(f"\n✅ Dataset split completed successfully!")
        print(f"📁 Split dataset saved to: {output_path}")
        print(f"📊 Split summary: {output_path}/split_summary.json")
        print(f"🗂️  Class mapping: {output_path}/class_mapping.json")
        print(f"📈 Dataset stats: {output_path}/dataset_stats.json")
        
    except Exception as e:
        print(f"❌ Error during dataset split: {e}")
        return


if __name__ == '__main__':
    main() 