from libs.ood.utils import extract_ood_dataset_info
from exp_ood import OODDetectionExp
from factory import (
    make_ood_dataset,
    make_model,
    make_optimizer,
    make_reporter,
    make_lossfn,
)
import sys
import toml
import torch
import os
import copy
from ray import tune
from ray.tune.schedulers import ASHAScheduler
import ray
import numpy as np

# Fixed list of seeds to use for experiments
SEEDS = [42, 123, 456, 789, 1024]

def run_experiment(config):
    # Extract base config and seed
    base_cfg = config["base_cfg"]
    seed = config["seed"]
    
    try:
        # Create a copy of the base config and update with the current seed
        cfg = copy.deepcopy(base_cfg)
        cfg["exp"]["seed"] = seed
        
        # Set device based on what ray tune assigns
        device = "cuda" if torch.cuda.is_available() else "cpu"
        cfg["exp"]["device"] = device
        cfg["model"]["device"] = device
        
        # Create reporter if specified
        if "reporter" in cfg:
            reporter = make_reporter(cfg["reporter"], cfg)
        else:
            reporter = None

        # Prepare dataset
        dataset_ind, dataset_ood_tr, dataset_ood_te = make_ood_dataset(cfg["dataset"])
        extract_ood_dataset_info(cfg["dataset"], dataset_ind, dataset_ood_tr, dataset_ood_te)

        # Setup loss function and model
        loss_fn, eval_func = make_lossfn(cfg["lossfn"])
        model = make_model(cfg["model"], dataset_ind)

        # Configure optimizer
        warmup_optimizer = None
        if cfg["model"]["name"].lower() == "sgcn":
            teacher, model = model
            teacher_optimizer = make_optimizer(cfg["optimizer"], teacher)
        elif cfg["model"]["name"].lower() == "gpn":
            optimizer, warmup_optimizer = make_optimizer(cfg["optimizer"], model)
        else:
            optimizer = make_optimizer(cfg["optimizer"], model)
        
        # Create and run experiment
        exp = OODDetectionExp(cfg=cfg["exp"], 
                            cfg_model=cfg["model"], 
                            model=model, 
                            criterion=loss_fn, 
                            eval_func=eval_func, 
                            optimizer=optimizer,
                            warmup_optimizer=warmup_optimizer,
                            reporter=reporter, 
                            dataset_ind=dataset_ind, 
                            dataset_ood_tr=dataset_ood_tr, 
                            dataset_ood_te=dataset_ood_te)
        
        # Run the experiment and get results
        results = exp.run()
        
        # Return results directly with success flag
        return {
            "seed": seed,
            "metrics": results,
            "_metric": 1.0,  # Dummy metric for scheduler
            "success": True
        }
    
    except Exception as e:
        print(f"Error in experiment with seed {seed}: {str(e)}")
        # Return error information
        return {
            "seed": seed,
            "_metric": 1.0,  # Dummy metric for scheduler
            "success": False,
            "error": str(e)
        }


