import os
import glob
import time
import pickle
import json
import random
from metrics import TabMetrics

import torch
import numpy as np
import pandas as pd

from copy import deepcopy
from tqdm import tqdm
import argparse

import wandb
import warnings
warnings.filterwarnings('ignore')



BAR = "=============="
def print_with_bar(log_msg):
    log_msg = BAR + log_msg + BAR
    if "End" in log_msg:
         log_msg += "\n"
    print(log_msg)


def align_synthetic_to_real(syn_df, real_data_path, info):
    """
    Coerce synthetic categorical columns to only use categories present in real data.
    This avoids 'unknown categories' errors when the generator outputs continuous
    or out-of-set values for categorical columns. Unknown values are replaced with
    the mode of that column in the real data.
    """
    real_df = pd.read_csv(real_data_path)
    real_df.columns = range(len(real_df.columns))
    syn_df = syn_df.copy()
    if syn_df.shape[1] != real_df.shape[1]:
        print(
            "Warning: Synthetic has {} columns but real has {}. "
            "Evaluation requires the same schema (real.csv and sample CSV must match). Skipping categorical alignment.".format(
                syn_df.shape[1], real_df.shape[1]
            )
        )
        return syn_df
    syn_df.columns = range(len(syn_df.columns))

    num_col_idx = list(info.get("num_col_idx", []))
    cat_col_idx = list(info.get("cat_col_idx", []))
    target_col_idx = list(info.get("target_col_idx", []))
    task_type = info.get("task_type", "classification")
    if task_type == "regression":
        cat_cols = cat_col_idx
    else:
        cat_cols = cat_col_idx + target_col_idx

    for col in cat_cols:
        if col >= syn_df.shape[1]:
            continue
        real_vals = real_df[col].astype(str).str.replace(r"\.0$", "", regex=True)
        allowed = set(real_vals.unique())
        if len(real_df[col].dropna()) > 0:
            mode_val = str(real_df[col].mode().iloc[0])
            if mode_val.endswith(".0"):
                mode_val = mode_val[:-2]
        else:
            mode_val = list(allowed)[0] if allowed else ""
        syn_str = syn_df[col].astype(str).str.replace(r"\.0$", "", regex=True)
        # Replace any value not in allowed with mode (or first allowed)
        def _map_val(x):
            return x if x in allowed else mode_val
        syn_df[col] = syn_str.map(_map_val)
    return syn_df


def _fmt(v):
    """Format a metric value for print (handle None and nan)."""
    if v is None:
        return "N/A"
    if isinstance(v, float) and (np.isnan(v) or np.isinf(v)):
        return "N/A"
    if isinstance(v, float):
        return f"{v:.4f}"
    return str(v)


def print_evaluation_results(out_metrics):
    """Print evaluation metrics in a clear, readable block."""
    print("\n" + "=" * 60)
    print("EVALUATION RESULTS (synthetic vs real data)")
    print("=" * 60)
    # Quality / density
    for key in ["quality/Shape", "quality/Trend", "quality/Overall", "density/Shape", "density/Trend"]:
        if key in out_metrics:
            print(f"  {key}: {_fmt(out_metrics[key])}")
    # MLE
    if "mle" in out_metrics:
        print(f"  mle: {_fmt(out_metrics['mle'])}")
    # Extra quality if present
    for key in ["quality/F1_sample", "quality/Q_geom", "quality/alpha_precision_all", "quality/beta_recall_all"]:
        if key in out_metrics:
            print(f"  {key}: {_fmt(out_metrics[key])}")
    print("=" * 60 + "\n")


