"""Main simulation runner for MMD/Wald/CODITE experiments."""

import numpy as np
from numpy.random import SeedSequence
import pandas as pd
import torch
import time
import pickle
import os
import logging
from datetime import datetime
from typing import Dict, List, Tuple

from config import SimulationConfig
from codite_mmd_test import codite_mmd_test
from proposed_test import proposed_mmd_test, proposed_wald_test
from utils import validate_device

import argparse

def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description='Run MMD/Wald/CODITE simulation experiments')
    
    # Data paths
    parser.add_argument('--X_data_path', default='./sim_data/null_unedited.npy', 
                       help='Path to X data file')
    parser.add_argument('--Y_null_path', default='./sim_data/null_bright_rand.npy',
                       help='Path to Y null data file')
    parser.add_argument('--Y_alt_path', default='./sim_data/alternate_bright_rand_rot_det.npy',
                       help='Path to Y alternative data file')
    parser.add_argument('--A_data_path', default='./sim_data/trt_assign.npy',
                       help='Path to treatment assignment data file')
    parser.add_argument('--propensity_path', default='./sim_data/propensities.npy',
                   help='Path to known propensity scores file')
    
    # Experiment parameters
    parser.add_argument('--job_id', type=int, default=None,
                       help='Unique job ID for parallel runs (auto-generated if not specified)')
    parser.add_argument('--n_replicates', type=int, default=1000,
                       help='Number of replicates per configuration')
    parser.add_argument('--n_bootstrap', type=int, default=1000,
                       help='Number of bootstrap/permutation samples')
    parser.add_argument('--alpha', type=float, default=0.05,
                       help='Significance level')
    
    # Experimental factors
    parser.add_argument('--test_types', nargs='+', default=['wald', 'codite', 'mmd'],
                       choices=['wald', 'codite', 'mmd'],
                       help='Test types to run')
    parser.add_argument('--scenarios', nargs='+', default=['null', 'alternate'],
                       choices=['null', 'alternate'],
                       help='Scenarios to test')
    parser.add_argument('--sample_sizes', nargs='+', type=int, default=[250, 500, 750, 1000],
                       help='Sample sizes to test')
    
    parser.add_argument('--misspec_combinations', nargs='+', 
                   help='Misspecification combinations as "ps,om" pairs. '
                        'Examples: "False,False" "False,True" "True,False" "True,True". '
                        'If not specified, uses all 4 combinations from config.py')
    
    # Output options
    parser.add_argument('--results_dir', default=None,
                       help='Results directory (auto-generated if not specified)')
    parser.add_argument('--log_level', default='INFO', 
                       choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
                       help='Logging level')

    parser.add_argument('--dry_run', action='store_true', help='Test setup without running full experiments')
    
    parser.add_argument('--use_oracle_propensity', action='store_true', 
                        help="If set, uses propensity scores from file instead of estimating them.")

    return parser.parse_args()

def setup_logging(results_dir: str) -> logging.Logger:
    """Setup logging for the simulation."""
    # Simple logging setup that works in SLURM
    log_file = os.path.join(results_dir, f"simulation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ],
        force=True  # Override any existing loggers
    )
    
    # Quiet external libraries
    logging.getLogger('lightgbm').setLevel(logging.ERROR)
    try:
        import optuna
        optuna.logging.set_verbosity(optuna.logging.WARNING)
    except:
        pass
    
    return logging.getLogger(__name__)

def sample_data(X_data, A_data, Y_data, propensity_data, n_samples, random_state):
    """Sample n_samples from the full dataset."""
    rng = np.random.default_rng(seed=random_state)
    indices = rng.choice(len(X_data), size=n_samples, replace=True) #, replace=False)
    return X_data[indices], A_data[indices], Y_data[indices], propensity_data[indices] 

