#!/usr/bin/env python3
"""
Aggregate results from cluster execution into a summary table.

Usage:
    python aggregate_cluster_results.py results_dir/
    python aggregate_cluster_results.py results_dir/ --import_to database.db
"""

import json
import argparse
import sys
from pathlib import Path
from collections import defaultdict
from datetime import datetime
import pandas as pd
import numpy as np

# Add parent directory to path to import dashboard modules
sys.path.insert(0, str(Path(__file__).parent.parent))

from flask import Flask
from models import db, Experiment, SubJob


def convert_hyperparameters_to_display_format(hyperparameters):
    """
    Convert hyperparameters from JSON list format to display format.

    Converts formats like:
    - ["uniform", 0.1, 0.9] -> {type: "uniform", mode: "range", min: 0.1, max: 0.9}
    - ["randint", 1, 10] -> {type: "randint", mode: "range", min: 1, max: 10}
    - ["loguniform", 0.001, 1.0] -> {type: "loguniform", mode: "range", min: 0.001, max: 1.0}
    - [value] -> {mode: "fixed", value: value}
    - [val1, val2, ...] (categorical) -> {mode: "range", type: "categorical", choices: [val1, val2, ...]}

    Args:
        hyperparameters: Dict of model_name -> param_name -> param_spec

    Returns:
        Converted hyperparameters in display format
    """
    converted = {}
    distribution_types = {"randint", "uniform", "loguniform", "1-loguniform"}

    for model_name, model_params in hyperparameters.items():
        converted[model_name] = {}
        for param_name, param_spec in model_params.items():
            if isinstance(param_spec, list):
                if len(param_spec) >= 2 and param_spec[0] in distribution_types:
                    # Distribution type: ["uniform", 0.1, 0.9]
                    dist_type = param_spec[0]
                    converted[model_name][param_name] = {
                        "type": dist_type,
                        "mode": "range",
                        "min": param_spec[1],
                        "max": param_spec[2],
                    }
                elif len(param_spec) == 1:
                    # Fixed value: [value]
                    converted[model_name][param_name] = {
                        "mode": "fixed",
                        "value": param_spec[0],
                    }
                else:
                    # Categorical: [val1, val2, ...]
                    converted[model_name][param_name] = {
                        "mode": "range",
                        "type": "categorical",
                        "choices": param_spec,
                    }
            elif isinstance(param_spec, dict):
                # Already in display format, use as-is
                converted[model_name][param_name] = param_spec
            else:
                # Single value (not in list)
                converted[model_name][param_name] = {
                    "mode": "fixed",
                    "value": param_spec,
                }

    return converted


def find_cluster_config(results_dir):
    """
    Find the cluster configuration file in the results directory.

    The config file is copied there by run_cluster_experiment.py and should
    be a JSON file that's not a result file (not matching run_*.json).

    Args:
        results_dir: Path to results directory

    Returns:
        Path to config file, or None if not found
    """
    results_dir = Path(results_dir)

    # Look for JSON files in the root of results_dir that aren't result files
    json_files = list(results_dir.glob("*.json"))

    # Filter out result files (run_*.json pattern) and error files
    config_candidates = [
        f
        for f in json_files
        if not f.name.startswith("run_") and not f.name.endswith("_ERROR.json")
    ]

    if len(config_candidates) == 0:
        return None
    elif len(config_candidates) == 1:
        return config_candidates[0]
    else:
        # Multiple candidates - prefer files with "config" in the name
        config_files = [f for f in config_candidates if "config" in f.name.lower()]
        if config_files:
            return config_files[0]
        # If no "config" in name, return the first one
        return config_candidates[0]


def load_all_results(results_dir):
    """
    Load all result JSON files from the results directory.

    Args:
        results_dir: Directory containing result JSON files

    Returns:
        List of (result_file_path, result_dict) tuples
    """
    results_dir = Path(results_dir)
    result_files = sorted(results_dir.glob("**/run_*.json"))

    all_results = []
    for result_file in result_files:
        try:
            with open(result_file, "r") as f:
                result = json.load(f)
            all_results.append((result_file, result))
        except Exception as e:
            print(f"Warning: Failed to parse {result_file}: {e}")
            continue

    return all_results


