import os
import pandas as pd
import matplotlib.pyplot as plt
import mlflow
import re
import logging
from omegaconf import OmegaConf
from omegaconf import DictConfig as Config
import plotly.express as px


_log = logging.getLogger(__name__)


# avoid leaking sensitive info
credentials_path = "src/syngen/local/credentials.yml"
if os.path.exists(credentials_path):
    credentials = OmegaConf.load(credentials_path)
    MLFLOW_SERVER_URL = credentials.mlflow.server_url
    MLFLOW_TRACKING_USERNAME = credentials.mlflow.user
    MLFLOW_TRACKING_PASSWORD = credentials.mlflow.password
else:
    raise FileNotFoundError(f"Credentials file not found: {credentials_path}")


mlflow.set_tracking_uri(MLFLOW_SERVER_URL)
mlflow.set_registry_uri(MLFLOW_SERVER_URL)

# Set authentication via environment variables
os.environ["MLFLOW_TRACKING_USERNAME"] = MLFLOW_TRACKING_USERNAME
os.environ["MLFLOW_TRACKING_PASSWORD"] = MLFLOW_TRACKING_PASSWORD

# Initialize MLflow Client
client = mlflow.MlflowClient()


def get_next_version_number(experiment_name):
    """
    Retrieve the highest version number from past runs and increment it.
    """
    try:
        experiment = client.get_experiment_by_name(experiment_name)
        
        # version_1 if no experiment exists
        if experiment is None:
            return 1

        experiment_id = experiment.experiment_id
        runs = client.search_runs(experiment_id, order_by=["start_time DESC"])
        
        max_version = 0

        for run in runs:
            match = re.search(r"version_(\d+)", run.info.run_name)
            if match:
                version_num = int(match.group(1))
                max_version = max(max_version, version_num)

        return max_version + 1

    except Exception as e:
        print(f"Error retrieving version number: {e}")
        return 1  # Fallback to version_1


def log_artifact_with_info(file_path, description):
    if os.path.exists(file_path):
        mlflow.log_artifact(file_path)
        _log.info(f"{description} saved to MLflow artifacts.")
    else:
        _log.warning(f"File not found: {file_path}. Skipping artifact logging.")


def log_csv(file_dir, file_name, description):
    file_path = os.path.join(file_dir, file_name)
    log_artifact_with_info(file_path, description)


# ------------------ DATA GENERATION EXPERIMENT ------------------ #

def log_gan_metrics(datagen_method, save_path):
    loss_df = datagen_method.model.get_loss_values() # __dict__["loss_values"]

    loss_csv_path = f"{save_path}/training_loss.csv"
    loss_df.to_csv(loss_csv_path, index=False)
    log_artifact_with_info(loss_csv_path, "Training metrics dataframe")

    for _, row in loss_df.iterrows():
        epoch = int(row["Epoch"])
        mlflow.log_metric("generator_loss", row["Generator Loss"], step=epoch)
        mlflow.log_metric("discriminator_loss", row["Discriminator Loss"], step=epoch)

    fig = datagen_method.model.get_loss_values_plot()

    loss_plot_path = f"{save_path}/loss_plot.png"
    fig.write_image(loss_plot_path)
    log_artifact_with_info(loss_plot_path, "Training plots")


def log_tvae_metrics(datagen_method, save_path):
    loss_df = datagen_method.model.get_loss_values()

    loss_csv_path = f"{save_path}/training_loss.csv"
    loss_df.to_csv(loss_csv_path, index=False)
    log_artifact_with_info(loss_csv_path, "Training metrics dataframe")

    for _, row in loss_df.iterrows():
        epoch = int(row["Epoch"])
        mlflow.log_metric("loss", row["Loss"], step=epoch)

    plt.plot(loss_df["Epoch"], loss_df["Loss"], label="Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss per Epoch")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    loss_plot_path = f"{save_path}/loss_plot.png"
    plt.savefig(loss_plot_path)
    log_artifact_with_info(loss_plot_path, "Training plots")
    plt.close()