if __name__ == "__main__":
    # Read config file from command line argument like in run_ood_detection.py
    config_name = sys.argv[1]
    with open(config_name, mode="r") as f:
        base_cfg = toml.load(f)
    
    # Initialize Ray
    ray.init(num_cpus=os.cpu_count(), num_gpus=torch.cuda.device_count(), _temp_dir=None)
    
    # Count available GPUs
    num_gpus = torch.cuda.device_count()
    if num_gpus == 0:
        print("No GPUs available. Running on CPU.")
    else:
        print(f"Using all available GPUs: {num_gpus}")
    
    print(f"Running experiments with seeds: {SEEDS}")
    
    # Configure resources per trial based on available GPUs
    gpus_per_trial = 1 if num_gpus > 0 else 0
    
    # Create Ray Tune search space - just different seeds with the same config
    search_space = {
        "base_cfg": base_cfg,
        "seed": tune.grid_search(SEEDS)
    }
    
    # Create a resource scheduler
    scheduler = ASHAScheduler(
        metric="_metric",  # Dummy metric name, we're not using the scheduler for early stopping
        mode="max",
        max_t=1,  # Since we're just running with different seeds
        grace_period=1,
        reduction_factor=2
    )
    
    # Run the experiments
    result = tune.run(
        run_experiment,
        config=search_space,
        resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
        num_samples=1,  # We're using grid search for seeds
        scheduler=scheduler,
        name=f"ood_seeds_{os.path.basename(config_name)}"
    )
    
    # Print summary of results
    print("\nResults Summary:")
    print("=" * 100)
    print(f"{'Run #':<10} {'Seed':<10} {'Status':<10} {'AUROC':<10} {'AUPR_in':<10} {'AUPR_out':<10} {'FPR95':<10} {'Detection Acc':<15}")
    print("-" * 100)
    
    # Initialize dictionaries for calculating averages and collect all values for std
    avg_metrics = {
        "END_AUROC": 0,
        "END_AUPR_in": 0,
        "END_AUPR_out": 0,
        "END_FPR95": 0,
        "END_DETECTION_acc": 0
    }
    
    all_values = {
        "END_AUROC": [],
        "END_AUPR_in": [],
        "END_AUPR_out": [],
        "END_FPR95": [],
        "END_DETECTION_acc": []
    }
    
    # Counters for successful and failed runs
    successful_runs = 0
    failed_runs = 0
    
    # Print results for all runs
    for i, trial in enumerate(result.trials):
        run_num = i + 1
        
        if not trial.last_result:
            print(f"{run_num:<10} {'N/A':<10} {'NO RESULT':<10} {'--':<10} {'--':<10} {'--':<10} {'--':<10} {'--':<15}")
            failed_runs += 1
            continue
            
        seed = trial.last_result.get("seed", "N/A")
        success = trial.last_result.get("success", False)
        
        if not success:
            error = trial.last_result.get("error", "Unknown error")
            print(f"{run_num:<10} {seed:<10} {'FAILED':<10} {'--':<10} {'--':<10} {'--':<10} {'--':<10} {'--':<15}")
            print(f"    Error: {error}")
            failed_runs += 1
            continue
        
        if "metrics" not in trial.last_result:
            print(f"{run_num:<10} {seed:<10} {'NO METRICS':<10} {'--':<10} {'--':<10} {'--':<10} {'--':<10} {'--':<15}")
            failed_runs += 1
            continue
            
        # Extract metrics from successful run
        metrics = trial.last_result["metrics"]
        auroc = metrics["END_AUROC"]
        aupr_in = metrics["END_AUPR_in"]
        aupr_out = metrics["END_AUPR_out"]
        fpr95 = metrics["END_FPR95"]
        detection_acc = metrics["END_DETECTION_acc"]
        
        print(f"{run_num:<10} {seed:<10} {'SUCCESS':<10} {auroc:<10.4f} {aupr_in:<10.4f} {aupr_out:<10.4f} {fpr95:<10.4f} {detection_acc:<15.4f}")
        
        # Accumulate metrics for average calculation
        avg_metrics["END_AUROC"] += auroc
        avg_metrics["END_AUPR_in"] += aupr_in
        avg_metrics["END_AUPR_out"] += aupr_out
        avg_metrics["END_FPR95"] += fpr95
        avg_metrics["END_DETECTION_acc"] += detection_acc
        
        # Collect values for standard deviation calculation
        all_values["END_AUROC"].append(auroc)
        all_values["END_AUPR_in"].append(aupr_in)
        all_values["END_AUPR_out"].append(aupr_out)
        all_values["END_FPR95"].append(fpr95)
        all_values["END_DETECTION_acc"].append(detection_acc)
        
        successful_runs += 1
    
    # Calculate and print averages and standard deviations for successful runs
    if successful_runs > 0:
        # Calculate average and std dev
        for metric in avg_metrics:
            avg_metrics[metric] /= successful_runs
            std_metrics = {
                key: np.std(values) if values else 0 
                for key, values in all_values.items()
            }
        
        print("-" * 100)
        print(f"{'Average':<10} {'':<10} {'':<10} {avg_metrics['END_AUROC']:<10.4f} {avg_metrics['END_AUPR_in']:<10.4f} "
              f"{avg_metrics['END_AUPR_out']:<10.4f} {avg_metrics['END_FPR95']:<10.4f} {avg_metrics['END_DETECTION_acc']:<15.4f}")
        print(f"{'Std Dev':<10} {'':<10} {'':<10} {std_metrics['END_AUROC']:<10.4f} {std_metrics['END_AUPR_in']:<10.4f} "
              f"{std_metrics['END_AUPR_out']:<10.4f} {std_metrics['END_FPR95']:<10.4f} {std_metrics['END_DETECTION_acc']:<15.4f}")
    else:
        print("-" * 100)
        print("No successful runs to report statistics.")
    
    print("=" * 100)
    print(f"Run summary: {successful_runs} successful, {failed_runs} failed, {len(result.trials)} total")
    
    # Also report a short summary to easily copy-paste if there were successful runs
    if successful_runs > 0:
        print(f"model: {base_cfg['model']['name']}")
        print(f"dataset: {base_cfg['dataset']['name']}")
        print("\nShort Summary (Avg ± Std):")
        print(f"AUROC: {avg_metrics['END_AUROC']:.4f} ± {std_metrics['END_AUROC']:.4f}")
        print(f"AUPR_in: {avg_metrics['END_AUPR_in']:.4f} ± {std_metrics['END_AUPR_in']:.4f}")
        print(f"AUPR_out: {avg_metrics['END_AUPR_out']:.4f} ± {std_metrics['END_AUPR_out']:.4f}")
        print(f"FPR95: {avg_metrics['END_FPR95']:.4f} ± {std_metrics['END_FPR95']:.4f}")
        print(f"Detection Acc: {avg_metrics['END_DETECTION_acc']:.4f} ± {std_metrics['END_DETECTION_acc']:.4f}")
        print(f"Based on {successful_runs}/{len(result.trials)} successful runs") 