import os
import json
import argparse
import pandas as pd
import numpy as np
from evaluation.eval_corr import CorrelationAnalyzer
from evaluation.eval_stat import StatisticalDistributionAnalyzer
from evaluation.eval_time import TimeAnalyzer
from evaluation.eval_utility import run_tstr
from evaluation.eval_predsim import run_pred_similarity
from utils.convert_table_to_text import load_tables
from utils.configs import ex


def analyze_statistics(stat_analyzer, output_dir):
    """
    Perform statistical analysis and save the results.
    """
    stat_summary = stat_analyzer.generate_statistics_report()
    stat_summary.to_csv(os.path.join(output_dir, "stat_summary.csv"), index=False)


# NaN, inf, None이 있는지 확인하는 함수
def check_invalid_values(item_specific_results):
    has_nan = False
    has_inf = False
    has_none = False

    for key, result in item_specific_results.items():
        if "mu_abs" in result:
            value = result["mu_abs"]

            if value is None:
                has_none = True
                print(f"[WARNING] None 값 발견: {key}")

            elif isinstance(value, float) and np.isnan(value):
                has_nan = True
                print(f"[WARNING] NaN 값 발견: {key}")

            elif isinstance(value, float) and np.isinf(value):
                has_inf = True
                print(f"[WARNING] Inf 값 발견: {key}")

    return has_nan, has_inf, has_none

def analyze_correlation(corr_analyzer, real_df, syn_df, output_dir):
    """
    Perform correlation analysis and save the results.
    """
    # General Correlation Analysis
    general_results = corr_analyzer.analyze(real_df, syn_df, plot=True)
    general_results["real_correlation_matrix"].to_csv(
        os.path.join(output_dir, "real_corr_matrix.csv"), index=True
    )
    general_results["synthetic_correlation_matrix"].to_csv(
        os.path.join(output_dir, "synthetic_corr_matrix.csv"), index=True
    )

    # Item-Specific Correlation Analysis
    item_specific_results = corr_analyzer.analyze_by_item(real_df, syn_df, plot=False)
    pd.DataFrame.from_dict(item_specific_results, orient="index").to_csv(
        os.path.join(output_dir, "item_corr.csv")
    )
    # has_nan, has_inf, has_none = check_invalid_values(item_specific_results)

    # Save overall metrics
    overall_metrics = {
        "mean_mu_abs": round(np.mean([r["mu_abs"] for r in item_specific_results.values()]), 3),
        "mean_cor_acc": round(np.mean([r["cor_acc"] for r in item_specific_results.values()]), 3),
        "general_mu_abs": general_results["mu_abs"],
        "general_cor_acc": general_results["cor_acc"]
    }
    with open(os.path.join(output_dir, "overall_metrics.json"), "w") as f:
        json.dump(overall_metrics, f, indent=2)

def analyze_time(real_times, syn_times, ehr, output_dir):
    """
    Perform time analysis and save the results.
    """
    real_combined = pd.concat(real_times)[["stay_id", "time"]].sort_values(by=["stay_id", "time"]).reset_index(drop=True)
    syn_combined = pd.concat(syn_times)[["stay_id", "time"]].sort_values(by=["stay_id", "time"]).reset_index(drop=True)
    
    grouped_real_times = real_combined.groupby("stay_id")["time"].apply(list).tolist()
    grouped_syn_times = syn_combined.groupby("stay_id")["time"].apply(list).tolist()

    time_analyzer = TimeAnalyzer(ehr, grouped_real_times, grouped_syn_times)
    time_summary = time_analyzer.analyze(plot_type=None)
    time_summary.to_csv(os.path.join(output_dir, "time_summary.csv"), index=True)


def eval_stat_corr_time(config):
    # Load tables and preprocess
    ehr = config["ehr"]
    table_names = config["table_names"]
    real_data_root = config["real_data_root"]
    syn_data_root = config["syn_data_root"]
    seed_column = f"seed{config['seed']}"
    suffix = config["suffix"]

    output_base_dir = os.path.join(f"results/{ehr}_{syn_data_root.split('/')[-1]}{suffix[1:]}")
    os.makedirs(output_base_dir, exist_ok=True)

    # Load predefined data
    real_dfs, syn_dfs = load_tables(config)
    col_type = pd.read_pickle(os.path.join(real_data_root, f"{ehr}_col_dtype.pickle"))
    splits = pd.read_csv(os.path.join(real_data_root, f"{ehr}_split.csv")).reset_index()
    train_indices = splits[splits[seed_column] == "train"]['index']

    # Configuration
    real_times = []
    syn_times = []

    for table_name in table_names:  
        output_dir = os.path.join(output_base_dir, table_name)
        os.makedirs(output_dir, exist_ok=True)

        real_df = real_dfs[table_name]
        real_df = real_df[real_df.stay_id.isin(train_indices)]
        syn_df = syn_dfs[table_name]
        
        real_times.append(real_df)
        syn_times.append(syn_df)

        print(f"{table_name} Shapes - Real: {real_df.shape}, Synthetic: {syn_df.shape}")

        # Statistical Analysis
        stat_analyzer = StatisticalDistributionAnalyzer(
            ehr, real_df, syn_df, table_name,
            col_type[table_name]["numeric_columns"],
            col_type[table_name]["categorical_columns"]
        )
        analyze_statistics(stat_analyzer, output_dir)

        # Correlation Analysis
        corr_analyzer = CorrelationAnalyzer(
            ehr, table_name,
            col_type[table_name]["numeric_columns"],
            col_type[table_name]["categorical_columns"]
        )
        analyze_correlation(corr_analyzer, real_df, syn_df, output_dir)

    # Time Analysis
    time_output_dir = os.path.join(output_base_dir, "time")
    os.makedirs(time_output_dir, exist_ok=True)
    analyze_time(real_times, syn_times, ehr, time_output_dir)
    return

if __name__ == "__main__":
    @ex.automain
    def run(_config):
        eval_stat_corr_time(_config)
        run_tstr(_config)
        run_pred_similarity(_config)