import argparse
import random
import torch
import numpy as np
import pandas as pd
import os
from datetime import datetime
from train_order import run_single_run


def blue(x):
    return "\033[94m" + x + "\033[0m"


def red(x):
    return "\033[31m" + x + "\033[0m"

#spspsp
def _nullable_float(value):
    """Parse a float allowing textual None."""
    if value is None:
        return None
    if isinstance(value, str):
        stripped = value.strip()
        if stripped.lower() in {"none", "null", "nan", ""}:
            return None
        try:
            return float(stripped)
        except ValueError as err:
            raise argparse.ArgumentTypeError(f"Invalid float value: {value}") from err
    try:
        return float(value)
    except (TypeError, ValueError) as err:
        raise argparse.ArgumentTypeError(f"Invalid float value: {value}") from err


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--results_dir", type=str, default="results")
    parser.add_argument("--dodiscover_cam", type=str, default="True")
    parser.add_argument("--pruning_method", type=str, default="cam", choices=["cam", "xgb", "rf", "tabpfn"])
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--model", type=str, default="CaPS", choices=["CAM", "SCORE", "DAS", "NoGAM", "DiffAN", "CaPS", "OURS", "random"])
    parser.add_argument("--pre_pruning", type=str, default="True")
    parser.add_argument("--add_edge", type=str, default="True")
    parser.add_argument("--lambda1", type=float, default=50.0)
    parser.add_argument("--lambda2", type=float, default=50.0)
    parser.add_argument("--dataset", type=str, default="sachs")
    parser.add_argument("--num_nodes", type=int, default=10)
    parser.add_argument("--num_samples", type=int, default=2000)
    parser.add_argument("--method", type=str, default="mixed", choices=["mixed", "linear", "nonlinear"])
    parser.add_argument("--linear_sem_type", type=str, default="gauss", choices=["gauss", "exp", "gumbel", "uniform", "logistic"])
    parser.add_argument("--nonlinear_sem_type", type=str, default="gp", choices=["gp", "gp-add", "mlp", "mim", "quadratic"])
    parser.add_argument("--linear_rate", type=float, default=0.5)
    
    # Misspecified Data Scenario Arguments
    parser.add_argument("--scenario", type=str, default="vanilla", 
                       choices=["vanilla", "pnl", "lingam", "confounded", "measure_err", "timino", "unfaithful"],
                       help="Misspecified data scenario type (vanilla: no post-processing, paper 6 scenarios)")
    parser.add_argument("--rho", type=float, default=0.2, 
                       help="Probability of adding confounder for confounded scenario (paper default: 0.2)")
    parser.add_argument("--gamma", type=float, default=0.8, 
                       help="Signal to noise ratio for measurement error scenario (paper default: 0.8)")
    parser.add_argument("--p_unfaithful", type=float, default=0.3, 
                       help="Probability of unfaithful distribution for unfaithful scenario")
    parser.add_argument("--exponent", type=float, default=3.0,
                       help="Exponent for post-nonlinear transformation in PNL scenario (paper: x^3)")
    
    parser.add_argument("--manualSeed", type=str, default="False")
    parser.add_argument("--runs", type=int, default=10)
    parser.add_argument("--norm", type=str, default="False")
    parser.add_argument(
        "--confidence_threshold",
        type=_nullable_float,
        default=None,
        help="Optional per-sample gain threshold override; defaults to the MDL gate.",
    )
    parser.add_argument(
        "--mlp_confidence_threshold",
        type=_nullable_float,
        default=None,
        help="Optional per-sample gain override for MLP pruning; defaults to the MDL gate.",
    )
    parser.add_argument(
        "--xgb_confidence_threshold",
        type=_nullable_float,
        default=None,
        help="Optional per-sample gain override for XGBoost pruning; defaults to the MDL gate.",
    )
    parser.add_argument(
        "--rf_confidence_threshold",
        type=_nullable_float,
        default=None,
        help="Optional per-sample gain override for RandomForest pruning; defaults to the MDL gate.",
    )
    parser.add_argument(
        "--mdl_lambda",
        type=float,
        default=1.0,
        help="Lambda parameter for the MDL gate (Eq. (tau^MDL)).",
    )
    parser.add_argument(
        "--mdl_kappa",
        type=float,
        default=25.0,
        help="Kappa parameter for the MDL gate (Eq. (tau^MDL)).",
    )


    args = parser.parse_args()
    args.manualSeed = True if args.manualSeed == "True" else False
    args.dodiscover_cam = True if args.dodiscover_cam == "True" else False
    args.pre_pruning = True if args.pre_pruning == "True" else False
    args.add_edge = True if args.add_edge == "True" else False
    args.norm = True if args.norm == "True" else False
    
    # Create a timestamped results directory
    # Use the timestamp from the environment when available, otherwise create one
    experiment_timestamp = os.environ.get('EXPERIMENT_TIMESTAMP')
    if experiment_timestamp:
        # Respect the timestamp provided by the shell script
        timestamp = experiment_timestamp
    else:
        # Generate a fresh timestamp for standalone runs
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    args.results_dir = os.path.join(args.results_dir, timestamp)
    
    return args