def log_ctabgan_metrics(datagen_method, save_path):
    loss_df = datagen_method.loss_values

    loss_csv_path = os.path.join(save_path, "training_loss.csv")
    loss_df.to_csv(loss_csv_path, index=False)
    log_artifact_with_info(loss_csv_path, "CTABGAN training loss (CSV)")

    for epoch, row in loss_df.iterrows():
        mlflow.log_metric("generator_loss", row["g_loss"], step=epoch)
        mlflow.log_metric("discriminator_loss", row["d_loss"], step=epoch)
        mlflow.log_metric("info_loss", row["info_loss"], step=epoch)
        mlflow.log_metric("classifier_fake_loss", row["classifier_fake_loss"], step=epoch)
        mlflow.log_metric("classifier_real_loss", row["classifier_real_loss"], step=epoch)

    fig = px.line(loss_df, x=loss_df.index, y=["g_loss", "d_loss", "info_loss"],
                  labels={"value": "Loss", "index": "Epoch"},
                  title="CTABGAN Training Losses")

    loss_plot_path = os.path.join(save_path, "loss_plot.png")
    fig.write_image(loss_plot_path)
    log_artifact_with_info(loss_plot_path, "CTABGAN training loss plot")


def log_transformer_metrics(datagen_method, save_path):
    loss_df = datagen_method.model.loss_values

    # Save CSV
    loss_csv_path = os.path.join(save_path, "training_loss.csv")
    loss_df.to_csv(loss_csv_path, index=False)
    log_artifact_with_info(loss_csv_path, f"Training metrics dataframe")

    # Log metrics to MLflow per epoch
    for _, row in loss_df.iterrows():
        epoch = int(row["Epoch"])
        mlflow.log_metric("loss", row["Loss"], step=epoch)
        # mlflow.log_metric("loss_cat", row["Loss_Categorical"], step=epoch)
        # mlflow.log_metric("loss_num", row["Loss_Numerical"], step=epoch)

    # Plot all losses
    plt.figure(figsize=(8, 5))
    plt.plot(loss_df["Epoch"], loss_df["Loss"], label="Loss")
    # plt.plot(loss_df["Epoch"], loss_df["Loss_Categorical"], label="Categorical Loss")
    # plt.plot(loss_df["Epoch"], loss_df["Loss_Numerical"], label="Numerical Loss")

    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss per Epoch")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    # Save plot
    loss_plot_path = os.path.join(save_path, "loss_plot.png")
    plt.savefig(loss_plot_path)
    log_artifact_with_info(loss_plot_path, "Training plots")
    plt.close()

    # get latent space (if applicable)
    if hasattr(datagen_method.model, "latent_space"):
        _log.info(f"Saving PCA visualization...")
        # Save plot
        pca_plot_path = os.path.join(save_path, "pca_plot.png")
        # fig = datagen_method.model.visualize_latent_space(datagen_method.train_data, datagen_method.condition_column)
        fig = datagen_method.model.latent_space
        fig.write_image(pca_plot_path)
        _log.info(f"PCA visualization saved to {pca_plot_path}")
        log_artifact_with_info(pca_plot_path, "Latent space visualization with PCA")



