"""Utilities for building dataset indices."""

import os
import pandas as pd
import argparse
from pathlib import Path
from typing import List, Dict, Tuple
from sklearn.model_selection import train_test_split


def build_officehome_index(
    data_root: str,
    output_file: str,
    train_ratio: float = 0.7,
    val_ratio: float = 0.15,
    test_ratio: float = 0.15,
    seed: int = 42
) -> None:
    """Build index for Office-Home dataset.
    
    Args:
        data_root: Root directory containing domain folders
        output_file: Output CSV file path
        train_ratio: Training set ratio
        val_ratio: Validation set ratio
        test_ratio: Test set ratio
        seed: Random seed for splitting
    """
    data_root = Path(data_root)
    domains = ["Art", "Clipart", "Product", "RealWorld"]
    
    all_samples = []
    
    for domain in domains:
        domain_path = data_root / domain
        if not domain_path.exists():
            print(f"Warning: Domain {domain} not found at {domain_path}")
            continue
        
        for class_idx, class_name in enumerate(sorted(os.listdir(domain_path))):
            class_path = domain_path / class_name
            if class_path.is_dir():
                for img_name in os.listdir(class_path):
                    if img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
                        img_path = os.path.join(domain, class_name, img_name)
                        all_samples.append({
                            'img_path': img_path,
                            'label': class_idx,
                            'domain': domain,
                            'class_name': class_name
                        })
    
    # Convert to DataFrame
    df = pd.DataFrame(all_samples)
    
    # Split into train/val/test
    splits = []
    for domain in domains:
        domain_df = df[df['domain'] == domain]
        if len(domain_df) == 0:
            continue
        
        # Stratified split by class
        train_df, temp_df = train_test_split(
            domain_df, 
            train_size=train_ratio, 
            stratify=domain_df['label'],
            random_state=seed
        )
        
        val_size = val_ratio / (val_ratio + test_ratio)
        val_df, test_df = train_test_split(
            temp_df,
            train_size=val_size,
            stratify=temp_df['label'],
            random_state=seed
        )
        
        # Add split information
        train_df['split'] = 'train'
        val_df['split'] = 'val'
        test_df['split'] = 'test'
        
        splits.extend([train_df, val_df, test_df])
    
    # Combine all splits
    final_df = pd.concat(splits, ignore_index=True)
    
    # Save to CSV
    final_df.to_csv(output_file, index=False)
    print(f"Saved index with {len(final_df)} samples to {output_file}")
    
    # Print statistics
    for domain in domains:
        domain_df = final_df[final_df['domain'] == domain]
        if len(domain_df) > 0:
            print(f"\n{domain}:")
            for split in ['train', 'val', 'test']:
                split_df = domain_df[domain_df['split'] == split]
                print(f"  {split}: {len(split_df)} samples")