def aggregate_results(results_dir):
    """
    Aggregate all result JSON files into a summary table.

    Args:
        results_dir: Directory containing result JSON files

    Returns:
        DataFrame with aggregated results
    """
    results_dir = Path(results_dir)

    if not results_dir.exists():
        print(f"Error: Results directory not found: {results_dir}")
        sys.exit(1)

    all_results = []

    # Find all result JSON files (search recursively in subdirectories)
    result_files = sorted(results_dir.glob("**/run_*.json"))

    if not result_files:
        print(f"Warning: No result files found in {results_dir}")
        return pd.DataFrame()

    print(f"Found {len(result_files)} result files")

    for result_file in result_files:
        try:
            with open(result_file, "r") as f:
                result = json.load(f)

            # Extract key information
            row = {
                "run_id": result_file.stem,
                "dataset": result.get("dataset", "unknown"),
                "model": result.get("model", "unknown"),
                "success": result.get("success", False),
            }

            # Add performance metrics
            if "mean_test_rmse" in result:
                row["mean_test_rmse"] = result["mean_test_rmse"]
                row["std_test_rmse"] = result["std_test_rmse"]
                row["min_test_rmse"] = result["min_test_rmse"]
                row["max_test_rmse"] = result["max_test_rmse"]
                row["n_folds"] = result.get("n_folds", 1)
            elif "test_rmse" in result:
                row["test_rmse"] = result["test_rmse"]

            # Add error if present
            if "error" in result:
                row["error"] = result["error"]

            all_results.append(row)

        except Exception as e:
            print(f"Warning: Failed to parse {result_file}: {e}")
            continue

    if not all_results:
        print("Error: No valid results found")
        return pd.DataFrame()

    # Create DataFrame
    df = pd.DataFrame(all_results)

    # Sort by dataset, then model
    df = df.sort_values(["dataset", "model"])

    return df