class Evaluator:
    def __init__(
            self, sample_path,
            metrics, logger,
            result_save_path,
            real_data_path=None,
            info=None,
            **kwargs
    ):
        self.sample_path = sample_path
        self.sampled_synthetic = pd.read_csv(sample_path)
        print(f"Synthetic sample shape: {self.sampled_synthetic.shape}")
        # Do NOT align synthetic to real here - parent evaluator uses raw synthetic; alignment
        # (replacing unknown categories with mode) collapsed LGB_S diversity and lowered quality metrics.

        self.metrics = metrics
        self.logger = logger
        self.result_save_path = result_save_path        
        
    def report_test(self, num_runs):
        save_dir = self.result_save_path
        
        shape_ = []
        trend_ = []
        mle_ = []
        #c2st_ = []
        for i in range(num_runs):
            print_with_bar(f"GENERAL Evaluation Run {i}")
            out_metrics, extras, _ = self.evaluate_generation()
            print(extras)
            print(f"Results of Run {i} are: \n{out_metrics}")
            shape_.append(out_metrics["density/Shape"])
            trend_.append(out_metrics["density/Trend"])
            mle_.append(out_metrics["mle"])

        shape_ = np.array(shape_)
        trend_ = np.array(trend_)
        mle_ = np.array(mle_)
        #c2st_ = np.array(c2st_)
        
        shape_error = (1 - shape_)*100
        trend_error = (1 - trend_)*100
        #c2st_percent = c2st_ * 100
        
        all_results = pd.DataFrame({
            "shape": shape_error,
            "trend": trend_error,
            "mle": mle_,
           # "c2st": c2st_percent,
        })
        avg = all_results.mean(axis=0).round(3)
        std = all_results.std(axis=0).round(3)
        avg_std = pd.concat([avg, std], axis=1, ignore_index=True)
        avg_std.columns = ["avg", "std"]
        avg_std.index = [
            "shape", 
            "trend", 
            "mle", 
            #"c2st", 
        ]
        
        # Savings
        all_results.to_csv(f"{save_dir}/all_results.csv", index=True)
        avg_std.to_csv(f"{save_dir}/avg_std.csv", index=True)
        print_with_bar(f"The AVG over {num_runs} runs are: \n{avg_std}")
        
    def report_test_dcr(self, num_runs):
        save_dir = self.result_save_path
        
        dcr_ = []
        dcr_real_ = []
        dcr_test_ = []
        for i in range(num_runs):
            print_with_bar(f"DCR Evaluation Run {i}")
            out_metrics, extras, syn_df = self.evaluate_generation()
            print(f"Results of Run {i} are: \n{out_metrics}")
            dcr_.append(out_metrics["dcr"])
            dcr_real_.append(extras["dcr_real"])
            dcr_test_.append(extras["dcr_test"])

        dcr_ = np.array(dcr_)
        
        dcr_percent = dcr_ * 100
        
        all_results = pd.DataFrame({
            "dcr": dcr_percent,
        })
        avg = all_results.mean(axis=0).round(3)
        std = all_results.std(axis=0).round(3)
        avg_std = pd.concat([avg, std], axis=1, ignore_index=True)
        avg_std.columns = ["avg", "std"]
        avg_std.index = [
            "dcr", 
        ]
        
        # Savings
        all_results.to_csv(f"{save_dir}/all_results.csv", index=True)
        avg_std.to_csv(f"{save_dir}/avg_std.csv", index=True)
        dcr_real = np.concatenate(dcr_real_, axis=0)
        dcr_test = np.concatenate(dcr_test_, axis=0)
        dcr_df = pd.DataFrame({
            "dcr_real": dcr_real,
            "dcr_test": dcr_test
        })
        dcr_df.to_csv(f"{save_dir}/dcr.csv", index=False)
        
        print_with_bar(f"The AVG over {num_runs} runs are: \n{avg_std}")
        
    def test(self):    
        out_metrics, _, _ = self.evaluate_generation(save_metric_details=True, plot_density=True)
        # ADDED
        shape = out_metrics['quality/Shape']
        trend = out_metrics['quality/Trend']
        recall = out_metrics['quality/beta_recall_all']
        precision = out_metrics['quality/alpha_precision_all']
        F1_sample = 2*(recall*precision) / (recall+precision)
        Q_geom = (shape*trend*recall*precision) ** (1/4)
        
        out_metrics['quality/F1_sample'] = F1_sample
        out_metrics['quality/Q_geom'] = Q_geom
        
        # TEMPORARILY COMMENTED OUT - Volatility and sensitivity-related stats (not useful currently)
        # Try to load volatility from training.
        # Training saves a JSON next to the sample CSV used here:
        #   sample_path: .../synthetic/{method}-{extra}/{dataname}/{filename}.csv
        #   volatility:  .../synthetic/{method}-{extra}/{dataname}/{filename}_volatility.json
        # volatility = None
        # volatility_file = None

        # sample_path = getattr(self, "sample_path", None)
        # if sample_path is not None:
        #     base, _ = os.path.splitext(sample_path)
        #     candidate = base + "_volatility.json"
        #     if os.path.exists(candidate):
        #         volatility_file = candidate

        # # Fallback: search under ../evaluation/synthetic by file name
        # if volatility_file is None:
        #     # Derive expected base name from the CSV file name
        #     if sample_path is not None:
        #         csv_name = os.path.basename(sample_path)
        #         base_name, _ = os.path.splitext(csv_name)
        #         synthetic_root = os.path.join("..", "evaluation", "synthetic")
        #         if os.path.exists(synthetic_root):
        #             for root, dirs, files in os.walk(synthetic_root):
        #                 for fname in files:
        #                     if fname == base_name + "_volatility.json":
        #                         volatility_file = os.path.join(root, fname)
        #                         break
        #                 if volatility_file:
        #                     break

        # if volatility_file and os.path.exists(volatility_file):
        #     try:
        #         with open(volatility_file, "r") as f:
        #             vol_data = json.load(f)
        #         volatility = vol_data.get("volatility", None)
        #         logscaleV = vol_data.get("logscaleV", None)
        #         mean_log_sensitivity = vol_data.get("mean_log_sensitivity", None)
        #         q90_log_sensitivity = vol_data.get("q90_log_sensitivity", None)
        #         q95_log_sensitivity = vol_data.get("q95_log_sensitivity", None)
        #         cv_sensitivity = vol_data.get("cv_sensitivity", None)
        #         if volatility is not None:
        #             out_metrics["training/volatility"] = volatility
        #             print(f"✓ Loaded volatility: {volatility:.6f} from {volatility_file}")
        #         if logscaleV is not None:
        #             out_metrics["training/logscaleV"] = logscaleV
        #             print(f"✓ Loaded logscaleV: {logscaleV:.6f}")
        #         if mean_log_sensitivity is not None:
        #             out_metrics["training/mean_log_sensitivity"] = mean_log_sensitivity
        #             print(f"✓ Loaded mean(log(S)): {mean_log_sensitivity:.6f}")
        #         if q90_log_sensitivity is not None:
        #             out_metrics["training/q90_log_sensitivity"] = q90_log_sensitivity
        #             print(f"✓ Loaded Q90(log(S)): {q90_log_sensitivity:.6f}")
        #         if q95_log_sensitivity is not None:
        #             out_metrics["training/q95_log_sensitivity"] = q95_log_sensitivity
        #             print(f"✓ Loaded Q95(log(S)): {q95_log_sensitivity:.6f}")
        #         if cv_sensitivity is not None:
        #             out_metrics["training/cv_sensitivity"] = cv_sensitivity
        #             print(f"✓ Loaded CV(S): {cv_sensitivity:.6f}")
        #     except Exception as e:
        #         print(f"Warning: Could not load volatility from {volatility_file}: {e}")
        # else:
        #     print("Info: Volatility file not found; it may not have been computed during training.")
        
        # Try to load sharpness history for time series logging
        sharpness_history_file = None
        sample_path = getattr(self, "sample_path", None)
        if sample_path is not None:
            base, _ = os.path.splitext(sample_path)
            candidate = base + "_sharpness_history.json"
            if os.path.exists(candidate):
                sharpness_history_file = candidate
        
        # Fallback: search for sharpness history file
        if sharpness_history_file is None and sample_path is not None:
            csv_name = os.path.basename(sample_path)
            base_name, _ = os.path.splitext(csv_name)
            synthetic_root = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data", "synthetic")
            if os.path.exists(synthetic_root):
                for root, dirs, files in os.walk(synthetic_root):
                    for fname in files:
                        if fname == base_name + "_sharpness_history.json":
                            sharpness_history_file = os.path.join(root, fname)
                            break
                    if sharpness_history_file:
                        break
        
        # Load and log sharpness history to wandb
        if sharpness_history_file and os.path.exists(sharpness_history_file):
            try:
                with open(sharpness_history_file, "r") as f:
                    sharpness_data = json.load(f)
                sharpness_history = sharpness_data.get("sharpness_history", [])
                
                if len(sharpness_history) > 0:
                    # Define metrics for wandb to track them properly
                    try:
                        wandb.define_metric("training/sharpness", step_metric="training_step")
                        wandb.define_metric("training/log_sharpness", step_metric="training_step")
                    except:
                        pass  # If already defined or not available, continue
                    
                    # Log sharpness as a time series (line plot) in wandb
                    steps = [entry["step"] for entry in sharpness_history]
                    sharpness_values = [entry["sharpness"] for entry in sharpness_history]
                    
                    # Create a wandb table for the time series plot
                    table_data = [[step, sharp] for step, sharp in zip(steps, sharpness_values)]
                    table = wandb.Table(columns=["step", "sharpness"], data=table_data)
                    self.logger.log({"training/sharpness_over_steps": wandb.plot.line(
                        table, "step", "sharpness", title="Sharpness Over Training Steps"
                    )})
                    
                    # Log individual points as a time series for easier querying and visualization
                    # This creates a line plot in wandb's time series view
                    log_sharpness_count = 0
                    for entry in sharpness_history:
                        log_dict = {
                            "training/sharpness": entry["sharpness"],
                            "training_step": entry["step"]  # Add step as a metric for reference
                        }
                        # Also log log(sharpness) if value is positive
                        if entry["sharpness"] > 0:
                            log_sharpness = np.log(entry["sharpness"])
                            log_dict["training/log_sharpness"] = log_sharpness
                            log_sharpness_count += 1
                        # Log with training step
                        self.logger.log(log_dict, step=entry["step"])
                    
                    print(f"✓ Loaded and logged {len(sharpness_history)} sharpness measurements from {sharpness_history_file}")
                    print(f"✓ Logged {log_sharpness_count} log(sharpness) values to wandb (training/log_sharpness)")
            except Exception as e:
                print(f"Warning: Could not load sharpness history from {sharpness_history_file}: {e}")
        else:
            print("Info: Sharpness history file not found; it may not have been saved during training.")
        
        # TEMPORARILY COMMENTED OUT - Sensitivity history loading and logging (not useful currently)
        # Try to load sensitivity history for time series logging
        # sensitivity_history_file = None
        # sample_path = getattr(self, "sample_path", None)
        # if sample_path is not None:
        #     base, _ = os.path.splitext(sample_path)
        #     candidate = base + "_sensitivity_history.json"
        #     if os.path.exists(candidate):
        #         sensitivity_history_file = candidate
        # 
        # # Fallback: search for sensitivity history file
        # if sensitivity_history_file is None and sample_path is not None:
        #     csv_name = os.path.basename(sample_path)
        #     base_name, _ = os.path.splitext(csv_name)
        #     synthetic_root = os.path.join("..", "evaluation", "synthetic")
        #     if os.path.exists(synthetic_root):
        #         for root, dirs, files in os.walk(synthetic_root):
        #             for fname in files:
        #                 if fname == base_name + "_sensitivity_history.json":
        #                     sensitivity_history_file = os.path.join(root, fname)
        #                     break
        #             if sensitivity_history_file:
        #                 break
        # 
        # # Load and log sensitivity history to wandb
        # if sensitivity_history_file and os.path.exists(sensitivity_history_file):
        #     try:
        #         with open(sensitivity_history_file, "r") as f:
        #             sensitivity_data = json.load(f)
        #         sensitivity_history = sensitivity_data.get("sensitivity_history", [])
        #         
        #         if len(sensitivity_history) > 0:
        #             # Define metrics for wandb to track them properly
        #             try:
        #                 wandb.define_metric("training/sensitivity", step_metric="training_step")
        #                 wandb.define_metric("training/log_sensitivity", step_metric="training_step")
        #             except:
        #                 pass  # If already defined or not available, continue
        #             
        #             # Log sensitivity as a time series (line plot) in wandb
        #             steps = [entry["step"] for entry in sensitivity_history]
        #             sensitivity_values = [entry["sensitivity"] for entry in sensitivity_history]
        #             
        #             # Create a wandb table for the time series plot
        #             table_data = [[step, sens] for step, sens in zip(steps, sensitivity_values)]
        #             table = wandb.Table(columns=["step", "sensitivity"], data=table_data)
        #             self.logger.log({"training/sensitivity_over_steps": wandb.plot.line(
        #                 table, "step", "sensitivity", title="Sensitivity Over Training Steps"
        #             )})
        #             
        #             # Log individual points as a time series
        #             log_sensitivity_count = 0
        #             for entry in sensitivity_history:
        #                 log_dict = {
        #                     "training/sensitivity": entry["sensitivity"],
        #                     "training_step": entry["step"]
        #                 }
        #                 # Also log log(sensitivity) if value is positive
        #                 if entry["sensitivity"] > 0:
        #                     log_sensitivity = np.log(entry["sensitivity"])
        #                     log_dict["training/log_sensitivity"] = log_sensitivity
        #                     log_sensitivity_count += 1
        #                 # Log with training step
        #                 self.logger.log(log_dict, step=entry["step"])
        #             
        #             print(f"✓ Loaded and logged {len(sensitivity_history)} sensitivity measurements from {sensitivity_history_file}")
        #             print(f"✓ Logged {log_sensitivity_count} log(sensitivity) values to wandb (training/log_sensitivity)")
        #     except Exception as e:
        #         print(f"Warning: Could not load sensitivity history from {sensitivity_history_file}: {e}")
        # else:
        #     print("Info: Sensitivity history file not found; it may not have been saved during training.")
        
        # Try to load validation loss variance statistics from training
        val_loss_variance_file = None
        sample_path = getattr(self, "sample_path", None)
        if sample_path is not None:
            base, _ = os.path.splitext(sample_path)
            candidate = base + "_val_loss_variance.json"
            if os.path.exists(candidate):
                val_loss_variance_file = candidate
        
        # Fallback: search for validation loss variance file
        if val_loss_variance_file is None and sample_path is not None:
            csv_name = os.path.basename(sample_path)
            base_name, _ = os.path.splitext(csv_name)
            synthetic_root = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data", "synthetic")
            if os.path.exists(synthetic_root):
                for root, dirs, files in os.walk(synthetic_root):
                    for fname in files:
                        if fname == base_name + "_val_loss_variance.json":
                            val_loss_variance_file = os.path.join(root, fname)
                            break
                    if val_loss_variance_file:
                        break
        
        # Load and log validation loss variance statistics to wandb
        if val_loss_variance_file and os.path.exists(val_loss_variance_file):
            try:
                with open(val_loss_variance_file, "r") as f:
                    val_loss_variance_data = json.load(f)
                mean_loss = val_loss_variance_data.get("mean_loss", None)
                V_within = val_loss_variance_data.get("V_within", None)
                V_between = val_loss_variance_data.get("V_between", None)
                V_total = val_loss_variance_data.get("V_total", None)
                V_eff = val_loss_variance_data.get("V_eff", None)
                loss_mean_plus_V_total = val_loss_variance_data.get("loss_mean_plus_V_total", None)
                C = val_loss_variance_data.get("C", None)
                score = val_loss_variance_data.get("score", None)
                
                if mean_loss is not None:
                    out_metrics["validation/loss_mean"] = mean_loss
                    print(f"✓ Loaded validation loss mean: {mean_loss:.6f} from {val_loss_variance_file}")
                if V_within is not None:
                    out_metrics["validation/V_within"] = V_within
                    print(f"✓ Loaded V_within: {V_within:.6f}")
                if V_between is not None:
                    out_metrics["validation/V_between"] = V_between
                    print(f"✓ Loaded V_between: {V_between:.6f}")
                if V_total is not None:
                    out_metrics["validation/V_total"] = V_total
                    print(f"✓ Loaded V_total: {V_total:.6f}")
                if V_eff is not None:
                    out_metrics["validation/V_eff"] = V_eff
                    print(f"✓ Loaded V_eff: {V_eff:.6f}")
                if loss_mean_plus_V_total is not None:
                    out_metrics["validation/loss_mean_plus_V_total"] = loss_mean_plus_V_total
                    print(f"✓ Loaded loss_mean + V_total: {loss_mean_plus_V_total:.6f}")
                if C is not None:
                    out_metrics["validation/C"] = C
                    print(f"✓ Loaded Bernstein complexity term (C): {C:.6f}")
                if score is not None:
                    out_metrics["validation/score"] = score
                    print(f"✓ Loaded Bernstein score: {score:.6f}")
            except Exception as e:
                print(f"Warning: Could not load validation loss variance from {val_loss_variance_file}: {e}")
        else:
            print("Info: Validation loss variance file not found; it may not have been computed during training.")
        
        print_with_bar(f"Results of the test are: \n{out_metrics}")
        print_evaluation_results(out_metrics)

        self.logger.log(out_metrics)
        #print(out_metrics)

    def evaluate_generation(self, save_metric_details=False, plot_density=False, ema=False):        
        # Sampled synthetic table
        syn_df_loaded = self.sampled_synthetic #(num_samples, ema=ema)
        save_path = self.result_save_path
        
        # Compute evaluation metrics on the sample
        out_metrics, extras = self.metrics.evaluate(syn_df_loaded)
        
        # Save metrics and metric details
        path = os.path.join(save_path, "all_results.json")
        with open(path, "w") as json_file:
            json.dump(out_metrics, json_file, indent=4, separators=(", ", ": "))        # always locally save the output metrics
        if save_metric_details:
            for name, extra in extras.items():
                if isinstance(extra, pd.DataFrame):
                    extra.to_csv(os.path.join(save_path, f"{name}.csv"))
                elif isinstance(extra, dict):
                    with open(os.path.join(save_path, f"{name}.json"), "w") as json_file:
                        json.dump(extra, json_file, indent=4, separators=(", ", ": "))
                else:
                    raise NotImplementedError(f"Extra file generated during evaluations has type {type(extra)}, and code to save this type of file is not implemented")
        
        # Plot density figures
        ## activate later for generating plots
        # if plot_density:
        #     img = self.metrics.plot_density(syn_df_loaded)
        #     path = os.path.join(save_path, "density_plots.png")
        #     img.save(path)
        #     print(
        #         f"The density plots are saved at {path}"
        #     )
        return out_metrics, extras, syn_df_loaded    



