import os
import time
import json
import yaml  # you might need to install PyYAML for this script
import random
import argparse
import numpy as np
from datetime import datetime
import torch


# =======================
# Not needed, just install the package
# =======================
# 🔧 EDIT: If you install 'causal_profiler' you can delete this section
import sys

current_dir = os.path.dirname(os.path.realpath(__file__))
# Get the root directory of the project
project_root = os.path.abspath(
    os.path.join(current_dir, os.pardir, os.pardir, "causal_profiler")
)
# Add the project root directory to the Python path
sys.path.insert(0, project_root)

# =======================
# Import CausalProfiler
# =======================
from causal_profiler import (
    CausalProfiler,
    SpaceOfInterest,
    ErrorMetric,
    MechanismFamily,
    NoiseMode,
    NoiseDistribution,
    VariableDataType,
    QueryType,
    NeuralNetworkType,
    FunctionSampling,
)


# =======================
# Import your CausalNF wrapper
# =======================
from my_causal_method import MyCausalMethod


# =======================
# Utility functions
# =======================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def to_enum(enum_class, value):
    if isinstance(value, str):
        return enum_class[value]
    return value


def log_print(msg, logfile=None):
    print(msg)
    if logfile:
        with open(logfile, "a") as f:
            f.write(msg + "\n")


def json_dump_convert(o):
    if isinstance(o, (np.integer, np.int64)):
        return int(o)
    if isinstance(o, (np.floating, np.float64)):
        return float(o)
    if isinstance(o, np.ndarray):
        return o.tolist()
    raise TypeError(f"Object of type {type(o)} is not JSON serializable")


# =======================
# Main evaluation script
# =======================
def evaluate(
    config_file,
    output_dir,
    use_wandb=False,
    num_runs=3,
    num_tries=5,
    scaler_type="default",
):
    if use_wandb:
        import wandb

    # Load configuration
    with open(config_file, "r") as f:
        config = yaml.safe_load(f)

    os.makedirs(output_dir, exist_ok=True)
    logfile = os.path.join(output_dir, "log.txt")
    summary_file = os.path.join(output_dir, "summary.json")
    results = []

    # Initialize your method ONCE (can re-use for all runs)
    method = MyCausalMethod(scale=scaler_type)

    for space_cfg in config["spaces"]:
        space_name = space_cfg.pop("name")
        seed_list = space_cfg.pop("seed_list", [42])
        log_print(f"\n=== Evaluating: {space_name} ===", logfile)

        # Convert the yaml entries into spaces of interest
        # Convert list values to tuples for SpaceOfInterest
        for key, value in space_cfg.items():
            if isinstance(value, list) and key != "noise_args":
                space_cfg[key] = tuple(value)

        # Convert string values to enums
        enum_fields = {
            "mechanism_family": MechanismFamily,
            "noise_mode": NoiseMode,
            "noise_distribution": NoiseDistribution,
            "variable_type": VariableDataType,
            "query_type": QueryType,
            "discrete_function_sampling": FunctionSampling,
        }
        for key, enum_cls in enum_fields.items():
            if key in space_cfg:
                space_cfg[key] = to_enum(enum_cls, space_cfg[key])

        if space_cfg["mechanism_family"] == MechanismFamily.NEURAL_NETWORK:
            space_cfg["mechanism_args"] = [NeuralNetworkType.FEEDFORWARD, 8, 8]

        # Initialize profiler
        space = SpaceOfInterest(**space_cfg)
        profiler = CausalProfiler(space_of_interest=space, metric=ErrorMetric.L2)

        for seed in seed_list:
            set_seed(seed)
            seed_log_id = f"{space_name}_seed{seed}"
            log_print(f"\n--> Seed: {seed_log_id}", logfile)

            for run in range(num_runs):
                run_id = f"{seed_log_id}_run{run}"
                log_print(f"  Run {run+1}/{num_runs}", logfile)

                # Generate samples and queries for this run
                data, (queries, targets), (graph, index_to_variable) = (
                    profiler.generate_samples_and_queries()
                )

                # Track errors and failures for this run's tries
                run_errors = []
                run_failures = []
                run_runtimes = []

                # Perform multiple tries to get average error
                for try_idx in range(num_tries):
                    start_time = time.perf_counter()

                    # Use your method to answer all queries at once
                    try:
                        user_estimates = method.estimate(
                            queries, data, graph, index_to_variable, space=space
                        )
                    except Exception as e:
                        log_print(f"      Batch query failed: {e}", logfile)
                        user_estimates = [np.nan] * len(queries)

                    duration = time.perf_counter() - start_time

                    error, num_failed = profiler.evaluate_error(
                        estimated=user_estimates, target=targets
                    )

                    run_errors.append(error)
                    run_failures.append(num_failed)
                    run_runtimes.append(duration)

                    log_print(
                        f"    Try {try_idx+1}/{num_tries}: Error={error:.4f}, Failed={num_failed}",
                        logfile,
                    )

                # Create result record for this run
                result = {
                    "space": space_name,
                    "seed": seed,
                    "run": run,
                    "method": method.__class__.__name__,
                    "run_error_mean": np.mean(run_errors),
                    "run_error_std": np.std(run_errors),
                    "errors": run_errors,  # Store all try errors
                    "run_failures_mean": np.mean(run_failures),
                    "run_failures_std": np.std(run_failures),
                    "failures_all": run_failures,  # Store all try failures
                    "run_runtime_mean": np.mean(run_runtimes),
                    "run_runtime_std": np.std(run_runtimes),
                    "runtimes": run_runtimes,  # Store all try runtimes
                    "num_tries": num_tries,
                }

                log_print(
                    f"  Run Summary: Avg Error={result['run_error_mean']:.4f}±{result['run_error_std']:.4f}",
                    logfile,
                )
                results.append(result)

                # Save intermediate result
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                partial_file = os.path.join(
                    output_dir, f"result_{run_id}_{timestamp}.json"
                )
                with open(partial_file, "w") as f:
                    json.dump(result, f, indent=2, default=json_dump_convert)

                if use_wandb:
                    wandb.log(result)

    # Final summary
    with open(summary_file, "w") as f:
        json.dump(results, f, indent=2, default=json_dump_convert)
    log_print("\n✅ Evaluation complete. Summary saved to:\n" + summary_file, logfile)


# =======================
# Entry point
# =======================
# python evaluate.py --config spaces.yaml --output_dir results/method1
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config", type=str, default="spaces.yaml", help="Path to YAML config file."
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="results/method1",
        help="Directory to save logs and results.",
    )
    parser.add_argument(
        "--wandb", action="store_true", help="Enable logging to Weights & Biases."
    )
    parser.add_argument(
        "--num_runs", type=int, default=3, help="Number of runs per seed."
    )
    parser.add_argument(
        "--num_tries",
        type=int,
        default=5,
        help="Number of tries per run to average error.",
    )
    parser.add_argument(
        "--scaler_type",
        type=str,
        default="default",
        choices=["default", "min0_max1"],
        help="Type of scaler to use: 'default' or 'min0_max1'",
    )
    args = parser.parse_args()

    evaluate(
        config_file=args.config,
        output_dir=args.output_dir,
        use_wandb=args.wandb,
        num_runs=args.num_runs,
        num_tries=args.num_tries,
        scaler_type=args.scaler_type,
    )
