import time

import pandas as pd

from dagsolver_utils import plot_heatmap, ExDagDataException
from solve_causal_learn import solve_fci
from solvers.IP4AncADMG.solver_ip4ancadmg import solve_ip4ancadmg
from solvers.IPBMandHforCGL.solver_ipbm import solve_ipbm

from data_generation_loading_utils import load_problem
from metrics_utils import calculate_metrics, calculate_metrics_pag
import utils

import solve_exmag

def start_experiment(solver_name, seed, problem_name, d, n, edge_ratio, pdir, pbidir, max_in_arrows, plots=False):
    utils.set_random_seed(seed)
    try:
        W_true, W_bi_true, B_true, B_bi_true, W_lags_true, B_lags_true, X, Y, tabu_edges, intra_nodes, inter_nodes = load_problem(problem_name, d, n, edge_ratio, pdir, pbidir, max_in_arrows)
    except ExDagDataException:
        return 0, -1, -1 # cannot generate ground truth graph for such combination of parameters

    n, d = X.shape
    p = len(W_lags_true) if W_lags_true is not None else 0

    start_time = time.time()
    W_est = None
    W_bi_est = None
    B_est, B_bi_est = None, None
    if solver_name == "exmag":

        W_est, W_bi_est, gap, lazy_count, stats = solve_exmag.solve(X, 1.0, 'l2', 'l1',
                                                               0.001, tabu_edges=tabu_edges,
                                                               B_ref=B_true, mode='all_cycles',
                                                               time_limit=900,
                                                               robust=False, weights_bound=100.0)
        B_est = (W_est != 0).astype(int)
        B_bi_est = (W_bi_est != 0).astype(int)
        A_est = []
        if W_est is None:
            print(f'Error: Gurobi has not found a solution')
            return None
        if stats:
            table_data = {}
            table_data['Time'] = [s[0] for s in stats]
            table_data['Def_thresh_SHD'] = [s[1] for s in stats]
            table_data['Best_SHD'] = [s[2] for s in stats]
            table_data['Best_threshold'] = [s[3] for s in stats]
            table_data['Objective_val'] = [s[4] for s in stats]
            table_data['Dag_t'] = [s[5] for s in stats]

    elif solver_name == 'ip4ancadmg':
        B_est, B_bi_est = solve_ip4ancadmg(X)
        W_est = B_est
        W_bi_est = B_bi_est
        A_est = []
    elif solver_name == 'ipbm':
        B_est, B_bi_est = solve_ipbm(X)
        W_est = B_est
        W_bi_est = B_bi_est
        A_est = []
    elif solver_name == 'fci':
        B_est, B_bi_est = solve_fci(X)
        W_est = None
        W_bi_est = None
        A_est = []
    else:
        assert False
    solving_duration = time.time() - start_time


    # B and B_bi
    # B[i,j] = 1 :::: i --> j :::: direct cause
    # B[i,j] = 2 ::::  i o-> j :::: direct cause or common confounder
    # B_bi[i,j] = 1 :::: i <-> j :::: common confounder
    # B_bi[i,j] = 2 :::: i --- j :::: selection bias - common effect
    # B_bi[i,j] = 3 :::: i o-o j :::: direct cause either direction or common confounder or common effect

    if W_est is None:
        assert B_est is not None
        shd, f1  = calculate_metrics_pag(B_true, B_bi_true, B_est, B_bi_est)

    else:
        if not utils.is_dag(W_est):
            print('Error: Graph found is not DAG')
            return None
        if B_est is not None:
            assert not (B_est == 2).any() or (B_bi_est == 3).any() or (
                        B_bi_est == 2).any(), 'PAG with weights - Not correctly implemented yet'
        best_W, best_Wbi, best_A, shd, f1 = calculate_metrics(X, Y, W_true, B_true, W_lags_true, B_lags_true, W_est, A_est,
                                                     W_bi_true, B_bi_true, W_bi_est)


        #best_W, best_Wbi, best_A, shd, f1 = calculate_metrics(X, Y, W_true, B_true, W_lags_true, W_est, A_est, W_bi_true, Wbi)

        if plots:
            plot_heatmap(W_bi_est, intra_nodes, intra_nodes, filename=f'W_bi_est_heatmap.png')
            plot_heatmap(W_est, intra_nodes, intra_nodes, filename=f'W_est_heatmap.png')

    return solving_duration, shd, f1


def run_cds_experiment():
    start_experiment('exmag', 0, 'cds', None, 1000, None, None, None, None, True)

def run_all_experiments():
    rows = []
    algs=['exmag', 'ip4ancadmg', 'ipbm', 'fci']
    datasets = [('ermag', [(2,None, None, None), (3,None, None, None), (5,None, None, None), (7,None, None, None)]), ('bowfree_admg', [(None, 0.4, 0.3, 3),(None, 0.2, 0.15, 3), (None, 0.4, 0.3, None),(None, 0.2, 0.15, None), (None, 0.2, 0.3, None)])]
    for alg in algs:
        for dataset, dataset_params in datasets:
            for params in dataset_params:
                for d in [5,10,15,20]:
                    for n in [50, 100, 500, 1000]:
                        for seed in range(10):
                            print(alg, seed, dataset, d, n, *params)
                            duration, shd, f1 = start_experiment(alg, seed, dataset, d, n, *params)
                            rows.append((alg, dataset, d, n,  *params, shd, f1, duration))


    columns = ['Algorithm', 'seed', 'dataset', 'number_of_variables', 'number_of_samples', 'edge_ratio',  'pdir', 'pbidir', 'max_in_arrows', 'Run time']

    # Create DataFrame
    df = pd.DataFrame(rows, columns=columns)

    print(df)

if __name__ == "__main__":
    run_all_experiments()
    run_cds_experiment()