def build_domainnet_index(
    data_root: str,
    output_file: str,
    train_ratio: float = 0.7,
    val_ratio: float = 0.15,
    test_ratio: float = 0.15,
    seed: int = 42
) -> None:
    """Build index for DomainNet dataset.
    
    Args:
        data_root: Root directory containing domain folders
        output_file: Output CSV file path
        train_ratio: Training set ratio
        val_ratio: Validation set ratio
        test_ratio: Test set ratio
        seed: Random seed for splitting
    """
    data_root = Path(data_root)
    domains = ["clipart", "infograph", "painting", "quickdraw", "real", "sketch"]
    
    # --- Start of edit: Build a global class-to-index mapping ---
    print("Building global class-to-index mapping...")
    all_class_names = set()
    for domain in domains:
        domain_path = data_root / domain
        if domain_path.exists():
            for class_name in os.listdir(domain_path):
                if (domain_path / class_name).is_dir():
                    all_class_names.add(class_name)
    
    class_to_idx = {name: i for i, name in enumerate(sorted(list(all_class_names)))}
    print(f"Found {len(class_to_idx)} unique classes across all domains.")
    # --- End of edit ---

    all_samples = []
    
    for domain in domains:
        domain_path = data_root / domain
        if not domain_path.exists():
            print(f"Warning: Domain {domain} not found at {domain_path}")
            continue
        
        # --- Start of edit: Use global mapping instead of local index ---
        for class_name in sorted(os.listdir(domain_path)):
            class_path = domain_path / class_name
            if class_path.is_dir():
                class_idx = class_to_idx.get(class_name)
                if class_idx is None:
                    print(f"Warning: Class '{class_name}' from domain '{domain}' not in global map. Skipping.")
                    continue

                for img_name in os.listdir(class_path):
                    if img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
                        img_path = os.path.join(domain, class_name, img_name)
                        all_samples.append({
                            'img_path': img_path,
                            'label': class_idx,
                            'domain': domain,
                            'class_name': class_name
                        })
        # --- End of edit ---
    
    # Convert to DataFrame
    df = pd.DataFrame(all_samples)
    
    # Split into train/val/test
    splits = []
    for domain in domains:
        domain_df = df[df['domain'] == domain]
        if len(domain_df) == 0:
            continue
        
        # Stratified split by class
        train_df, temp_df = train_test_split(
            domain_df, 
            train_size=train_ratio, 
            stratify=domain_df['label'],
            random_state=seed
        )
        
        val_size = val_ratio / (val_ratio + test_ratio)
        val_df, test_df = train_test_split(
            temp_df,
            train_size=val_size,
            stratify=temp_df['label'],
            random_state=seed
        )
        
        # Add split information
        train_df['split'] = 'train'
        val_df['split'] = 'val'
        test_df['split'] = 'test'
        
        splits.extend([train_df, val_df, test_df])
    
    # Combine all splits
    final_df = pd.concat(splits, ignore_index=True)
    
    # Save to CSV
    final_df.to_csv(output_file, index=False)
    print(f"Saved index with {len(final_df)} samples to {output_file}")
    
    # Print statistics
    for domain in domains:
        domain_df = final_df[final_df['domain'] == domain]
        if len(domain_df) > 0:
            print(f"\n{domain}:")
            for split in ['train', 'val', 'test']:
                split_df = domain_df[domain_df['split'] == split]
                print(f"  {split}: {len(split_df)} samples")


def main():
    """Main function for command-line usage."""
    parser = argparse.ArgumentParser(description="Build dataset indices")
    parser.add_argument("--dataset", type=str, required=True, 
                       choices=["officehome", "domainnet"],
                       help="Dataset name")
    parser.add_argument("--data-root", type=str, required=True,
                       help="Root directory containing domain folders")
    parser.add_argument("--output", type=str, required=True,
                       help="Output CSV file path")
    parser.add_argument("--train-ratio", type=float, default=0.7,
                       help="Training set ratio")
    parser.add_argument("--val-ratio", type=float, default=0.15,
                       help="Validation set ratio")
    parser.add_argument("--test-ratio", type=float, default=0.15,
                       help="Test set ratio")
    parser.add_argument("--seed", type=int, default=42,
                       help="Random seed for splitting")
    
    args = parser.parse_args()
    
    # Validate ratios
    total_ratio = args.train_ratio + args.val_ratio + args.test_ratio
    if abs(total_ratio - 1.0) > 1e-6:
        raise ValueError(f"Ratios must sum to 1.0, got {total_ratio}")
    
    # Build index
    if args.dataset == "officehome":
        build_officehome_index(
            args.data_root,
            args.output,
            args.train_ratio,
            args.val_ratio,
            args.test_ratio,
            args.seed
        )
    elif args.dataset == "domainnet":
        build_domainnet_index(
            args.data_root,
            args.output,
            args.train_ratio,
            args.val_ratio,
            args.test_ratio,
            args.seed
        )


if __name__ == "__main__":
    main()