def run_single_experiment(X, A, Y, propensity, test_type, misspec_ps, misspec_om, n_bootstrap, device, random_state, logger):  # ADD propensity
    """Run a single experiment."""
    try:
        if test_type == "codite":
            test_stat, p_val = codite_mmd_test(
                X, A, Y,
                misspecify_propensity_model=misspec_ps,
                misspecify_outcome_model=misspec_om,
                num_perm=(1.5*(n_bootstrap//10)),
                device=device,
                random_state=random_state,
                propensity=propensity 
            )
            return test_stat, p_val, True
            
        elif test_type == "mmd":
            test_stat, p_val, rejected, info = proposed_mmd_test(
                X, A, Y,
                misspecify_propensity_model=misspec_ps,
                misspecify_outcome_model=misspec_om,
                n_bootstrap=n_bootstrap,
                device=device,
                random_state=random_state,
                propensity=propensity
            )
            return test_stat, p_val, True
            
        elif test_type == "wald":
            test_stat, p_val, rejected, info = proposed_wald_test(
                X, A, Y,
                misspecify_propensity_model=misspec_ps,
                misspecify_outcome_model=misspec_om,
                n_bootstrap=n_bootstrap,
                device=device,
                random_state=random_state,
                propensity=propensity
            )
            return test_stat, p_val, True
            
    except Exception as e:
        logger.error(f"Error in {test_type} test: {e}")
        return np.nan, np.nan, False

def save_checkpoint(results, results_dir, experiment_id, total_experiments, logger):
    """Save intermediate results."""
    try:
        # Save progress info
        progress_file = os.path.join(results_dir, "progress.txt")
        with open(progress_file, 'w') as f:
            f.write(f"Completed {experiment_id}/{total_experiments} experiments\n")
            f.write(f"Last update: {datetime.now()}\n")

        # Save checkpoint
        checkpoint_dir = os.path.join(results_dir, "checkpoints")
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        filename = os.path.join(checkpoint_dir, f"checkpoint_{experiment_id}.pkl")
        with open(filename, 'wb') as f:
            pickle.dump(results, f)
        
        logger.info(f"Checkpoint saved: {experiment_id}/{total_experiments} experiments completed")
    except Exception as e:
        logger.error(f"Failed to save checkpoint: {e}")

def save_final_results(results, results_dir, logger):
    """Save final results."""
    try:
        # Save as pickle
        pickle_file = os.path.join(results_dir, "final_results.pkl")
        with open(pickle_file, 'wb') as f:
            pickle.dump(results, f)
        logger.info(f"Saved pickle results: {pickle_file}")
        
        # Save as CSV
        df = pd.DataFrame(results)
        csv_file = os.path.join(results_dir, "simulation_results.csv")
        df.to_csv(csv_file, index=False)
        logger.info(f"Saved CSV results: {csv_file}")
        
        # Create summary
        if len(df) > 0:
            summary = df.groupby([
                'test_type', 'scenario', 'n_samples', 'misspec_ps', 'misspec_om'
            ]).agg({
                'rejected': ['count', 'sum', 'mean'],
                'success': 'mean',
                'test_statistic': ['mean', 'std'],
                'p_value': ['mean', 'std'],
                'a_ratio': ['mean', 'std']
            }).round(4)
            
            summary_file = os.path.join(results_dir, "summary_statistics.csv")
            summary.to_csv(summary_file)
            logger.info(f"Saved summary: {summary_file}")
        
    except Exception as e:
        logger.error(f"Failed to save results: {e}")
        # Try to at least save the raw results
        try:
            emergency_file = os.path.join(results_dir, f"emergency_results_{int(time.time())}.pkl")
            with open(emergency_file, 'wb') as f:
                pickle.dump(results, f)
            logger.info(f"Saved emergency backup: {emergency_file}")
        except:
            logger.error("Failed to save emergency backup")

def main():
    """Main simulation runner."""
    args = parse_args()

    # Parse misspec combinations if provided
    misspec_combos = None
    if args.misspec_combinations:
        misspec_combos = []
        for combo_str in args.misspec_combinations:
            parts = combo_str.split(',')
            if len(parts) != 2:
                raise ValueError(f"Invalid misspec combination format: '{combo_str}'. "
                               f"Expected format: 'ps,om' (e.g., 'False,False')")
            ps_str, om_str = parts
            ps_bool = ps_str.strip().lower() == 'true'
            om_bool = om_str.strip().lower() == 'true'
            misspec_combos.append((ps_bool, om_bool))
        print(f"Using {len(misspec_combos)} custom misspec combinations: {misspec_combos}")
    else:
        print("Using default misspec combinations from config.py (all 4)")

    
    # Create config from arguments
    config = SimulationConfig(
        X_data_path=args.X_data_path,
        Y_null_path=args.Y_null_path,
        Y_alt_path=args.Y_alt_path,
        A_data_path=args.A_data_path,
        propensity_path=args.propensity_path,
        n_replicates=args.n_replicates,
        n_bootstrap=args.n_bootstrap,
        alpha=args.alpha,
        test_types=args.test_types,
        scenarios=args.scenarios,
        sample_sizes=args.sample_sizes,
        misspec_combinations=misspec_combos
    )

    if args.dry_run:
        print("Dry run - would run experiments with config:")
        print(f"  Total experiments: {config.get_total_experiments():,}")
        return
    
    # Setup job_id
    if args.job_id is None:
        # Check for SLURM variables first
        slurm_id = os.environ.get("SLURM_ARRAY_TASK_ID") or os.environ.get("SLURM_JOB_ID")
        
        if slurm_id:
            job_id = int(slurm_id)
            print(f"Using SLURM ID: {job_id}")
        else:
            # Fallback for local runs (add randomness to avoid PID/Time collisions)
            import secrets
            job_id = secrets.randbelow(1_000_000_000)
            logger.info(f"!!! GENERATED RANDOM JOB_ID: {job_id} !!!") 
            logger.info(f"!!! SAVE THIS ID FOR REPRODUCIBILITY !!!")
    else:
        job_id = args.job_id
        print(f"Using provided job_id: {job_id}")

    # Setup results directory
    if args.results_dir is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        results_dir = f"simulation_results_{timestamp}_job{job_id}"
    else:
        results_dir = f"{args.results_dir}_job{job_id}"
    
    os.makedirs(results_dir, exist_ok=True)
    
    # Setup logging
    logger = setup_logging(results_dir)
    logger.info("Starting simulation experiments")
    logger.info(f"Job ID: {job_id}")
    logger.info(f"Results directory: {results_dir}")
    
    # Setup device
    device = validate_device()
    logger.info(f"Using device: {device}")
    
    # Load data
    logger.info("Loading data...")
    try:
        X_data = np.load(config.X_data_path)
        Y_null = np.load(config.Y_null_path)
        Y_alt = np.load(config.Y_alt_path)
        A_data = np.load(config.A_data_path)
        propensity_data = np.load(config.propensity_path)
        logger.info(f"Data loaded - X: {X_data.shape}, A: {A_data.shape}")
    except Exception as e:
        logger.error(f"Failed to load data: {e}")
        return
    
    # Calculate total experiments
    total_experiments = config.get_total_experiments()
    logger.info(f"Total experiments: {total_experiments:,}")
    
    # Main experiment loop
    results = []
    experiment_id = 0
    start_time = time.time()
    
    total_configs = (len(config.test_types) * len(config.scenarios) * 
                    len(config.sample_sizes) * len(config.misspec_combinations))
    config_num = 0
    
    for test_type in config.test_types:
        for scenario in config.scenarios:
            Y_data = Y_null if scenario == "null" else Y_alt
            
            for n_samples in config.sample_sizes:
                for misspec_ps, misspec_om in config.misspec_combinations:
                
                    propensity_arg_to_use = None
                    if args.use_oracle_propensity:
                        propensity_arg_to_use = "oracle" # Placeholder to indicate we grab it inside the loop
                        if misspec_ps:
                             logger.warning("Configuration warning: use_oracle_propensity=True but misspec_ps=True. Oracle propensity will override misspecification!")
                    
                    config_num += 1
                    
                    logger.info(f"Config {config_num}/{total_configs}: {test_type}|{scenario}|n={n_samples}|ps={misspec_ps}|om={misspec_om}")
                    
                    config_results = []
                    config_start = time.time()
                    
                    for replicate in range(config.n_replicates):
                        experiment_id += 1
                        # unique_seed = (experiment_id + experiment_id_offset) % (2**32 - 1)

                        ss = SeedSequence([job_id, experiment_id])
                        unique_seed = int(ss.generate_state(1, dtype=np.uint32)[0])
                        
                        # Sample data  
                        X_sample, A_sample, Y_sample, propensity_sample = sample_data(
                            X_data, A_data, Y_data, propensity_data, n_samples, unique_seed
                        )
                        
                        current_propensity = propensity_sample if propensity_arg_to_use=="oracle" else None

                        a_ratio = np.mean(A_sample)

                        # Run experiment
                        test_stat, p_val, success = run_single_experiment(
                            X_sample, A_sample, Y_sample, current_propensity, 
                            test_type, misspec_ps, misspec_om,
                            config.n_bootstrap, device, unique_seed, logger
                        )
                        
                        # Store result
                        result = {
                            'experiment_id': experiment_id,
                            'test_type': test_type,
                            'scenario': scenario,
                            'n_samples': n_samples,
                            'misspec_ps': misspec_ps,
                            'misspec_om': misspec_om,
                            'replicate': replicate,
                            'test_statistic': test_stat,
                            'p_value': p_val,
                            'rejected': p_val <= config.alpha if success else np.nan,
                            'success': success,
                            'a_ratio': a_ratio,
                            'timestamp': datetime.now().isoformat()
                        }
                        
                        config_results.append(result)
                        results.append(result)
                        
                        # Progress reporting
                        if replicate % 10 == 0 or replicate == config.n_replicates - 1:
                            successful_so_far = sum(1 for r in config_results if r['success'])
                            logger.info(f"  Replicate {replicate+1}/{config.n_replicates}: "
                                       f"success_rate={successful_so_far}/{len(config_results)}, "
                                       f"a_ratio={a_ratio:.3f}")
                        
                        # Save checkpoint
                        if experiment_id % 50 == 0: 
                            save_checkpoint(results, results_dir, experiment_id, total_experiments, logger)

                        # Memory cleanup
                        if experiment_id % 20 == 0:
                            import gc
                            gc.collect()
                    
                    # Report configuration results
                    successful = [r for r in config_results if r['success']]
                    if successful:
                        rejection_rate = np.mean([r['rejected'] for r in successful])
                        config_time = (time.time() - config_start) / 60
                        logger.info(f"Config completed in {config_time:.1f}min: "
                                   f"rejection_rate={rejection_rate:.3f}, "
                                   f"success={len(successful)}/{len(config_results)}")
                    else:
                        logger.warning("No successful experiments in this configuration!")
                    
                    # Save CSV after each configuration is completed
                    try:
                        df = pd.DataFrame(results)
                        csv_file = os.path.join(results_dir, f"partial_results_config_{config_num}.csv")
                        df.to_csv(csv_file, index=False)
                        logger.info(f"Saved partial results: {csv_file}")
                    except Exception as e:
                        logger.error(f"Failed to save partial results: {e}")
    
    # Final save
    logger.info("Saving final results...")
    save_final_results(results, results_dir, logger)
    
    elapsed_hours = (time.time() - start_time) / 3600
    logger.info(f"Simulation completed in {elapsed_hours:.2f} hours")
    logger.info(f"Total experiments: {len(results):,}")
    logger.info(f"Results saved in: {results_dir}")

if __name__ == "__main__":
    main()