def import_to_database(results_dir, db_path, experiment_name_override=None):
    """
    Import cluster results into SQLite database.

    Each dataset becomes one Experiment, with SubJobs for each model×fold.

    Args:
        results_dir: Directory containing cluster results and config file
        db_path: Path to SQLite database file
        experiment_name_override: Optional name to use instead of folder name in experiment name
    """
    results_dir = Path(results_dir)

    # Find the cluster config file
    config_path = find_cluster_config(results_dir)
    cluster_config = None
    if config_path:
        try:
            with open(config_path, "r") as f:
                cluster_config = json.load(f)
            print(f"✅ Found cluster config: {config_path}")
        except Exception as e:
            print(f"⚠️  Warning: Failed to load cluster config: {e}")
    else:
        print("⚠️  Warning: No cluster config found, will reconstruct from results")

    # Load all result files
    all_results = load_all_results(results_dir)
    if not all_results:
        print("❌ Error: No result files found")
        return

    print(f"Found {len(all_results)} result files")

    # Group results by dataset and model, then select latest successful run for each combination
    results_by_dataset_model = defaultdict(list)
    for result_file, result in all_results:
        dataset_name = result.get("dataset", "unknown")
        model_name = result.get("model", "unknown")
        key = (dataset_name, model_name)
        results_by_dataset_model[key].append((result_file, result))

    # For each dataset+model combination, keep only the first successful result
    results_by_dataset = defaultdict(list)
    for (dataset_name, model_name), model_results in results_by_dataset_model.items():
        # Filter to successful results only
        successful_results = [
            (result_file, result)
            for result_file, result in model_results
            if result.get("success") is True
            or result.get("test_rmse") is not None
            or (
                result.get("results")
                and any(r.get("test_rmse") for r in result.get("results", []))
            )
        ]

        if successful_results:
            # Select the first successful result found
            first_successful = successful_results[0]
            results_by_dataset[dataset_name].append(first_successful)

            # Log if we skipped any results
            if len(model_results) > 1:
                skipped = len(model_results) - 1
                selected_file = first_successful[0].name
                print(
                    f"  ℹ️  {dataset_name}/{model_name}: selected {selected_file} from {len(model_results)} runs (skipped {skipped} other/failed)"
                )
        else:
            # No successful results for this dataset+model combination
            # Skip it - it will be caught by the "no valid results" check later if all models fail
            if len(model_results) > 1:
                print(
                    f"  ⚠️  {dataset_name}/{model_name}: all {len(model_results)} runs failed, skipping"
                )

    print(f"Found {len(results_by_dataset)} unique datasets")

    # Initialize Flask app and database
    app = Flask(__name__)
    app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{Path(db_path).absolute()}"
    app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
    db.init_app(app)

    with app.app_context():
        # Create tables if they don't exist
        db.create_all()

        imported_count = 0
        failed_count = 0

        for dataset_name, dataset_results in results_by_dataset.items():
            try:
                # Collect all models for this dataset (only from results with valid data)
                models = set()
                for _, result in dataset_results:
                    # Include models that have valid results (success=True, or have test_rmse/results even if success=False)
                    has_valid_data = (
                        result.get("success") is True
                        or result.get("test_rmse") is not None
                        or result.get("results")
                    )
                    if has_valid_data:
                        model_name = result.get("model", "unknown")
                        if model_name != "unknown":
                            models.add(model_name)

                models = sorted(list(models))

                # Skip if no models with valid results
                if not models:
                    print(
                        f"\n⚠️  Skipping dataset {dataset_name}: no successful results"
                    )
                    failed_count += 1
                    continue

                print(f"\nProcessing dataset: {dataset_name} ({len(models)} models)")

                # Get dataset config from result files
                dataset_config = None
                for _, result in dataset_results:
                    if "dataset_config" in result:
                        dataset_config = result["dataset_config"]
                        break

                if not dataset_config:
                    # Fallback to cluster_config if result files don't have it
                    if cluster_config:
                        runs = cluster_config.get("runs", [])
                        for run in runs:
                            if run.get("dataset", {}).get("name") == dataset_name:
                                dataset_config = run.get("dataset")
                                break

                    if not dataset_config:
                        raise ValueError(
                            f"No dataset_config found for {dataset_name} in result files or cluster_config"
                        )

                # Get optimization, cv, and resources config from cluster_config
                global_config = (
                    cluster_config.get("global_config", {}) if cluster_config else {}
                )
                optimization_config = global_config.get("optimization", {})
                cv_config = global_config.get("cv", {})
                resources_config = global_config.get("resources", {})
                hyperparameters_raw = (
                    cluster_config.get("hyperparameters", {}) if cluster_config else {}
                )

                # Convert hyperparameters from JSON list format to display format
                hyperparameters = convert_hyperparameters_to_display_format(
                    hyperparameters_raw
                )

                # Build experiment config JSON (matching dashboard format)
                experiment_config = {
                    "dataset": dataset_config,
                    "models": {
                        "model_names": models,
                        "parameters": hyperparameters,
                    },
                    "optimization": optimization_config,
                    "cv": cv_config,
                    "resources": resources_config,
                }

                # Aggregate results across all models
                all_individual_results = []
                model_summaries = []
                optuna_storage_paths = []
                fitted_models_paths = {}
                failed_jobs = []

                # Try to get n, p from dataset if possible
                n = None
                p = None
                try:
                    from services.datasets import load_dataset

                    dataset_info = load_dataset(dataset_config)
                    n = dataset_info["X"].shape[0]
                    p = dataset_info["X"].shape[1]
                except Exception:
                    # If we can't load dataset, leave n and p as None
                    pass

                for result_file, result in dataset_results:
                    model_name = result.get("model", "unknown")

                    # Skip failed results that have no valid data
                    if result.get("success") is False:
                        # Only add to failed_jobs if it's a complete failure with no partial results
                        if result.get("test_rmse") is None and not result.get(
                            "results"
                        ):
                            failed_jobs.append(
                                {
                                    "model": model_name,
                                    "fold_index": None,
                                    "error": result.get("error", "Unknown error"),
                                }
                            )
                            print(f"  ⚠️  Skipping failed result: {result_file.name}")
                            continue

                    # Collect optuna storage paths
                    if cv_config.get("strategy") == "nested":
                        # Look for fold-specific log files
                        safe_dataset = dataset_name.replace(" ", "_").replace("/", "_")
                        safe_model = model_name.replace(" ", "_").replace("/", "_")
                        result_dir = result_file.parent
                        log_files = list(
                            result_dir.glob(f"{safe_dataset}_{safe_model}_fold*.log")
                        )
                        optuna_storage_paths.extend([str(f) for f in log_files])
                    else:
                        safe_dataset = dataset_name.replace(" ", "_").replace("/", "_")
                        safe_model = model_name.replace(" ", "_").replace("/", "_")
                        result_dir = result_file.parent
                        log_file = result_dir / f"{safe_dataset}_{safe_model}.log"
                        if log_file.exists():
                            optuna_storage_paths.append(str(log_file))

                    # Collect fitted model paths
                    fitted_model_path = result.get("fitted_model_path")
                    if fitted_model_path:
                        fitted_models_paths[model_name] = fitted_model_path

                    # Process individual fold results
                    individual_results = result.get("results", [])
                    if not individual_results:
                        # Simple CV case - result itself is the single result
                        if result.get("test_rmse") is not None:
                            individual_results = [result]

                    model_fold_results = []
                    for fold_result in individual_results:
                        if fold_result.get("test_rmse") is not None:
                            all_individual_results.append(
                                {
                                    "model": model_name,
                                    "fold_index": fold_result.get("fold"),
                                    "test_rmse": fold_result.get("test_rmse"),
                                    "test_mse": fold_result.get("test_mse"),
                                    "best_cv_score": fold_result.get("best_cv_score"),
                                    "best_params": fold_result.get("best_params", {}),
                                    "fixed_params": fold_result.get("fixed_params", {}),
                                    "fit_time": fold_result.get("fit_time"),
                                }
                            )
                            model_fold_results.append(fold_result)
                        else:
                            failed_jobs.append(
                                {
                                    "model": model_name,
                                    "fold_index": fold_result.get("fold"),
                                    "error": fold_result.get("error", "Unknown error"),
                                }
                            )

                    # Calculate model summary
                    if model_fold_results:
                        test_rmses = [r["test_rmse"] for r in model_fold_results]
                        model_summaries.append(
                            {
                                "model": model_name,
                                "mean_test_rmse": float(np.mean(test_rmses)),
                                "std_test_rmse": float(np.std(test_rmses)),
                                "min_test_rmse": float(np.min(test_rmses)),
                                "max_test_rmse": float(np.max(test_rmses)),
                                "n_folds": len(test_rmses),
                            }
                        )

                # Skip creating experiment if all results are failed
                if len(all_individual_results) == 0 and len(failed_jobs) > 0:
                    print(f"  ⚠️  Skipping dataset {dataset_name}: all results failed")
                    failed_count += 1
                    continue

                # Determine best model and RMSE
                if model_summaries:
                    best_model_summary = min(
                        model_summaries, key=lambda x: x["mean_test_rmse"]
                    )
                    best_rmse = best_model_summary["mean_test_rmse"]
                    best_model = best_model_summary["model"]
                else:
                    best_rmse = None
                    best_model = None

                # Build results JSON
                results_summary = {
                    "individual_results": all_individual_results,
                    "model_summaries": model_summaries,
                    "best_rmse": best_rmse,
                    "best_model": best_model,
                    "total_models_run": len(models),
                    "failed_jobs": len(failed_jobs),
                }

                # Determine status
                if failed_jobs and len(all_individual_results) == 0:
                    status = "failed"
                elif failed_jobs:
                    status = "failed"  # Partial failure
                else:
                    status = "completed"

                # Get timestamps from ALL results to find earliest start and latest end
                # This gives accurate runtime for experiments with multiple models
                all_start_timestamps = []
                all_end_timestamps = []

                for _, result in dataset_results:
                    start_ts = result.get("start_timestamp")
                    end_ts = result.get("end_timestamp")

                    # Fallback to old "timestamp" field if new fields not available
                    if not start_ts:
                        start_ts = result.get("timestamp")
                    if not end_ts:
                        end_ts = result.get("timestamp")

                    if start_ts:
                        all_start_timestamps.append(start_ts)
                    if end_ts:
                        all_end_timestamps.append(end_ts)

                # Parse timestamps - find earliest start and latest end
                started_at = None
                completed_at = None

                if all_start_timestamps:
                    parsed_starts = []
                    for start_ts_str in all_start_timestamps:
                        try:
                            if "Z" in start_ts_str:
                                parsed_starts.append(
                                    datetime.fromisoformat(
                                        start_ts_str.replace("Z", "+00:00")
                                    )
                                )
                            else:
                                parsed_starts.append(
                                    datetime.fromisoformat(start_ts_str)
                                )
                        except Exception as e:
                            print(
                                f"Warning: Failed to parse start_timestamp '{start_ts_str}': {e}"
                            )
                            continue
                    if parsed_starts:
                        started_at = min(parsed_starts)  # Earliest start time

                if all_end_timestamps:
                    parsed_ends = []
                    for end_ts_str in all_end_timestamps:
                        try:
                            if "Z" in end_ts_str:
                                parsed_ends.append(
                                    datetime.fromisoformat(
                                        end_ts_str.replace("Z", "+00:00")
                                    )
                                )
                            else:
                                parsed_ends.append(datetime.fromisoformat(end_ts_str))
                        except Exception as e:
                            print(
                                f"Warning: Failed to parse end_timestamp '{end_ts_str}': {e}"
                            )
                            continue
                    if parsed_ends:
                        completed_at = max(parsed_ends)  # Latest end time

                # Use parsed timestamps or fallback to current time
                if not started_at:
                    started_at = datetime.utcnow()
                if not completed_at:
                    completed_at = datetime.utcnow()

                # Ensure completed_at is after started_at
                if completed_at < started_at:
                    completed_at = started_at

                # Use completed_at for created_at (backward compatibility)
                timestamp = completed_at

                # Calculate trials
                n_trials = optimization_config.get("n_trials", 100)
                total_trials = (
                    n_trials
                    * len(models)
                    * (
                        cv_config.get("outer_cv_folds", 1)
                        if cv_config.get("strategy") == "nested"
                        else 1
                    )
                )
                trials_completed = (
                    total_trials  # Assume all completed if we have results
                )

                # Build experiment name: [name_or_folder]-[openml task id]-datasetname
                name_prefix = (
                    experiment_name_override
                    if experiment_name_override
                    else results_dir.name
                )
                task_id = None
                if dataset_config and dataset_config.get("type") == "openml_task":
                    task_id = dataset_config.get("task_id")

                if task_id is not None:
                    experiment_name = f"{name_prefix}-{task_id}-{dataset_name}"
                else:
                    # Fallback if no task_id available
                    experiment_name = f"{name_prefix}-{dataset_name}"

                # Create Experiment record
                experiment = Experiment(
                    name=experiment_name,
                    description=f"Imported from cluster execution on {timestamp.strftime('%Y-%m-%d %H:%M:%S')}",
                    status=status,
                    priority="normal",
                    tags=json.dumps(["cluster", "imported"]),
                    created_at=timestamp,
                    started_at=started_at,
                    completed_at=completed_at,
                    config=json.dumps(experiment_config, default=str),
                    results=json.dumps(results_summary, default=str),
                    error_log=json.dumps(failed_jobs) if failed_jobs else None,
                    optuna_storage_name="; ".join(optuna_storage_paths)
                    if optuna_storage_paths
                    else None,
                    job_id=None,  # No RQ job for cluster imports
                    sub_job_ids=json.dumps([]),
                    progress=1.0,
                    trials_completed=trials_completed,
                    total_trials=total_trials,
                    n_trials=n_trials,
                    best_rmse=best_rmse,
                    best_model=best_model,
                    is_task_collection_parent=False,
                    n=n,
                    p=p,
                    fitted_models_paths=json.dumps(fitted_models_paths)
                    if fitted_models_paths
                    else None,
                )

                db.session.add(experiment)
                db.session.flush()  # Get the experiment ID

                # Create SubJob records for each model×fold
                for result_file, result in dataset_results:
                    model_name = result.get("model", "unknown")
                    individual_results = result.get("results", [])

                    if not individual_results:
                        # Simple CV case
                        if result.get("test_rmse") is not None:
                            individual_results = [result]

                    for fold_result in individual_results:
                        fold_idx = fold_result.get("fold")
                        sub_job = SubJob(
                            experiment_id=experiment.id,
                            rq_job_id=None,
                            model_name=model_name,
                            fold_index=fold_idx,
                            status="completed"
                            if fold_result.get("test_rmse") is not None
                            else "failed",
                            started_at=timestamp,
                            completed_at=timestamp,
                            test_rmse=fold_result.get("test_rmse"),
                            test_mse=fold_result.get("test_mse"),
                            best_cv_score=fold_result.get("best_cv_score"),
                            best_params=json.dumps(fold_result.get("best_params", {}))
                            if fold_result.get("best_params")
                            else None,
                            fixed_params=json.dumps(fold_result.get("fixed_params", {}))
                            if fold_result.get("fixed_params")
                            else None,
                            fit_time=fold_result.get("fit_time"),
                            error_log=fold_result.get("error"),
                        )
                        db.session.add(sub_job)

                db.session.commit()
                imported_count += 1
                print(
                    f"  ✅ Imported experiment: {experiment.name} (ID: {experiment.id})"
                )

            except Exception as e:
                db.session.rollback()
                print(f"  ❌ Failed to import dataset {dataset_name}: {e}")
                import traceback

                traceback.print_exc()
                failed_count += 1

        print(f"\n{'=' * 80}")
        print(
            f"Import complete: {imported_count} datasets imported, {failed_count} failed"
        )
        print(f"{'=' * 80}")


