import os
import pprint
import random

import click
import numpy as np
import pandas as pd

from train.models import DAGMA, DCDFG, DCDI, NOBEARS, NOTEARS, Sortnregress
from train.utils import create_intervention_dataset, set_random_seed_all


MODEL_CLS_DCT = {
    model_cls.__name__: model_cls
    for model_cls in [
        DAGMA,
        NOBEARS,
        NOTEARS,
        Sortnregress,
        DCDFG,
        DCDI,
    ]
}

def generate_observational_dataset(
    save_dir=None,
):
    X_path = os.path.join(save_dir, "X.csv")
    Btrue_path = os.path.join(save_dir, "Btrue.csv")
    if os.path.exists(X_path) and os.path.exists(Btrue_path):
        print("Use existing dataset {}".format(save_dir))
        X = pd.read_csv(X_path, index_col=0)
        B_true = np.loadtxt(Btrue_path, delimiter=",").astype(np.int64)
        return X, B_true
    else:
        assert 1 == 2, "Not found dataset {}".format(save_dir)


def run_model(
    model_cls_name,
    dataset,
    B_true,
    model_kwargs=None,
    wandb_project=None,
    wandb_config_dict=None,
    save_dir=None,
    force=False,
):
    model_kwargs = model_kwargs or {}
    wandb_config_dict = wandb_config_dict or {}
    model_cls = MODEL_CLS_DCT[model_cls_name]

    if save_dir is not None:
        save_path = os.path.join(save_dir, f"{model_cls_name}.csv")
        if os.path.exists(save_path) and not force:
            print(f"Already ran {model_cls_name}, skipping. Use force=True to rerun.")
            return

    wandb_config_dict["model"] = model_cls_name
    model = model_cls()
    extra_kwargs = {}
    # ablation study and GPU
    model.train(
        dataset,
        log_wandb=False,
        wandb_project=wandb_project,
        wandb_config_dict=wandb_config_dict,
        **extra_kwargs,
    )
    try:
        metrics_dict = model.compute_metrics(B_true)
        metrics_dict["model"] = model_cls_name
        metrics_dict["train_time"] = model._train_runtime_in_sec

        # wandb.log(metrics_dict)
        # wandb.finish()

        B_pred = model.get_adjacency_matrix()
        if save_path:
            np.savetxt(save_path, B_pred, delimiter=",")
    except AssertionError:
        metrics_dict = None
        print("ignore {}".format(model_cls_name))
    return metrics_dict


@click.command()
@click.option("--seed", default=0, help="Random seed")
@click.option("--model", type=str, default="all", help="Which models to run")
@click.option("--force", default=True, help="If results exist, redo anyways.")
@click.option(
    "--save_mtxs", default=True, help="Save matrices to saved_mtxs/ directory"
)
@click.option(
    "--dataset", default="", type=str, help="Use existing dataset, override n, p, d and s"
)
def _run_full_pipeline(seed, model, force, save_mtxs, dataset):
    dataset_name = dataset
    save_dir = f"../../share_dataset/{dataset_name}"
    if save_mtxs:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir, exist_ok=True)
    X, B_true = generate_observational_dataset(save_dir)
    X_dataset = create_intervention_dataset(X)

    results_save_path = os.path.join(save_dir, "results.csv")
    results_df_rows = []
    if os.path.exists(results_save_path):
        results_df = pd.read_csv(results_save_path, index_col=0)
        results_df_rows = results_df.to_dict(orient="records")

    if model == "all":
        model_classes = MODEL_CLS_DCT
    else:
        model_classes = {model: MODEL_CLS_DCT[model]}
    for model_cls_name, model_cls in model_classes.items():
        # try:
        # set_random_seed_all(0)
        set_random_seed_all(seed)
        metrics_dict = run_model(
            model_cls_name,
            X_dataset,
            B_true,
            model_kwargs={},
            wandb_project="benchmark",
            wandb_config_dict={},
            save_dir=save_dir if save_mtxs else None,
            force=force,
        )
        # except Exception as e:
        #     print(f"Failed to run {model_cls_name}")
        #     print(e)
        #     wandb.finish()
        #     continue
        if metrics_dict is None:
            continue

        pprint.pprint(metrics_dict)
        results_df_rows.append(metrics_dict)
        results_df = pd.DataFrame.from_records(results_df_rows)
        mtxs_dir = f"saved_mtxs/{dataset_name}"
        if not os.path.exists(mtxs_dir):
            os.makedirs(mtxs_dir, exist_ok=True)
        results_df.to_csv(os.path.join(mtxs_dir, f"{model_cls_name}.csv"))

if __name__ == "__main__":
    _run_full_pipeline()