def save_experiment_results(
    args,
    metrics_dict=None,
    average_metrics=None,
    experiment_type="raw",
    first_run_in_batch=False,
):
    """Persist raw and aggregated experiment results to disk.

    Args:
        args: Parsed CLI arguments carrying experiment settings.
        metrics_dict: Metrics from an individual run (for raw outputs).
        average_metrics: Aggregated metrics across runs (for summary files).
        experiment_type: Category of experiment to control filenames.
            - "raw": per-run data (default for all experiments).
            - "scenario": scenario comparisons.
            - "linear_trend": linear-rate trend experiments.
            - "benchmark": benchmark comparisons.
            - "pruning": pruning-method comparisons.
        first_run_in_batch: Flag indicating whether this is the first run in a batch.
    """
    os.makedirs(args.results_dir, exist_ok=True)
    desired_order = [
        'shd', 'sid', 'precision', 'recall', 'F1', 'nnz', 
        'gscore', 'fdr', 'tpr', 'fpr',  'fnr', 'scenario', 'pruning_method', 'linear_rate',
        'tau', 'tau_mode', 'tau_min', 'tau_max', 'tau_mean',
        'rho', 'gamma', 'p_unfaithful'
    ]
    
    if experiment_type == "raw":
        # Persist per-run data as (task)_(dataset)_(model).csv
        # Read experiment type from the environment when running via scripts
        experiment_task_type = os.environ.get('EXPERIMENT_TYPE')
        
        if experiment_task_type:
            # When run via script, follow the environment-provided experiment type
            filename = f"{args.results_dir}/{experiment_task_type}_{args.dataset}_{args.model}.csv"
        else:
            # When run interactively, derive filenames from args
            if hasattr(args, 'scenario'):
                filename = f"{args.results_dir}/scenario_{args.dataset}_{args.model}.csv"
            elif hasattr(args, 'pruning_method') and args.pruning_method != 'cam':
                filename = f"{args.results_dir}/pruning_{args.dataset}_{args.model}.csv"
            else:
                filename = f"{args.results_dir}/benchmark_{args.dataset}_{args.model}.csv"
        
        # Reorder DataFrame columns before saving
        df = pd.DataFrame([metrics_dict])
        existing_cols = [col for col in desired_order if col in df.columns]
        df = df[existing_cols]

        # df = pd.DataFrame([metrics_dict])
        file_exists = os.path.isfile(filename)
        
        if first_run_in_batch and file_exists:
            with open(filename, "a") as f:
                f.write("\n")
        df.to_csv(filename, mode="a", header=not file_exists, index=False)
    
    elif experiment_type in ["scenario", "linear_trend", "benchmark", "pruning", "tau"]:
        # Save averages as avg_(task)_(dataset)_(model).csv
        task_name = experiment_type if experiment_type != "linear_trend" else "linear"
        filename = f"{args.results_dir}/avg_{task_name}_{args.dataset}_{args.model}.csv"
        
        # Build a generic structure for averaged metrics
        data_to_save = {}
        
        # Append experiment-specific condition columns
        if experiment_type == "scenario":
            if hasattr(args, 'scenario'):
                data_to_save['scenario'] = args.scenario
        elif experiment_type == "linear_trend":
            if hasattr(args, 'linear_rate'):
                data_to_save['linear_rate'] = args.linear_rate
        elif experiment_type == "pruning":
            if hasattr(args, 'pruning_method'):
                data_to_save['pruning_method'] = args.pruning_method
        elif experiment_type == "benchmark":
            # Benchmark experiments only need the dataset column
            data_to_save['dataset'] = args.dataset
        elif experiment_type == "tau":
            if hasattr(args, 'confidence_threshold'):
                tau_value = (
                    args.confidence_threshold
                    if args.confidence_threshold is not None
                    else float('nan')
                )
                data_to_save['tau'] = tau_value
            
        # Attach averaged metrics
        data_to_save.update(average_metrics)
        
        df = pd.DataFrame([data_to_save])
        file_exists = os.path.isfile(filename)
        df.to_csv(filename, mode='a', header=not file_exists, index=False)


