import os
import time
import json
import yaml  # you might need to install PyYAML for this script
import random
import argparse
import numpy as np
from datetime import datetime
import torch


# =======================
# Not needed, just install the package
# =======================
# 🔧 EDIT: If you install 'causal_profiler' you can delete this section
import sys

current_dir = os.path.dirname(os.path.realpath(__file__))
# Get the root directory of the project
project_root = os.path.abspath(
    os.path.join(current_dir, os.pardir, os.pardir, "causal_profiler")
)
# Add the project root directory to the Python path
sys.path.insert(0, project_root)

# =======================
# Import CausalProfiler
# =======================
from causal_profiler import (
    CausalProfiler,
    SpaceOfInterest,
    ErrorMetric,
    MechanismFamily,
    NoiseMode,
    NoiseDistribution,
    VariableDataType,
    QueryType,
    NeuralNetworkType,
)


# =======================
# Dummy placeholder method
# =======================
# 🔧 EDIT: import your own method
from my_causal_method import MyCausalMethod

# from my_causal_method2 import MyCausalMethod2 as MyCausalMethod


# =======================
# Utility functions
# =======================
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def to_enum(enum_class, value):
    if isinstance(value, str):
        return enum_class[value]
    return value


def log_print(msg, logfile=None):
    print(msg)
    if logfile:
        with open(logfile, "a") as f:
            f.write(msg + "\n")


def json_dump_convert(o):
    if isinstance(o, (np.integer, np.int64)):
        return int(o)
    if isinstance(o, (np.floating, np.float64)):
        return float(o)
    if isinstance(o, np.ndarray):
        return o.tolist()
    raise TypeError(f"Object of type {type(o)} is not JSON serializable")


