#!/usr/bin/env python
"""
Compute all dimensionality measures for CPRO experiments using SVD-based approach.

This script processes experiment results and computes various dimensionality measures:
- Global dimensionality
- Task dimensionality  
- Stimulus dimensionality (overall, collapse-only, integration-only)
- Motor response dimensionality (overall, collapse-only, integration-only)
- Rule satisfaction dimensionality

Results are saved as numpy arrays in pickle format for further analysis.
"""

import numpy as np
import pickle
import argparse
from pathlib import Path
import sys
import logging
from tqdm import tqdm
from datetime import datetime

# Add src directory to path
sys.path.append('./src')

from analysis_utils import (
    get_global_dimensionality,
    get_task_dimensionality,
    get_stimulus_dimensionality,
    get_motor_response_dimensionality,
    get_rule_satisfaction_dimensionality,
    get_logical_rule_dimensionality,
    get_sensory_rule_dimensionality,
    get_motor_rule_dimensionality,
    get_all_rule_domain_dimensionalities
)

def setup_logging(output_dir):
    """Setup logging to both file and console."""
    log_file = output_dir / f"dimensionality_computation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger(__name__)

def validate_experiment_files(results_dir, init_scales, training_modes, seeds, optimizer):
    """Validate that all expected experiment files exist."""
    missing_files = []
    
    for init_scale in init_scales:
        for training_mode in training_modes:
            for seed in seeds:
                exp_dir = results_dir / f'cpu_experiment_{training_mode}_scale{init_scale}_{optimizer}_seed{seed}'
                hidden_file = exp_dir / f'hidden_scale{init_scale}_{optimizer}_seed{seed}.pt'
                
                if not hidden_file.exists():
                    missing_files.append(str(hidden_file))
    
    return missing_files

def compute_single_experiment_dimensionalities(hidden_file, trial_length=10, compute_rule_domains=False, logger=None):
    """
    Compute all dimensionality measures for a single experiment.
    
    Args:
        hidden_file: Path to hidden activities file
        trial_length: Number of timepoints to analyze
        logger: Logger instance for progress tracking
    
    Returns:
        Dictionary with all dimensionality arrays
    """
    try:
        results = {
            # 'global_dim': get_global_dimensionality(hidden_file, trial_length),
            # 'task_dim': get_task_dimensionality(hidden_file, trial_length),
            # 'task_collapse_dim': get_task_dimensionality(hidden_file, trial_length, collapse_only=True),
            # 'task_integration_dim': get_task_dimensionality(hidden_file, trial_length, integration_only=True),
            # 'stimulus_dim': get_stimulus_dimensionality(hidden_file, trial_length),
            # 'stimulus_collapse_dim': get_stimulus_dimensionality(hidden_file, trial_length, collapse_only=True),
            # 'stimulus_integration_dim': get_stimulus_dimensionality(hidden_file, trial_length, integration_only=True),
            # 'motor_response_dim': get_motor_response_dimensionality(hidden_file, trial_length),
            # 'motor_response_collapse_dim': get_motor_response_dimensionality(hidden_file, trial_length, collapse_only=True),
            # 'motor_response_integration_dim': get_motor_response_dimensionality(hidden_file, trial_length, integration_only=True),
            # 'rule_satisfaction_dim': get_rule_satisfaction_dimensionality(hidden_file, trial_length),
            # 'rule_satisfaction_collapse_dim': get_rule_satisfaction_dimensionality(hidden_file, trial_length, collapse_only=True),
            # 'rule_satisfaction_integration_dim': get_rule_satisfaction_dimensionality(hidden_file, trial_length, integration_only=True)
        }
        
        if compute_rule_domains:
            rule_domain_results = get_all_rule_domain_dimensionalities(hidden_file, trial_length)
            results.update(rule_domain_results)
        
        return results
    
    except Exception as e:
        if logger:
            logger.error(f"Error processing {hidden_file}: {str(e)}")
        return None