if __name__ == "__main__":
    args = get_args()
    metrics_list_dict = {
        "shd": [], "sid": [], "precision": [], "recall": [], "F1": [],
        "fdr": [], "tpr": [], "fpr": [], "fnr": [], "nnz": [],
        "tau": [], "tau_mode": [], "tau_min": [], "tau_max": [], "tau_mean": [],
        "scenario": [], "pruning_method": [], "linear_rate": [],
        "rho": [], "gamma": [], "p_unfaithful": []}
    metrics_res_dict = {}
    simulation_seeds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

    for i in range(args.runs):
        print(red("runs {}:".format(i)))

        if args.manualSeed:
            Seed = getattr(args, 'random_seed', random.randint(1, 10000))
        else:
            Seed = random.randint(1, 10000)
        print("Random Seed:", Seed)
        random.seed(Seed)
        torch.manual_seed(Seed)
        np.random.seed(Seed)
        
        # Updated and cleaned model list
        if args.model in ["CaPS", "DiffAN", "DAS", "NoGAM", "SCORE", "CAM", "OURS", "random"]:
            metrics_dict, order, tau_summary = run_single_run(args, i, simulation_seeds)
            
            # Inject scenario and pruning metadata into metrics_dict
            if hasattr(args, 'scenario'):
                metrics_dict['scenario'] = args.scenario
            if hasattr(args, 'pruning_method'):
                # OURS always uses TabPFN pruning
                if args.model == 'OURS':
                    metrics_dict['pruning_method'] = 'tabpfn'
                else:
                    metrics_dict['pruning_method'] = args.pruning_method
            if hasattr(args, 'linear_rate'):
                metrics_dict['linear_rate'] = args.linear_rate
            
            # Record tau diagnostics
            metrics_dict.update(tau_summary)
            
            # Add any optional parameters present on args
            if hasattr(args, 'rho'):
                metrics_dict['rho'] = args.rho
            if hasattr(args, 'gamma'):
                metrics_dict['gamma'] = args.gamma
            if hasattr(args, 'p_unfaithful'):
                metrics_dict['p_unfaithful'] = args.p_unfaithful
            
            print(blue(str(metrics_dict)))
            save_experiment_results(args, metrics_dict, first_run_in_batch=(i == 0))
        else:
            raise Exception(f"Model '{args.model}' is not supported. Please choose from: CaPS, DiffAN, DAS, NoGAM, SCORE, CAM, OURS, random")

        for k in metrics_list_dict:
            if k in metrics_dict:
                # String or categorical parameters (may not be numeric)
                if k in ['tau', 'tau_mode', 'rho', 'gamma', 'p_unfaithful', 'scenario', 'pruning_method', 'linear_rate']:
                    metrics_list_dict[k].append(metrics_dict[k])
                else:
                    # Only numeric metrics require NaN checks
                    try:
                        if np.isnan(metrics_dict[k]):
                            if k not in {'tau_min', 'tau_max', 'tau_mean'}:
                                print(f"Warning: {k} is NaN in run {i}, skipping this value")
                            metrics_list_dict[k].append(metrics_dict[k])
                        else:
                            metrics_list_dict[k].append(metrics_dict[k])
                    except (TypeError, ValueError):
                        # Non-numeric entries are appended as-is
                        metrics_list_dict[k].append(metrics_dict[k])

    # Calculate and save the average of the completed runs
    if args.runs > 0:
        average_metrics = {}
        for k in metrics_list_dict:
            if metrics_list_dict[k]: # check if list is not empty
                # String parameters reuse the first observed value
                if k in ['tau', 'tau_mode', 'rho', 'gamma', 'p_unfaithful', 'scenario', 'pruning_method', 'linear_rate']:
                    average_metrics[k] = metrics_list_dict[k][0]
                else:
                    # Average only numeric metrics
                    try:
                        if len(metrics_list_dict[k]) == 0:
                            print(f"Warning: No valid values for {k}, skipping from average")
                            continue
                        # Use nanmean to tolerate NaN inputs
                        avg_val = np.nanmean(metrics_list_dict[k])
                        if np.isnan(avg_val):
                            if k not in {'tau_min', 'tau_max', 'tau_mean'}:
                                print(f"Warning: Average of {k} is NaN, skipping from results")
                            continue
                        average_metrics[k] = round(float(avg_val), 4)
                    except (TypeError, ValueError):
                        # Fall back to the first value for non-numeric metrics
                        if len(metrics_list_dict[k]) > 0:
                            average_metrics[k] = metrics_list_dict[k][0]
        
        if average_metrics:
            # Choose experiment_type based on execution context
            experiment_task_type = os.environ.get('EXPERIMENT_TYPE')
            
            if experiment_task_type:
                # For scripted runs, rely on the environment variable
                if experiment_task_type == "linear":
                    experiment_type = "linear_trend"  # Internally reuse the linear_trend label
                else:
                    experiment_type = experiment_task_type
            else:
                # For ad-hoc runs, rely on CLI arguments
                if hasattr(args, 'scenario'):
                    experiment_type = "scenario"
                elif hasattr(args, 'pruning_method') and args.pruning_method != 'cam':
                    experiment_type = "pruning"
                elif hasattr(args, 'linear_rate') and args.linear_rate != 0.0:
                    experiment_type = "linear_trend"
                else:
                    experiment_type = "benchmark"
            
            save_experiment_results(args, average_metrics=average_metrics, experiment_type=experiment_type)

    # Print final results with std dev (if more than one run)
    if args.runs > 1:
        metrics_res_dict = {}
        for k in metrics_list_dict:
            if len(metrics_list_dict[k]) == 0:
                continue  # Skip empty lists
            
            # Do not compute stats for string-valued metrics
            if k in ['tau', 'rho', 'gamma', 'p_unfaithful', 'scenario', 'pruning_method', 'linear_rate']:
                metrics_res_dict[k] = metrics_list_dict[k][0]  # Use the first value
            else:
                try:
                    # Compute mean/std while respecting NaN entries
                    mean_val = np.nanmean(metrics_list_dict[k])
                    std_val = np.nanstd(metrics_list_dict[k])
                    
                    if np.isnan(mean_val) or np.isnan(std_val):
                        continue  # Skip NaN results
                    
                    mean_val = round(float(mean_val), 4)
                    std_val = round(float(std_val), 4)
                    metrics_res_dict[k] = f"{mean_val}+/-{std_val}"
                except (TypeError, ValueError):
                    continue  # Skip values that cannot be computed
        
        if metrics_res_dict:
            print(blue(f"Average over {args.runs} runs: {metrics_res_dict}"))
