import random

import numpy as np
import torch

from scamuv.causal_discovery.scamuv import SCAMUV
from scamuv.data.generate_data import get_confounded_datasets

if __name__ == '__main__':
    seed = 42
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    for i, (data, ground_truth) in enumerate(get_confounded_datasets(5, 5, 1, 500)):
        print("Test dataset Nr. {}".format(i))
        algo = SCAMUV(alpha=0.01)
        g_hat = algo.fit(data)
