from SCM import CausalGraph as cg
from SCM.Mappers import *

import pandas as pd

def drift_node(graph : cg, drifted_node, new_target):
    """Simulates a concept drift in the specified node of the causal graph."""

    drifted_node.drift_history.append(copy.deepcopy(drifted_node.mapper)) 

    sorted_vertices = graph.topological_sort()

    training_set = {"X": [], "y": []}
    for _ in range(100):
        for v in graph.vertices.values():
            v.value = None

        for v_name in sorted_vertices:
            vtx = graph.vertices[v_name]

            if vtx.is_root():
                vtx.value = float(vtx.mapper.map(None))
            else:
                vtx.compute_value()

            if vtx.name == drifted_node.name:
                break

        parent_values = np.array([p.value for p in drifted_node.parents])
        y = drifted_node.mapper.generate_untrained_example(X=parent_values)

        training_set["X"].append(parent_values)
        training_set["y"].append(y)

    if not hasattr(graph, 'concept_history'):
            graph.concept_history = []

    concept_snapshot = {
        "nodes": [drifted_node],
        "mappers": {node.name: copy.deepcopy(node.drift_history[-1]) for node in [drifted_node]}
    }
    graph.concept_history.append(concept_snapshot)

    drifted_node.mapper.drift(np.array(training_set["X"]), np.array(training_set["y"]), new_label_func = new_target)

alpha = 0
rho = 0.1

node1 = cg.Vertex("x1", mapper=NormalMapper(ewma_alpha=alpha, rho = rho))
node2 = cg.Vertex("x2", mapper=UniformMapper(ewma_alpha=alpha, rho = rho))
node3 = cg.Vertex("x3", mapper=TreeMapper(rho = rho))
node3.mapper.label_function = ThresholdFunction()
node4 = cg.Vertex("x4", mapper=RandomMLPMapper(rho = rho))
node5 = cg.Vertex("x5", mapper=SGDMapper(rho = rho))
node5.mapper.label_function = SineFunction()
# target_node = cg.Vertex("y", mapper=OnlineGaussianCategoricalMapper(max_classes=10))
target_node = cg.Vertex("y", mapper=RandomMLPMapper())

graph = cg.CausalGraph()

graph.add_edge(node1, node3)
graph.add_edge(node2, node3)

graph.add_edge(node3, node4)
graph.add_edge(node3, node5)

graph.add_edge(node1, target_node)
graph.add_edge(node4, target_node)
graph.add_edge(node5, target_node)

graph.add_vertex(node1)
graph.add_vertex(node2)
graph.add_vertex(node3)
graph.add_vertex(node4)
graph.add_vertex(node5)
graph.add_vertex(target_node)

graph.visualize_graph()

drift_points = []
drift_sizes = []
drift_types = []
drift_types_time = []

data1 = pd.DataFrame(graph.generate(dataset_size=2000, drift_points=drift_points, drift_sizes=drift_sizes, drift_types=drift_types, drift_types_time=drift_types_time, missing_prob=0, intervention_prob=0))
node2.mapper.drift()
drift_node(graph, node4, new_target=None)
drift_node(graph, target_node, new_target=None)
data2 = pd.DataFrame(graph.generate(dataset_size=2000, drift_points=drift_points, drift_sizes=drift_sizes, drift_types=drift_types, drift_types_time=drift_types_time, missing_prob=0, intervention_prob=0))
drift_node(graph, node5, new_target=SineFunction())
drift_node(graph, target_node, new_target=None)
data3 = pd.DataFrame(graph.generate(dataset_size=2000, drift_points=[0], drift_sizes=[1], drift_types=['virtual'], drift_types_time=['abrupt'], missing_prob=0, intervention_prob=0))
# target_node.mapper.severe_drift()
data4 = pd.DataFrame(graph.generate(dataset_size=2000, drift_points=[0], drift_sizes=[1], drift_types=['virtual'], drift_types_time=['abrupt'], missing_prob=0, intervention_prob=0))
drift_node(graph, node3, new_target=LinearFunction())
drift_node(graph, target_node, new_target=None)
data5 = pd.DataFrame(graph.generate(dataset_size=2000, drift_points=drift_points, drift_sizes=drift_sizes, drift_types=drift_types, drift_types_time=drift_types_time, missing_prob=0, intervention_prob=0))

data = pd.concat([data1, data2, data3, data4, data5])

data.to_csv("test.csv", index=False)