def log_training_experiment(cfg:Config, datagen_method):
    """
    Logs for data generation methods training.
    """
    experiment_name = f"{cfg.datagen_method.name}_{cfg.dataset.name}_seed{cfg.seed}_train"

    # Get the next available version number
    next_version = get_next_version_number(experiment_name)
    run_name = f"version_{next_version}"

    mlflow.set_experiment(experiment_name)

    with mlflow.start_run(run_name=run_name, nested=True) as run:
        # Log parameters
        # if cfg.datagen_method.name != "tabddpm":
        mlflow.log_params({
            "dataset_name": cfg.dataset.name,
            "generation_method": cfg.datagen_method.name,
            "seed": cfg.seed,
            "training_time": datagen_method.train_time if hasattr(datagen_method, "train_time") else None,
            "generation_time": datagen_method.gen_time if hasattr(datagen_method, "gen_time") else None,
            # **datagen_method.model.get_parameters(),
            **cfg.datagen_method.params
        })
        model_dir = os.path.join(cfg.paths.datagen_methods_dir, f"{cfg.datagen_method.name}/{cfg.dataset.name}")
        os.makedirs(model_dir, exist_ok=True)
        model_path = os.path.join(model_dir, f"{cfg.datagen_method.model_filename}")
        if hasattr(datagen_method, "model") and hasattr(datagen_method.model, "save"):
            datagen_method.model.save(model_path)
        else:
            datagen_method.save(model_path)
        log_artifact_with_info(model_path, f"'{cfg.datagen_method.name}'")
        # else:
        # mlflow.log_params({
        #     "dataset_name": cfg.dataset.name,
        #     "generation_method": cfg.datagen_method.name,
        #     "seed": cfg.seed,
        #     **cfg.datagen_method.params
        # })
        # model_dir = os.path.join(cfg.paths.datagen_methods_dir, f"{cfg.datagen_method.name}/{cfg.dataset.name}")
        # os.makedirs(model_dir, exist_ok=True)
        # model_path = os.path.join(model_dir, f"{cfg.datagen_method.name}.pkl")
        # datagen_method.save(model_path)
        # log_artifact_with_info(model_path, f"Saved {cfg.datagen_method.name} model")


        # log original train and test datasets
        log_csv(cfg.paths.processed_data_dir, cfg.dataset.splits.train.path, "Original train data")
        log_csv(cfg.paths.processed_data_dir, cfg.dataset.splits.test.path, "Original test data")

        if cfg.datagen_method.name == "ctgan" or cfg.datagen_method.name == "copulagan":
            log_gan_metrics(datagen_method, model_dir)
        elif cfg.datagen_method.name == "tvae":
            log_tvae_metrics(datagen_method, model_dir)
        elif cfg.datagen_method.name == "ctabgan":
            log_ctabgan_metrics(datagen_method, model_dir)
        else:
            log_transformer_metrics(datagen_method, model_dir)

        return run.info.run_id


# ------------------ EVALUATION EXPERIMENT ------------------ #

