"""
Multi-Seed Model Training Script

This script trains models 20 times with different random seeds (23-42),
saving calibration and test data for each seed to enable downstream evaluation.
Only the final model (seed 42) weights are saved.

Supports datasets: COVID-19, MovieLens

Usage:
    python train_multiple_seeds.py --dataset covid --num_epochs 10
    python train_multiple_seeds.py --dataset movielens --num_epochs 20
"""

import os
import argparse
import numpy as np
from typing import Dict, Optional
from Covid_data import get_covid_data
from MovieLens_data import get_movielens_data


def train_multiple_seeds(
    dataset: str,
    data_dir: Optional[str] = None,
    data_root: Optional[str] = None,
    output_dir: Optional[str] = None,
    start_seed: int = 23,
    end_seed: int = 42,
    num_epochs: int = 10,
    batch_size: int = 32
) -> Dict[int, Dict[str, tuple]]:
    """
    Train model with multiple random seeds and save results.
    
    Args:
        dataset: Dataset name ('covid' or 'movielens')
        data_dir: Root directory of dataset (for COVID)
        data_root: Data root directory (for MovieLens)
        output_dir: Directory to save results (auto-set if None)
        start_seed: Starting seed value (inclusive)
        end_seed: Ending seed value (inclusive)
        num_epochs: Number of training epochs per seed
        batch_size: Batch size for training
        
    Returns:
        Dictionary mapping seed to data shapes for verification
    """
    # Set default output directory based on dataset
    if output_dir is None:
        if dataset == "covid":
            output_dir = "results/Covid_data"
        elif dataset == "movielens":
            output_dir = "results/MovieLens_data"
        else:
            raise ValueError(f"Unknown dataset: {dataset}")
    
    # Create output directories
    prob_dir = os.path.join(output_dir, "probabilities")
    model_dir = os.path.join(output_dir, "models")
    os.makedirs(prob_dir, exist_ok=True)
    os.makedirs(model_dir, exist_ok=True)
    
    dataset_name = "COVID-19" if dataset == "covid" else "MovieLens"
    print("\n" + "=" * 80)
    print(f"Multi-Seed {dataset_name} Model Training")
    print("=" * 80)
    print(f"Seeds: {start_seed} to {end_seed} ({end_seed - start_seed + 1} total)")
    print(f"Output directory: {output_dir}")
    print(f"Epochs per seed: {num_epochs}")
    print(f"Batch size: {batch_size}")
    print("=" * 80 + "\n")
    
    results_summary = {}
    
    # Loop through seeds
    for seed in range(start_seed, end_seed + 1):
        iteration = seed - start_seed + 1
        total_iterations = end_seed - start_seed + 1
        
        print("\n" + "=" * 80)
        print(f"SEED {seed} - Iteration {iteration}/{total_iterations}")
        print("=" * 80)
        
        # Only save model for the last seed (42)
        model_path = os.path.join(model_dir, "seed_42.pth") if seed == end_seed else None
        
        # Train model and extract probabilities based on dataset
        if dataset == "covid":
            cal_probs, cal_labels, test_probs, test_labels = get_covid_data(
                data_dir=data_dir,
                model_path=model_path,
                train_model=True,
                num_epochs=num_epochs,
                batch_size=batch_size,
                seed=seed
            )
        elif dataset == "movielens":
            cal_probs, cal_labels, test_probs, test_labels = get_movielens_data(
                data_root=data_root,
                model_path=model_path,
                train_model=True,
                num_epochs=num_epochs,
                batch_size=batch_size,
                seed=seed
            )
        else:
            raise ValueError(f"Unknown dataset: {dataset}")
        
        # Save probabilities and labels
        output_file = os.path.join(prob_dir, f"seed_{seed}.npz")
        np.savez(
            output_file,
            cal_probs=cal_probs,
            cal_labels=cal_labels,
            test_probs=test_probs,
            test_labels=test_labels
        )
        
        # Store results for summary
        results_summary[seed] = {
            'cal_shape': (cal_probs.shape, cal_labels.shape),
            'test_shape': (test_probs.shape, test_labels.shape),
            'output_file': output_file
        }
        
        print(f"\n✓ Saved probabilities to: {output_file}")
        if model_path:
            print(f"✓ Saved model weights to: {model_path}")
    
    return results_summary


