import os
import sys
import unittest

import numpy as np
import pandas as pd
from flaky import flaky

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

from benchmark.data.generator import CAMPolyMechanism
from causal_discovery.scamuv import SCAMUV

ALPHA = 0.1
N_SAMPLES = 1000
NUM_RUNS = 20
REGRESSION = 'xgboost'
CV = 2


class IsLeafTest(unittest.TestCase):

    @flaky(max_runs=3)
    def test_leaf_collider(self):
        incorrect = []
        for _ in range(NUM_RUNS):
            x = np.random.uniform(size=(N_SAMPLES, 1))
            y = np.random.uniform(size=(N_SAMPLES, 1))
            nz = np.random.uniform(size=N_SAMPLES)
            z = np.expand_dims(CAMPolyMechanism(1)(x) + CAMPolyMechanism(1)(y) + nz, axis=1)
            data = np.concatenate([x, y, z], axis=1)
            nodes = ['x', 'y', 'z']
            df = pd.DataFrame(data, columns=nodes)
            df = df / df.std()

            algo = SCAMUV(ALPHA, REGRESSION, cv=CV)
            algo.data = df
            res = algo.get_unconfounded_leaf(nodes, nodes)
            incorrect.append('z' != res)
            # if incorrect[-1]:
            #    print(res)
            #    n_bins = 20
            #    plt.hist(algo._get_delta('x', nodes), label='x', alpha=1, bins=n_bins)
            #    plt.hist(algo._get_delta('y', nodes), label='y', alpha=.5, bins=n_bins)
            #    plt.hist(algo._get_delta('z', nodes), label='z', alpha=.5, bins=n_bins)
            #    plt.legend()
            #    plt.show()
        print(np.mean(incorrect))
        self.assertLess(np.mean(incorrect), .5)  # TODO why is this case so difficult?

    @flaky(max_runs=3)
    def test_leaf_chain(self):
        incorrect = []
        for _ in range(NUM_RUNS):
            x = np.random.uniform(size=(N_SAMPLES, 1))
            y = np.expand_dims(CAMPolyMechanism(1)(x) + np.random.uniform(size=N_SAMPLES), axis=1)
            z = np.expand_dims(CAMPolyMechanism(1)(y) + np.random.uniform(size=N_SAMPLES), axis=1)
            data = np.concatenate([x, y, z], axis=1)
            nodes = ['x', 'y', 'z']
            df = pd.DataFrame(data, columns=nodes)
            df = df / df.std()

            algo = SCAMUV(ALPHA, REGRESSION, cv=CV)
            algo.data = df
            incorrect.append('z' != algo.get_unconfounded_leaf(nodes, nodes))
        print(np.mean(incorrect))
        self.assertLess(np.mean(incorrect), .5)

    @flaky(max_runs=3)
    def test_leaf_confounders(self):
        incorrect = []
        for _ in range(NUM_RUNS):
            x = np.random.uniform(size=(N_SAMPLES, 1))
            y = np.expand_dims(CAMPolyMechanism(1)(x) + np.random.uniform(size=N_SAMPLES), axis=1)
            xy_array = np.concatenate([x, y], axis=1)
            z = np.expand_dims(CAMPolyMechanism(2)(xy_array) + np.random.uniform(size=N_SAMPLES), axis=1)
            data = np.concatenate([x, y, z], axis=1)
            nodes = ['x', 'y', 'z']
            df = pd.DataFrame(data, columns=nodes)
            df = df / df.std()

            algo = SCAMUV(ALPHA, REGRESSION, cv=CV)
            algo.data = df
            incorrect.append('z' != algo.get_unconfounded_leaf(nodes, nodes))
        print(np.mean(incorrect))
        self.assertLess(np.mean(incorrect), .5)

    def test_non_leaf_hidden_confounders(self):
        incorrect = []
        for _ in range(NUM_RUNS):
            x = np.random.uniform(size=(N_SAMPLES, 1))
            y = np.expand_dims(CAMPolyMechanism(1)(x) + np.random.uniform(size=N_SAMPLES), axis=1)
            xy_array = np.concatenate([x, y], axis=1)
            z = np.expand_dims(CAMPolyMechanism(2)(xy_array) + np.random.uniform(size=N_SAMPLES), axis=1)
            data = np.concatenate([y, z], axis=1)
            nodes = ['y', 'z']
            df = pd.DataFrame(data, columns=nodes)
            df = df / df.std()

            algo = SCAMUV(ALPHA, REGRESSION, cv=CV)
            algo.data = df
            incorrect.append(algo.get_unconfounded_leaf(nodes, nodes) is not None)
        print(np.mean(incorrect))
        self.assertLess(np.mean(incorrect), .5)

    # def test_non_leaf_hidden_mediator(self):
    #    incorrect = []
    #    for _ in range(NUM_RUNS):
    #        x = np.random.normal(size=(N_SAMPLES, 1))
    #        y = np.expand_dims(CAMPolyMechanism(1)(x) + np.random.normal(size=N_SAMPLES), axis=1)
    #        z = np.expand_dims(CAMPolyMechanism(1)(y) + np.random.normal(size=N_SAMPLES), axis=1)
    #        data = np.concatenate([x, z], axis=1)
    #        nodes = ['x', 'z']
    #        df = pd.DataFrame(data, columns=nodes)
    #        df = df / df.std()
    #
    #        algo = SCAMUV(ALPHA)
    #        algo.data = df
    #        incorrect.append(algo.get_unconfounded_leaf(nodes, nodes) is not None)
    #    print(np.mean(incorrect))
    #    self.assertLess(np.mean(incorrect), .5)
    # TODO test doesn't work


if __name__ == '__main__':
    unittest.main()
