import os
import random
import sys
import time
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
from cdt.metrics import SHD

from benchmark.data.generate_data import get_pag_skel_with_scam_orientations

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from causal_discovery.scamuv import SCAMUV
from data.generate_data import get_confounded_datasets
from utils.algo_wrappers import CAMUV
from utils.metrics import dtop

if __name__ == '__main__':
    seed = 42
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    result_dir = os.path.join('..', 'data')
    scam_results = defaultdict(lambda: [])
    camuv_results = defaultdict(lambda: [])
    for i, (data, ground_truth) in enumerate(get_confounded_datasets(5, 5, 1, 500)):
        print("Test dataset Nr. {}".format(i))
        start_t = time.time()
        algo = SCAMUV(alpha=0.01, verbose=False)
        g_hat = algo.fit(data)
        end_t = time.time()

        order = algo.order
        marginal_ground_truth = get_pag_skel_with_scam_orientations(ground_truth, data.keys())
        scam_results['D_top'].append(dtop(marginal_ground_truth, order, ignore_bidirected=True))
        scam_results['SHD'].append(SHD(marginal_ground_truth, g_hat))
        scam_results['time'].append(end_t - start_t)

        start_t = time.time()
        algo = CAMUV(alpha=0.01)
        g_hat = algo(data)
        end_t = time.time()

        scam_results['D_top'].append(0)
        scam_results['SHD'].append(SHD(marginal_ground_truth, g_hat))
        scam_results['time'].append(end_t - start_t)

    df = pd.DataFrame(scam_results)
    df.to_csv(result_dir + 'scam.csv')
    df = pd.DataFrame(scam_results)
    df.to_csv(result_dir + 'camuv.csv')