def print_summary(results_summary: Dict[int, Dict[str, tuple]], output_dir: str):
    """
    Print final summary of all training runs.
    
    Args:
        results_summary: Dictionary with results from all seeds
        output_dir: Output directory path
    """
    print("\n\n" + "=" * 80)
    print("TRAINING COMPLETED - SUMMARY")
    print("=" * 80)
    
    # Get representative shapes (should be consistent across seeds)
    first_seed = min(results_summary.keys())
    cal_probs_shape, cal_labels_shape = results_summary[first_seed]['cal_shape']
    test_probs_shape, test_labels_shape = results_summary[first_seed]['test_shape']
    
    print(f"\nTotal seeds processed: {len(results_summary)}")
    print(f"\nData shapes (consistent across all seeds):")
    print(f"  Calibration probabilities: {cal_probs_shape}")
    print(f"  Calibration labels: {cal_labels_shape}")
    print(f"  Test probabilities: {test_probs_shape}")
    print(f"  Test labels: {test_labels_shape}")
    
    print(f"\nOutput structure:")
    print(f"  Probabilities: {os.path.join(output_dir, 'probabilities')}/")
    print(f"    - seed_23.npz")
    print(f"    - seed_24.npz")
    print(f"    - ...")
    print(f"    - seed_42.npz")
    print(f"  Model: {os.path.join(output_dir, 'models')}/")
    print(f"    - seed_42.pth")
    
    print(f"\nDownstream usage example:")
    print(f"```python")
    print(f"import numpy as np")
    print(f"prob_dir = '{os.path.join(output_dir, 'probabilities')}'")
    print(f"for seed in range(23, 43):")
    print(f"    data = np.load(f'{{prob_dir}}/seed_{{seed}}.npz')")
    print(f"    cal_probs = data['cal_probs']")
    print(f"    cal_labels = data['cal_labels']")
    print(f"    test_probs = data['test_probs']")
    print(f"    test_labels = data['test_labels']")
    print(f"    # Your evaluation code here")
    print(f"```")
    
    print("\n" + "=" * 80)
    print("All training runs completed successfully!")
    print("=" * 80 + "\n")


def main():
    """Main entry point with CLI argument parsing."""
    parser = argparse.ArgumentParser(
        description="Train models with multiple random seeds",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Example usage:
  # COVID-19 dataset
  python train_multiple_seeds.py --dataset covid --num_epochs 10
  python train_multiple_seeds.py --dataset covid --data_dir /path/to/dataset --num_epochs 10
  
  # MovieLens dataset
  python train_multiple_seeds.py --dataset movielens --num_epochs 20
  python train_multiple_seeds.py --dataset movielens --data_root /path/to/root --num_epochs 20
        """
    )
    
    parser.add_argument(
        '--dataset',
        type=str,
        required=True,
        choices=['covid', 'movielens'],
        help='Dataset to use: covid or movielens'
    )
    
    parser.add_argument(
        '--data_dir',
        type=str,
        default=None,
        help='[COVID] Root directory of COVID-19 dataset (default: /path/to/your/covid19-radiography-database/COVID-19_Radiography_Dataset)'
    )
    
    parser.add_argument(
        '--data_root',
        type=str,
        default=None,
        help='[MovieLens] Root directory for MovieLens data (default: /path/to/your/data)'
    )
    
    parser.add_argument(
        '--output_dir',
        type=str,
        default=None,
        help='Output directory for results (default: results/{Covid_data|MovieLens_data})'
    )
    
    parser.add_argument(
        '--start_seed',
        type=int,
        default=23,
        help='Starting seed value (default: 23)'
    )
    
    parser.add_argument(
        '--end_seed',
        type=int,
        default=42,
        help='Ending seed value (default: 42)'
    )
    
    parser.add_argument(
        '--num_epochs',
        type=int,
        default=None,
        help='Number of training epochs per seed (default: 10 for covid, 20 for movielens)'
    )
    
    parser.add_argument(
        '--batch_size',
        type=int,
        default=None,
        help='Batch size for training (default: 32 for covid, 256 for movielens)'
    )
    
    args = parser.parse_args()
    
    # Set dataset-specific defaults
    if args.dataset == "covid":
        if args.data_dir is None:
            args.data_dir = '/path/to/your/covid19-radiography-database/COVID-19_Radiography_Dataset'
        if args.num_epochs is None:
            args.num_epochs = 10
        if args.batch_size is None:
            args.batch_size = 32
    elif args.dataset == "movielens":
        if args.data_root is None:
            args.data_root = '/path/to/your/data'
        if args.num_epochs is None:
            args.num_epochs = 20
        if args.batch_size is None:
            args.batch_size = 256
    
    # Validate arguments
    if args.start_seed > args.end_seed:
        parser.error(f"start_seed ({args.start_seed}) must be <= end_seed ({args.end_seed})")
    
    if args.num_epochs <= 0:
        parser.error(f"num_epochs must be positive, got {args.num_epochs}")
    
    if args.batch_size <= 0:
        parser.error(f"batch_size must be positive, got {args.batch_size}")
    
    # Run training
    results_summary = train_multiple_seeds(
        dataset=args.dataset,
        data_dir=args.data_dir,
        data_root=args.data_root,
        output_dir=args.output_dir,
        start_seed=args.start_seed,
        end_seed=args.end_seed,
        num_epochs=args.num_epochs,
        batch_size=args.batch_size
    )
    
    # Print summary
    output_dir = args.output_dir if args.output_dir else (
        "results/Covid_data" if args.dataset == "covid" else "results/MovieLens_data"
    )
    print_summary(results_summary, output_dir)


if __name__ == "__main__":
    main()