def main():
    parser = argparse.ArgumentParser(
        description="Aggregate cluster execution results into a summary table"
    )
    parser.add_argument(
        "results_dir",
        type=str,
        help="Directory containing result JSON files",
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Output file path (CSV). If not specified, prints to stdout",
    )
    parser.add_argument(
        "--format",
        type=str,
        default="csv",
        choices=["csv", "json", "markdown"],
        help="Output format",
    )
    parser.add_argument(
        "--import_to",
        type=str,
        default=None,
        help="Import results into SQLite database at this path",
    )
    parser.add_argument(
        "--name",
        type=str,
        default=None,
        help="Name to use instead of folder name in experiment name (only used with --import_to)",
    )

    args = parser.parse_args()

    # If --import_to is specified, import to database and exit
    if args.import_to:
        import_to_database(args.results_dir, args.import_to, args.name)
        return

    # Aggregate results
    df = aggregate_results(args.results_dir)

    if df.empty:
        print("No results to aggregate")
        sys.exit(1)

    # Print summary
    print(f"\n{'=' * 80}")
    print("Summary Statistics")
    print(f"{'=' * 80}")
    print(f"Total runs: {len(df)}")
    print(f"Successful: {df['success'].sum()}")
    print(f"Failed: {(~df['success']).sum()}")

    if "mean_test_rmse" in df.columns:
        successful_df = df[df["success"]]
        if len(successful_df) > 0:
            print(f"\nBest RMSE: {successful_df['mean_test_rmse'].min():.4f}")
            print(f"Worst RMSE: {successful_df['mean_test_rmse'].max():.4f}")
            print(f"Mean RMSE: {successful_df['mean_test_rmse'].mean():.4f}")

    print(f"{'=' * 80}\n")

    # Output results
    if args.output:
        output_path = Path(args.output)
        if args.format == "csv":
            df.to_csv(output_path, index=False)
            print(f"✅ Results saved to: {output_path}")
        elif args.format == "json":
            df.to_json(output_path, orient="records", indent=2)
            print(f"✅ Results saved to: {output_path}")
        elif args.format == "markdown":
            df.to_markdown(output_path, index=False)
            print(f"✅ Results saved to: {output_path}")
    else:
        # Print to stdout
        if args.format == "csv":
            print(df.to_csv(index=False))
        elif args.format == "json":
            print(df.to_json(orient="records", indent=2))
        elif args.format == "markdown":
            print(df.to_markdown(index=False))


if __name__ == "__main__":
    main()