# main function here
# Flow: (1) Training produces samples and writes them to sample_path (CSV).
#       (2) This script loads that sampled CSV and real data (real.csv, test.csv, val.csv).
#       (3) It computes evaluation metrics (density/quality, MLE) against the real data and prints results.


def main(args):
    device = args.device

    # Load real/synthetic data paths
    dataname = args.dataname
    method = args.method
    
    noise_p = args.p
    noise_pattern = args.pattern
    noise_cov = args.cov
    noise_seed = args.noise_seed 
    preproc = args.preproc
    beta = args.beta
    use_log = args.use_log
    exp_name= args.exp_name

    sample_code = args.sample_code
    
    emstep = args.emstep
    
    is_dcr = 'dcr' in dataname
    
    extra = args.extra
    
    # When strategy=1 (obs_mask), preproc is forced to 'm' in main.py
    # Match that behavior here to ensure correct file path
    if extra == 'obs_mask':
        preproc = 'm'

    # Match parent (AugMask/evaluation): use relative paths data/, synthetic/ when run from evaluation/.
    # Fall back to data_root (parent-of-evaluation/data) if relative paths don't exist (e.g. share without symlinks).
    _eval_dir = os.path.dirname(os.path.abspath(__file__))
    _repo_data = os.path.join(os.path.dirname(_eval_dir), "data")
    # Parent layout: evaluation/data/, evaluation/synthetic/ (relative paths). Share layout: data/ is sibling of evaluation/.
    if os.path.exists(f'data/{dataname}/info.json'):
        _use_repo_data = False  # use relative paths (same as parent)
    else:
        _use_repo_data = True   # use repo data/ (parent-of-evaluation/data)

    dt = 'sup_x' if noise_cov == 'x_only' else 'semi_y' if noise_cov == 'y_only' else 'semi_xy' if noise_cov == 'both' else 'clean'
    method_extra = f"{method}-{extra}" if noise_p else method
    if not _use_repo_data:
        os.makedirs(os.path.join('synthetic', method_extra, dataname), exist_ok=True)
        if noise_p == 0:
            sample_path = f'synthetic/{method_extra}/{dataname}/{exp_name}clean.csv'
        else:
            if preproc == 'LGB_S':
                sample_path = f'synthetic/{method_extra}/{dataname}/{dt}_{noise_pattern}_p{noise_p}_{noise_seed}{preproc}_beta{beta}_use_log{use_log}.csv'
            else:
                sample_path = f'synthetic/{method_extra}/{dataname}/{dt}_{noise_pattern}_p{noise_p}_{noise_seed}{preproc}.csv'
        info_path = f'data/{dataname}/info.json'
        real_data_path = f'synthetic/{dataname}/real.csv'
        test_data_path = f'synthetic/{dataname}/test.csv'
        val_data_path = f'synthetic/{dataname}/val.csv'
    else:
        synthetic_dir = os.path.join(_repo_data, "synthetic", method_extra)
        os.makedirs(os.path.join(synthetic_dir, dataname), exist_ok=True)
        if noise_p == 0:
            sample_path = os.path.join(synthetic_dir, dataname, f"{exp_name}clean.csv")
        else:
            if preproc == 'LGB_S':
                sample_path = os.path.join(synthetic_dir, dataname, f"{dt}_{noise_pattern}_p{noise_p}_{noise_seed}{preproc}_beta{beta}_use_log{use_log}.csv")
            else:
                sample_path = os.path.join(synthetic_dir, dataname, f"{dt}_{noise_pattern}_p{noise_p}_{noise_seed}{preproc}.csv")
        info_path = os.path.join(_repo_data, dataname, "info.json")
        real_data_path = os.path.join(_repo_data, dataname, "real.csv")
        test_data_path = os.path.join(_repo_data, dataname, "test.csv")
        val_data_path = os.path.join(_repo_data, dataname, "val.csv")

    if method != 'diffputer':
        if preproc == 'LGB_S':
            exp_name += f'{method}-{extra}-{dataname}_{dt}_{noise_pattern}_{noise_p}_{noise_seed}{preproc}_beta{beta}_use_log{use_log}' if noise_p !=0 else f'{method}_{dataname}_clean'
        else:
            exp_name += f'{method}-{extra}-{dataname}_{dt}_{noise_pattern}_{noise_p}_{noise_seed}{preproc}' if noise_p !=0 else f'{method}_{dataname}_clean'
    else:
        if preproc == 'LGB_S':
            exp_name += f'{method}-{extra}-{dataname}_{dt}_{noise_pattern}_{noise_p}_{noise_seed}{preproc}_beta{beta}_use_log{use_log}_EM={emstep}' if noise_p !=0 else f'{method}_{dataname}_clean'
        else:
            exp_name += f'{method}-{extra}-{dataname}_{dt}_{noise_pattern}_{noise_p}_{noise_seed}{preproc}_EM={emstep}' if noise_p !=0 else f'{method}_{dataname}_clean'

    print(80*'=')
    print(f'Evaluation for: {exp_name}')
    print(80*'=')
    print(f"Sample path (synthetic): {sample_path}")
    print(f"Real data path:          {real_data_path}")

    with open(info_path, 'r') as f:
        info = json.load(f)

    if not os.path.exists(val_data_path):
        print(f"{args.dataname} does not have its validation set. During MLE evaluation, a validation set will be splitted from the training set!")
        val_data_path = None

    result_save_path = f"eval/report_runs/{exp_name}"
    if result_save_path is not None:
        if not os.path.exists(result_save_path):
            os.makedirs(result_save_path)
            
    ## Make everything determinstic if needed
    if args.deterministic:
        print("DETERMINISTIC MODE is enabled!!!")
        ## Set global random seeds
        torch.manual_seed(0)
        random.seed(0)
        np.random.seed(0)

        ## Ensure deterministic CUDA operations
        os.environ['PYTHONHASHSEED'] = '0'
        os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # or ':16:8'
        torch.use_deterministic_algorithms(True)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(0)
            torch.cuda.manual_seed_all(0)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
        
    ## Enable Wandb
    project_name = args.project_name #f"missgen2"
    
    raw_config = {}
    raw_config['project_name'] = project_name
    raw_config['dataname'] = dataname
    raw_config['method'] = method
    raw_config['exp_name'] = exp_name
    raw_config['noise_p'] = noise_p
    raw_config['pattern'] = noise_pattern
    raw_config['noise_cov'] = noise_cov
    raw_config['preproc'] = preproc
    raw_config['noise_seed'] = noise_seed
    raw_config['beta'] = args.beta
    raw_config['use_log'] = args.use_log
    
    raw_config['strategy'] = extra
    #raw_config['breaks'] = breaks
    #raw_config['m_rounds'] = m_rounds
    raw_config['emstep'] = emstep
    
    
    logger = wandb.init(
        project=project_name, 
        name=exp_name,
        config=raw_config,
        mode='disabled' if args.debug or args.no_wandb else 'online',
    )
        
    ## Load Metrics
    if is_dcr:
            metric_list = ["dcr"]
    else:
        metric_list = [
            "density", 
            'quality',
            "mle", 
            #"c2st",
        ]
    metrics = TabMetrics(real_data_path, test_data_path, val_data_path, info, device, metric_list=metric_list)
    
    # Evaluate: load sampled synthetic CSV, align categoricals to real data, compute metrics vs real/test/val
    evaluator = Evaluator(
        sample_path, metrics, logger, result_save_path,
        real_data_path=real_data_path, info=info
    )
    
    if args.report:
        if is_dcr:
            evaluator.report_test_dcr(args.num_runs)
        else:
            evaluator.report_test(args.num_runs)    
    
    else:
       evaluator.test()

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Evaluating Synthetic Data')
    parser.add_argument('--dataname', type=str, default='adult', help='Name of dataset.')
    parser.add_argument('--method', type=str, default='tabdiff', help='Name of method.')
    parser.add_argument('--sample_name', type=str, default='samples_0', help='Name of sample.')
    
    parser.add_argument('--preproc', type=str, default='', help='m=mean/mode,r=random')
    parser.add_argument('--cov', type=str, default='', help='noise in covariate')
    parser.add_argument('--pattern', type=str, default='NU1', help='pattern of noise')
    parser.add_argument('--p', type=float, default=0.0, help='percentage')
    parser.add_argument('--noise_seed', type=int, default=0, help='seed noise')
    parser.add_argument('--beta', type=str, default='0p7', help='beta parameter for split-normal shrinkage (format: 0p7 for 0.7, avoid decimal points)')
    parser.add_argument('--use_log', type=str, default='switch', help="use_log mode: 'switch' (default soft-switching), 'lognormal', or 'splitnormal'")
    
    # parser.add_argument('--noisy_input', type=int, default=0, help='0=x,1=x&y,2=y')
    # parser.add_argument('--noise_level', type=float, default=0.0, help='Noise level. 0, 3, 5, 7')
    # parser.add_argument('--noise_pattern', type=str, default='0', help='0=None, 1=Monotone, 2:NMU, 3:NMG')
    # parser.add_argument('--noise_mech', type=str, default='-', help='-, MCAR, MAR, MNAR')
    
    parser.add_argument('--no_wandb', action='store_true', help='disable wandb')
    parser.add_argument('--exp_name', type=str, default='', help='Experiment name, used to name log directories and the wandb run name')
    parser.add_argument('--report', action='store_true', help="Report testing mode: this mode sequentially runs <num_runs> test runs and report the avg and std")
    parser.add_argument('--num_runs', type=int, default=10, help="Number of runs to be averaged in the report testing mode")
    parser.add_argument('--deterministic', action='store_true', help='Whether to make the entire process deterministic, i.e., fix global random seeds')
    parser.add_argument('--gpu', type=int, default=0, help='GPU index.')
    parser.add_argument('--sample_code', type=int, default=-1, help='sample code.')

    parser.add_argument('--debug', action='store_true', help='Enable debug mode')
    
    parser.add_argument('--extra', type=str, default='', help='extra info')
    parser.add_argument('--breaks', type=int, default=30000, help='extra info')
    parser.add_argument('--m_rounds', type=int, default=1, help='extra info')
    
    parser.add_argument('--emstep', type=int, default=0, help='extra info')
    
    parser.add_argument('--project_name', type=str, default='augmask', help='wandb project name')
    
    args = parser.parse_args()
    
    if args.gpu != -1 and torch.cuda.is_available():
        args.device = f'cuda:{args.gpu}'
    else:
        args.device = 'cpu'
    
    main(args)