def log_eval_experiment(cfg:Config, mle_results, dcr_metrics, dcr_nndr_per_class_records, avg_wd, avg_jsd, wd_jsd_per_class, 
                        corr_results, matrix_path, ft_distr_path, ft_distr_paths_per_class, training_run_id=None, gen_time=None):
    sampling = cfg.datagen_method.params.sampling_strategy

    experiment_name = f"{cfg.datagen_method.name}_{cfg.dataset.name}_{sampling}_seed{cfg.seed}_eval"
    
    next_version = get_next_version_number(experiment_name)
    run_name = f"experiment_{cfg.datagen_method.params.sampling_strategy}_version_{next_version}"

    # Ensure experiment exists
    experiment = client.get_experiment_by_name(experiment_name)
    if experiment is None:
        experiment_id = client.create_experiment(experiment_name)
    else:
        experiment_id = experiment.experiment_id

    metrics = {
        'Avg_Balanced_Accuracy_Real_All_Models': mle_results['Avg Balanced Accuracy (Real)'].mean(),
        'Std_Balanced_Accuracy_Real_All_Models': mle_results['Std Balanced Accuracy (Real)'].mean(),
        'Avg_Balanced_Accuracy_Synthetic_All_Models': mle_results['Avg Balanced Accuracy (Synthetic)'].mean(),
        'Std_Balanced_Accuracy_Synthetic_All_Models': mle_results['Std Balanced Accuracy (Synthetic)'].mean(),
        'Avg_Balanced_Accuracy_Diff_All_Models': mle_results['Avg Balanced Accuracy Diff'].mean(),

        'Avg_F1_Score_Class_0_Real_All_Models': mle_results['Avg F1 Score (Class 0, Real)'].mean(),
        'Std_F1_Score_Class_0_Real_All_Models': mle_results['Std F1 Score (Class 0, Real)'].mean(),
        'Avg_F1_Score_Class_0_Synthetic_All_Models': mle_results['Avg F1 Score (Class 0, Synthetic)'].mean(),
        'Std_F1_Score_Class_0_Synthetic_All_Models': mle_results['Std F1 Score (Class 0, Synthetic)'].mean(),
        'Avg_F1_Score_Diff_Class_0_All_Models': mle_results['Avg F1 Score Diff (Class 0)'].mean(),

        'Avg_F1_Score_Class_1_Real_All_Models': mle_results['Avg F1 Score (Class 1, Real)'].mean(),
        'Std_F1_Score_Class_1_Real_All_Models': mle_results['Std F1 Score (Class 1, Real)'].mean(),
        'Avg_F1_Score_Class_1_Synthetic_All_Models': mle_results['Avg F1 Score (Class 1, Synthetic)'].mean(),
        'Std_F1_Score_Class_1_Synthetic_All_Models': mle_results['Std F1 Score (Class 1, Synthetic)'].mean(),
        'Avg_F1_Score_Diff_Class_1_All_Models': mle_results['Avg F1 Score Diff (Class 1)'].mean(),

        'Avg_F1_Score_Real_All_Models': mle_results['Avg F1 Score (Real)'].mean(),
        'Std_F1_Score_Real_All_Models': mle_results['Std F1 Score (Real)'].mean(),
        'Avg_F1_Score_Synthetic_All_Models': mle_results['Avg F1 Score (Synthetic)'].mean(),
        'Std_F1_Score_Synthetic_All_Models': mle_results['Std F1 Score (Synthetic)'].mean(),
        'Avg_F1_Score_Diff_All_Models': mle_results['Avg F1 Score Diff'].mean(),

        'Avg_AUC_ROC_Real_All_Models': mle_results['Avg AUC-ROC (Real)'].mean(),
        'Std_AUC_ROC_Real_All_Models': mle_results['Std AUC-ROC (Real)'].mean(),
        'Avg_AUC_ROC_Synthetic_All_Models': mle_results['Avg AUC-ROC (Synthetic)'].mean(),
        'Std_AUC_ROC_Synthetic_All_Models': mle_results['Std AUC-ROC (Synthetic)'].mean(),
        'Avg_AUC_ROC_Diff_All_Models': mle_results['Avg AUC-ROC Diff'].mean(),
            
        "Pearson_Pairwise_Correlation_Score": corr_results,

        "Avg_WD": avg_wd,
        "Avg_JSD": avg_jsd,

        "DCR_Real_vs_Fake": dcr_metrics["DCR_RF"],
        "DCR_Real_vs_Real": dcr_metrics["DCR_RR"],
        "DCR_Fake_vs_Fake": dcr_metrics["DCR_FF"],
        "NNDR_Real_vs_Fake": dcr_metrics["NNDR_RF"],
        "NNDR_Real_vs_Real": dcr_metrics["NNDR_RR"],
        "NNDR_Fake_vs_Fake": dcr_metrics["NNDR_FF"],

        # "DCR_Real_vs_Fake_Class_0": dcr_nndr_per_class_records[0]["DCR_RF"],
        # "DCR_Real_vs_Real_Class_0": dcr_nndr_per_class_records[0]["DCR_RR"],
        # "DCR_Fake_vs_Fake_Class_0": dcr_nndr_per_class_records[0]["DCR_FF"],
        # "NNDR_Real_vs_Fake_Class_0": dcr_nndr_per_class_records[0]["NNDR_RF"],
        # "NNDR_Real_vs_Real_Class_0": dcr_nndr_per_class_records[0]["NNDR_RR"],
        # "NNDR_Fake_vs_Fake_Class_0": dcr_nndr_per_class_records[0]["NNDR_FF"],

        # "DCR_Real_vs_Fake_Class_1": dcr_nndr_per_class_records[1]["DCR_RF"],
        # "DCR_Real_vs_Real_Class_1": dcr_nndr_per_class_records[1]["DCR_RR"],
        # "DCR_Fake_vs_Fake_Class_1": dcr_nndr_per_class_records[1]["DCR_FF"],
        # "NNDR_Real_vs_Fake_Class_1": dcr_nndr_per_class_records[1]["NNDR_RF"],
        # "NNDR_Real_vs_Real_Class_1": dcr_nndr_per_class_records[1]["NNDR_RR"],
        # "NNDR_Fake_vs_Fake_Class_1": dcr_nndr_per_class_records[1]["NNDR_FF"],
    }

    for entry in wd_jsd_per_class:
        cls = entry["class"]
        metrics[f"WD_Class_{cls}"] = entry["WD"]
        metrics[f"JSD_Class_{cls}"] = entry["JSD"]

    for entry in dcr_nndr_per_class_records:
        cls = entry["class"]
        for k, v in entry.items():
            if k != "class":
                metrics[f"{k}_Class_{cls}"] = v


    with mlflow.start_run(experiment_id=experiment_id, run_name=run_name, nested=True):

        # Log parameters
        mlflow.log_params({
            "dataset_name": cfg.dataset.name,
            "generation_method": cfg.datagen_method.name,
            "sampling_strategy": sampling,
            "seed": cfg.seed,
            "generation_time": gen_time,
        })

        if training_run_id:
            mlflow.set_tag("training_run_id", training_run_id)

        # log original train and test datasets
        log_csv(cfg.paths.processed_data_dir, cfg.dataset.splits.train.path, "Original train data")
        log_csv(cfg.paths.processed_data_dir, cfg.dataset.splits.test.path, "Original test data")

        # log synthetic dataset
        log_csv(f"{cfg.paths.synth_data_dir}{cfg.datagen_method.name}/{cfg.datagen_method.params.sampling_strategy}/", cfg.dataset.synthetic_path,
                    f"Synthetic data produced by '{cfg.datagen_method.name}'")

        # log plots
        log_artifact_with_info(matrix_path, "Absolute differences between correlation matrices")
        log_artifact_with_info(ft_distr_path, "Feature distribution plots")
        for ft_distr_path_per_class in ft_distr_paths_per_class:
            log_artifact_with_info(ft_distr_path_per_class, "Feature distribution plots per class")

        # log classification report
        log_csv(f"{cfg.paths.results_dir}/{cfg.datagen_method.name}/{cfg.dataset.name}/{cfg.datagen_method.params.sampling_strategy}/", "mle_score.csv", f"Machine Learning Efficacy results for '{cfg.datagen_method.name}'")
        
        # log the density error results per class
        density_err_per_class_df = pd.DataFrame(wd_jsd_per_class)
        density_err_per_class_path = os.path.join(cfg.paths.results_dir, cfg.datagen_method.name, cfg.dataset.name, cfg.datagen_method.params.sampling_strategy)
        os.makedirs(density_err_per_class_path, exist_ok=True)
        density_err_per_class_df.to_csv(os.path.join(density_err_per_class_path, "density_error_per_class.csv"), index=False)
        log_csv(density_err_per_class_path, "density_error_per_class.csv", f"Density error results per class for '{cfg.datagen_method.name}'")

        # Create and log DCR/NNDR table
        dcr_df = pd.DataFrame([dcr_metrics])
        dcr_csv_path = os.path.join(cfg.paths.results_dir, cfg.datagen_method.name, cfg.dataset.name, cfg.datagen_method.params.sampling_strategy, "dcr_nndr_metrics.csv")
        os.makedirs(os.path.dirname(dcr_csv_path), exist_ok=True)
        dcr_df.to_csv(dcr_csv_path, index=False)
        log_csv(os.path.dirname(dcr_csv_path), "dcr_nndr_metrics.csv", "Distance and Privacy Metrics (DCR & NNDR)")

        # Create DCR/NNDR table per class
        dcr_nndr_per_class_records_df = pd.DataFrame(dcr_nndr_per_class_records)
        dcr_csv_path = os.path.join(cfg.paths.results_dir, cfg.datagen_method.name, cfg.dataset.name, cfg.datagen_method.params.sampling_strategy, "dcr_nndr_metrics_per_class.csv")
        os.makedirs(os.path.dirname(dcr_csv_path), exist_ok=True)
        dcr_nndr_per_class_records_df.to_csv(dcr_csv_path, index=False)
        log_csv(os.path.dirname(dcr_csv_path), "dcr_nndr_metrics_per_class.csv", "Distance and Privacy Metrics (DCR & NNDR) per class")

        # log the rest of the metrics
        mlflow.log_metrics(metrics)
        try:
            df_per_class = pd.read_csv(os.path.join(density_err_per_class_path, "density_error_per_class.csv"))
            for _, row in df_per_class.iterrows():
                class_id = row["class"]
                mlflow.log_metric(f"Class_{class_id}_WD", row["WD"])
                mlflow.log_metric(f"Class_{class_id}_JSD", row["JSD"])
                mlflow.log_metric(f"Class_{class_id}_WD/JSD", row["WD/JSD"])
        except Exception as e:
            _log.warning(f"Could not log per-class metrics to MLflow: {e}")
        _log.info("All evaluation metrics are logged to the MLflow server.")