# =======================
# Main evaluation script
# =======================
def evaluate(config_file, output_dir, use_wandb=False, num_runs=3, num_tries=5):
    if use_wandb:
        import wandb

    # Load configuration
    with open(config_file, "r") as f:
        config = yaml.safe_load(f)

    os.makedirs(output_dir, exist_ok=True)
    logfile = os.path.join(output_dir, "log.txt")
    summary_file = os.path.join(output_dir, "summary.json")
    results = []

    for space_cfg in config["spaces"]:
        space_name = space_cfg.pop("name")
        if space_name == "nn_no_hidden":
            print("skipping nn_no_hiden")
            continue
        seed_list = space_cfg.pop("seed_list", [42])
        log_print(f"\n=== Evaluating: {space_name} ===", logfile)

        # Convert the yaml entries into spaces of interest
        # Convert list values to tuples for SpaceOfInterest
        for key, value in space_cfg.items():
            if isinstance(value, list) and key != "noise_args":
                space_cfg[key] = tuple(value)

        # Convert string values to enums
        enum_fields = {
            "mechanism_family": MechanismFamily,
            "noise_mode": NoiseMode,
            "noise_distribution": NoiseDistribution,
            "variable_type": VariableDataType,
            "query_type": QueryType,
        }
        for key, enum_cls in enum_fields.items():
            if key in space_cfg:
                space_cfg[key] = to_enum(enum_cls, space_cfg[key])

        if space_cfg["mechanism_family"] == MechanismFamily.NEURAL_NETWORK:
            space_cfg["mechanism_args"] = [NeuralNetworkType.FEEDFORWARD, 8, 8]

        # Initialize profiler
        space = SpaceOfInterest(**space_cfg)
        profiler = CausalProfiler(space_of_interest=space, metric=ErrorMetric.L2)
        # 🔧 EDIT: initialize with your own method
        method = MyCausalMethod()
        method_name = method.__class__.__name__

        for seed in seed_list:
            set_seed(seed)
            seed_log_id = f"{space_name}_seed{seed}"
            log_print(f"\n--> Seed: {seed_log_id}", logfile)

            for run in range(num_runs):
                run_id = f"{seed_log_id}_run{run}"
                log_print(f"  Run {run+1}/{num_runs}", logfile)

                # Generate samples and queries for this run
                data, (queries, targets), (graph, index_to_variable) = (
                    profiler.generate_samples_and_queries()
                )
                # Just for this experiment, get the exact number of hidden variables and pass it to MyCausalMethod
                num_hidden = len(
                    [
                        1
                        for var in profiler.sampler._scm.variables.values()
                        if not var.exogenous and not var.visible
                    ]
                )
                # print(graph)
                ########################################################
                # DeCaFlow uses the true graph not the projected one, so extract it manually
                import copy

                scm = copy.deepcopy(profiler.sampler._scm)
                _, orderx = scm.extract_adjacency_list()
                for var in scm.variables.values():
                    if not var.exogenous:
                        var.visible = True
                adjlist, ordery = scm.extract_adjacency_list()

                # Create new ordery with orderx as suffix
                orderx_set = set(orderx)
                new_ordery = [v for v in ordery if v not in orderx_set] + orderx

                # Create a mapping from old index to new index
                old_index = {v: i for i, v in enumerate(ordery)}
                new_index = {v: i for i, v in enumerate(new_ordery)}
                index_map = {old_index[v]: new_index[v] for v in ordery}

                # Remap adjacency list
                new_adjlist = {}
                for old_src, old_targets in adjlist.items():
                    new_src = index_map[old_src]
                    new_targets = [index_map[tgt] for tgt in old_targets]
                    new_adjlist[new_src] = new_targets

                # Sort new_adjlist by key
                # new_adjlist = dict(sorted(new_adjlist.items()))
                graph, index_to_variable = new_adjlist, new_ordery
                # print(graph)
                ########################################################

                # Track errors and failures for this run's tries
                run_errors = []
                run_failures = []
                run_runtimes = []

                # Perform multiple tries to get average error
                for try_idx in range(num_tries):
                    method = MyCausalMethod(num_hidden=num_hidden)
                    start_time = time.perf_counter()
                    # 🔧 EDIT: call your own method to answer these queries
                    user_estimates = [
                        method.estimate(query, data, graph, index_to_variable)
                        for query in queries
                    ]
                    duration = time.perf_counter() - start_time

                    error, num_failed = profiler.evaluate_error(
                        estimated=user_estimates, target=targets
                    )

                    run_errors.append(error)
                    run_failures.append(num_failed)
                    run_runtimes.append(duration)

                    log_print(
                        f"    Try {try_idx+1}/{num_tries}: Error={error:.4f}, Failed={num_failed}",
                        logfile,
                    )

                # Create result record for this run
                result = {
                    "space": space_name,
                    "seed": seed,
                    "run": run,
                    "method": method_name,
                    "run_error_mean": np.mean(run_errors),
                    "run_error_std": np.std(run_errors),
                    "errors": run_errors,  # Store all try errors
                    "run_failures_mean": np.mean(run_failures),
                    "run_failures_std": np.std(run_failures),
                    "failures_all": run_failures,  # Store all try failures
                    "run_runtime_mean": np.mean(run_runtimes),
                    "run_runtime_std": np.std(run_runtimes),
                    "runtimes": run_runtimes,  # Store all try runtimes
                    "num_tries": num_tries,
                }

                log_print(
                    f"  Run Summary: Avg Error={result['run_error_mean']:.4f}±{result['run_error_std']:.4f}",
                    logfile,
                )
                results.append(result)

                # Save intermediate result
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                partial_file = os.path.join(
                    output_dir, f"result_{run_id}_{timestamp}.json"
                )
                with open(partial_file, "w") as f:
                    json.dump(result, f, indent=2, default=json_dump_convert)

                if use_wandb:
                    wandb.log(result)

    # Final summary
    with open(summary_file, "w") as f:
        json.dump(results, f, indent=2, default=json_dump_convert)
    log_print("\n✅ Evaluation complete. Summary saved to:\n" + summary_file, logfile)


# =======================
# Entry point
# =======================
# python evaluate.py --config spaces.yaml --output_dir results/method1
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config", type=str, default="spaces.yaml", help="Path to YAML config file."
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="results/method1",
        help="Directory to save logs and results.",
    )
    parser.add_argument(
        "--wandb", action="store_true", help="Enable logging to Weights & Biases."
    )
    parser.add_argument(
        "--num_runs", type=int, default=3, help="Number of runs per seed."
    )
    parser.add_argument(
        "--num_tries",
        type=int,
        default=5,
        help="Number of tries per run to average error.",
    )
    args = parser.parse_args()

    evaluate(
        config_file=args.config,
        output_dir=args.output_dir,
        use_wandb=args.wandb,
        num_runs=args.num_runs,
        num_tries=args.num_tries,
    )
