import time

import numpy as np
import pandas as pd

import solve_milp
from dagsolver_utils import plot_heatmap, ExDagDataException

from data_generation_loading_utils import load_problem
from metrics_utils import calculate_metrics
import utils
from structure.dynotears import from_numpy_dynamic


def start_experiment(dataset, solver_name, seed, problem_name, d, n, p, intra_edge_ratio, w_min_inter, w_max_inter, noise_scale, noise_scale_variance, time_limit=7200, plots=False):
    utils.set_random_seed(seed)
    W_true, W_bi_true, B_true, B_bi_true, A_true, X, Y, tabu_edges, intra_nodes, inter_nodes = load_problem(dataset, problem_name, d, n, p, intra_edge_ratio, 1, w_max_inter, w_min_inter, 1.1, noise_scale, noise_scale_variance)


    n, d = X.shape

    start_time = time.time()
    Wbi = None
    if solver_name == 'exdbn':
        W_est, A_est, gap, lazy_count, stats = solve_milp.solve(X, 0.1, 0.1, 0, Y=Y,
                                                                B_ref=B_true,
                                                                tabu_edges=None, time_limit=time_limit)
        if W_est is None:
            print(f'Error: Gurobi has not found a solution', 'error.txt')
            return None
        if stats:
            table_data = {}
            table_data['Time'] = [s[0] for s in stats]
            table_data['Def_thresh_SHD'] = [s[1] for s in stats]
            table_data['Best_SHD'] = [s[2] for s in stats]
            table_data['Best_threshold'] = [s[3] for s in stats]
            table_data['Objective_val'] = [s[4] for s in stats]
            table_data['Dag_t'] = [s[5] for s in stats]

    elif solver_name == 'lingam':
        from solve_lingam import solve_lingam
        W_est, A_est = solve_lingam(X, Y, p)
    elif solver_name == 'nts_notears':
        from solve_nts_notears import solve_nts_notears
        W_est, A_est = solve_nts_notears(X, Y, p)
    elif solver_name == 'dynotears':
        X_lag = np.concatenate(Y, axis=1)
        _, W_est, A_est_concated = from_numpy_dynamic(X,X_lag, w_threshold=0.1, lambda_w=0.03, lambda_a=0.03)
        A_est = []
        for lag in range(p):
            idxs = list(range(d*lag, d*(lag+1)))
            A_est_lag = A_est_concated[idxs,:]
            A_est.append(A_est_lag)
        #print(A_est)

    else:
        assert False
    solving_duration = time.time() - start_time
    if not utils.is_dag(W_est):
        print('Error: Graph found is not DAG')
        return None


    best_W, best_Wbi, best_A, shd, f1 = calculate_metrics(X, Y, W_true, B_true, A_true, W_est, A_est, W_bi_true, Wbi)

    if plots:
        for lag, A_est_lag in enumerate(A_est):
            plot_heatmap(A_est_lag, inter_nodes, intra_nodes, filename=f'W_lag{lag}_est_heatmap.png')
        plot_heatmap(W_est, intra_nodes, intra_nodes, filename=f'W_est_heatmap.png')

    return solving_duration, shd, f1


def run_cds_experiment():
    start_experiment('cds','exdbn', 0, 'cds', None, 1000, 2, None, None, None, True, None, True)

def run_convergence_experiment():
    rows = []
    datasets = [('sf', [(1, 3, 0.2, 0.4, 1.0, None)]), ('er', [(2,2, 0.3, 0.5, 1.0, None)])]
    for dataset, dataset_params in datasets:
        for params in dataset_params:
            for time_limit in [60, 120, 300, 600, 1200, 1800, 2700, 3600, 5400, 7200]:
                for seed in range(10):
                    try:
                        duration, shd, f1 = start_experiment('dynamic', 'exdbn', seed, dataset, 25, 250, *params, time_limit=time_limit)
                        rows.append((dataset, time_limit, seed, shd, f1))
                    except ExDagDataException:
                        pass  # for same combination of edge ratio and number of variables the underlying truth graph does not exists, We just skip it

        columns = ['Dataset','Run time', 'seed', 'shd', 'f1']
        df = pd.DataFrame(rows, columns=columns)

        print(df)

def run_all_experiments():
    rows = []
    algs=['exdbn','dynotears', 'lingam', 'nts_notears']
    datasets = [('er', [(1,3, 0.2, 0.4, 0.8, 0.4), (1,3, 0.2, 0.4, 1.0, None), (2,2, 0.3, 0.5, 1.0, None)]), ('sf', [(1, 3, 0.2, 0.4, 1.0, None)])]
    for alg in algs:
        for dataset, dataset_params in datasets:
            for params in dataset_params:
                for d in [5,7,10,15,20, 25]:
                    for n in [50, 100, 250, 500, 1000]:
                        for seed in range(10):
                            print(alg, seed, dataset, d, n, *params)
                            try:
                                duration, shd, f1 = start_experiment('dynamic', alg, seed, dataset, d, n, *params, time_limit=120)
                                rows.append((alg, seed, dataset, d, n,  *params, shd, f1, duration))
                            except ExDagDataException:
                                pass # for same combination of edge ratio and number of variables the underlying truth graph does not exists, We just skip it


    columns = ['Algorithm', 'seed', 'dataset', 'number_of_variables', 'number_of_samples', 'p', 'intra_edge_ratio', 'w_min_inter', 'w_max_inter', "noise_scale", "noise_scale_variance", 'shd', 'f1', 'Run time']

    df = pd.DataFrame(rows, columns=columns)

    print(df)

if __name__ == "__main__":
    run_all_experiments()
    run_convergence_experiment()
    run_cds_experiment()