def main():
    parser = argparse.ArgumentParser(description='Compute all dimensionality measures for CPRO experiments')
    
    # Project configuration
    PROJECT_PATH = '/home/ln275/f_mc1689_1/cpro-rnn/docs/scripts/'
    
    # Required arguments
    parser.add_argument('--batch-dir', type=str, required=True,
                        help='Batch directory name (e.g., slurm_cpu_batch_20250604_111254)')
    parser.add_argument('--output-subdir', type=str, required=True,
                        help='Output subdirectory name (e.g., analysis_outputs_svd)')
    
    # Optional path overrides (in case you want to use different base paths)
    parser.add_argument('--project-path', type=str, default=PROJECT_PATH,
                        help=f'Project base path (default: {PROJECT_PATH})')
    parser.add_argument('--results-base', type=str, default='results',
                        help='Results base directory name (default: results)')
    
    # Optional arguments with defaults
    parser.add_argument('--optimizer', type=str, default='sgd',
                        choices=['sgd', 'adamw'], help='Optimizer used in experiments')
    parser.add_argument('--trial-length', type=int, default=10,
                        help='Number of timepoints to analyze (default: 10)')
    parser.add_argument('--init-scales', type=float, nargs='+',
                        default=[0.01, 0.041617914502878176, 0.17320508075688776, 0.7208434242404266, 3.0],
                        help='List of initial scales used in experiments')
    parser.add_argument('--training-modes', type=str, nargs='+',
                        default=['minimal', 'balanced_16', 'balanced_32', 'balanced_48', 'maximal'],
                        help='List of training modes used in experiments')
    parser.add_argument('--seeds', type=int, nargs='+',
                        default=[42, 123, 234, 345, 456, 567, 678, 789, 890, 901],
                        help='List of seeds used in experiments')
    parser.add_argument('--resume', action='store_true',
                        help='Resume computation from existing files')
    parser.add_argument('--validate-only', action='store_true',
                        help='Only validate that all files exist, don\'t compute')
    
    # Rule-domain dimensionalities computation
    parser.add_argument('--compute-rule-domains', action='store_true',
                    help='Compute rule domain dimensionalities (logical, sensory, motor)')
    
    args = parser.parse_args()
    
    # Construct full paths
    project_path = Path(args.project_path)
    results_dir = project_path / args.results_base / args.batch_dir
    output_dir = project_path / args.results_base / args.output_subdir
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Setup logging
    logger = setup_logging(output_dir)
    logger.info(f"Starting dimensionality computation")
    logger.info(f"Results directory: {results_dir}")
    logger.info(f"Output directory: {output_dir}")
    logger.info(f"Configuration: {len(args.init_scales)} scales, {len(args.training_modes)} training modes, {len(args.seeds)} seeds")
    
    # Validate experiment files
    logger.info("Validating experiment files...")
    missing_files = validate_experiment_files(results_dir, args.init_scales, args.training_modes, args.seeds, args.optimizer)
    
    if missing_files:
        logger.error(f"Missing {len(missing_files)} experiment files:")
        for f in missing_files[:10]:  # Show first 10
            logger.error(f"  {f}")
        if len(missing_files) > 10:
            logger.error(f"  ... and {len(missing_files) - 10} more")
        
        if not args.resume:
            logger.error("Aborting. Use --resume to skip missing files.")
            return
        else:
            logger.warning(f"Continuing with {len(missing_files)} missing files...")
    
    if args.validate_only:
        logger.info("Validation complete. Exiting.")
        return
    
    # Initialize arrays for all dimensionality measures
    array_shape = (len(args.init_scales), len(args.training_modes), len(args.seeds), args.trial_length)
    
    dimensionality_arrays = {
        'global_dim': np.full(array_shape, np.nan),
        'task_dim': np.full(array_shape, np.nan),
        'task_collapse_dim': np.full(array_shape, np.nan),
        'task_integration_dim': np.full(array_shape, np.nan),
        'stimulus_dim': np.full(array_shape, np.nan),
        'stimulus_collapse_dim': np.full(array_shape, np.nan),
        'stimulus_integration_dim': np.full(array_shape, np.nan),
        'motor_response_dim': np.full(array_shape, np.nan),
        'motor_response_collapse_dim': np.full(array_shape, np.nan),
        'motor_response_integration_dim': np.full(array_shape, np.nan),
        'rule_satisfaction_dim': np.full(array_shape, np.nan),
        'rule_satisfaction_collapse_dim': np.full(array_shape, np.nan),
        'rule_satisfaction_integration_dim': np.full(array_shape, np.nan)
    }
    
    if args.compute_rule_domains:
        rule_domain_arrays = {
            'logical_rule_dim': np.full(array_shape, np.nan),
            'sensory_rule_dim': np.full(array_shape, np.nan),
            'motor_rule_dim': np.full(array_shape, np.nan)
        }
        dimensionality_arrays.update(rule_domain_arrays)
        logger.info("Rule domain computation enabled")
    
    # Load existing arrays if resuming
    if args.resume:
        for measure_name in dimensionality_arrays.keys():
            pkl_file = output_dir / f'{measure_name}_arr.pkl'
            if pkl_file.exists():
                logger.info(f"Loading existing {measure_name} array")
                with open(pkl_file, 'rb') as f:
                    dimensionality_arrays[measure_name] = pickle.load(f)
    
    # Main computation loop
    total_experiments = len(args.init_scales) * len(args.training_modes) * len(args.seeds)
    completed = 0
    skipped = 0
    errors = 0
    
    logger.info(f"Starting computation for {total_experiments} experiments...")
    
    with tqdm(total=total_experiments, desc="Processing experiments") as pbar:
        for sc_idx, init_scale in enumerate(args.init_scales):
            for t_idx, training_mode in enumerate(args.training_modes):
                for s_idx, seed in enumerate(args.seeds):
                    
                    # Check if already computed (when resuming)
                    if args.resume and not np.isnan(dimensionality_arrays['global_dim'][sc_idx, t_idx, s_idx, 0]):
                        skipped += 1
                        pbar.update(1)
                        continue
                    
                    # Construct file path
                    exp_dir = results_dir / f'cpu_experiment_{training_mode}_scale{init_scale}_{args.optimizer}_seed{seed}'
                    hidden_file = exp_dir / f'hidden_scale{init_scale}_{args.optimizer}_seed{seed}.pt'
                    
                    if not hidden_file.exists():
                        logger.warning(f"Missing file: {hidden_file}")
                        errors += 1
                        pbar.update(1)
                        continue
                    
                    # Compute dimensionalities for this experiment
                    results = compute_single_experiment_dimensionalities(hidden_file, args.trial_length, logger)
                    results = compute_single_experiment_dimensionalities(hidden_file, args.trial_length,                                                        args.compute_rule_domains, logger)
                    
                    if results is None:
                        errors += 1
                        pbar.update(1)
                        continue
                    
                    # Store results in arrays
                    for measure_name, values in results.items():
                        dimensionality_arrays[measure_name][sc_idx, t_idx, s_idx, :] = values
                    
                    completed += 1
                    pbar.update(1)
                    
                    # Save intermediate results every 50 experiments
                    if completed % 50 == 0:
                        logger.info(f"Saving intermediate results after {completed} experiments...")
                        for measure_name, array in dimensionality_arrays.items():
                            pkl_file = output_dir / f'{measure_name}_arr.pkl'
                            with open(pkl_file, 'wb') as f:
                                pickle.dump(array, f)
    
    # Final save
    logger.info("Saving final results...")
    for measure_name, array in dimensionality_arrays.items():
        pkl_file = output_dir / f'{measure_name}_arr.pkl'
        with open(pkl_file, 'wb') as f:
            pickle.dump(array, f)
        logger.info(f"Saved {measure_name} to {pkl_file}")
    
    # Save metadata
    metadata = {
        'init_scales': args.init_scales,
        'training_modes': args.training_modes,
        'seeds': args.seeds,
        'optimizer': args.optimizer,
        'trial_length': args.trial_length,
        'array_shape': array_shape,
        'completed_experiments': completed,
        'skipped_experiments': skipped,
        'failed_experiments': errors,
        'computation_date': datetime.now().isoformat()
    }
    
    metadata_file = output_dir / 'dimensionality_metadata.pkl'
    with open(metadata_file, 'wb') as f:
        pickle.dump(metadata, f)
    
    logger.info(f"Computation complete!")
    logger.info(f"Completed: {completed}, Skipped: {skipped}, Errors: {errors}")
    logger.info(f"Results saved to: {output_dir}")

if __name__ == "__main__":
    main()