import SCM.CausalGraph as cg
import random
import numpy as np
from SCM.Mappers import *
import pandas as pd


def get_mapper_lists():    
    root_mappers = [NormalMapper, UniformMapper]
    edge_mappers = [
        MLPMapping, TreeMapper, PrototypeCategoricalMapper, SGDMapper,
        OnlineGaussianCategoricalMapper, RandomRBFCategoricalMapper, RandomMLPMapper, RotatingHyperplaneMapper
    ]
    return root_mappers, edge_mappers

categorical_mappers = [OnlineGaussianCategoricalMapper, PrototypeCategoricalMapper, RandomRBFCategoricalMapper, RotatingHyperplaneMapper]
reg_mappers = [MLPMapping, TreeMapper, SGDMapper, RandomMLPMapper]

def build_connected_dag(n_features, n_roots, max_parents, task='classification'):
    root_mappers, edge_mappers = get_mapper_lists()
    n_features+=1
    
    nodes = [cg.Vertex(f"x{i}") for i in range(n_features)]
    
    root_indices = random.sample(range(n_features), k=n_roots)
    roots = [nodes[i] for i in root_indices]
    
    label_candidates = [i for i in range(n_features) if i not in root_indices]
    label_index = random.choice(label_candidates)
    label_node = nodes[label_index]
    label_node.name = 'y'
    
    node_order = list(range(n_features))
    random.shuffle(node_order)
    
    if len(root_indices) >= 2:
        roots_to_place = random.sample(root_indices, 2)
        
        for i in range(2):
            root = roots_to_place[i]
            idx_to_swap = node_order.index(root)
            node_order[i], node_order[idx_to_swap] = node_order[idx_to_swap], node_order[i]

    elif len(root_indices) == 1:
        root = root_indices[0]
        idx_to_swap = node_order.index(root)
        node_order[0], node_order[idx_to_swap] = node_order[idx_to_swap], node_order[0]
    
    graph = cg.CausalGraph()
    for node in nodes:
        graph.add_vertex(node)
    
    for i, idx in enumerate(node_order):
        node = nodes[idx]
        if node in roots:
            continue

        candidate_parents = [nodes[j] for j in node_order[:i]]
        
        max_parents_this_node = min(max_parents, len(candidate_parents))
        if (max_parents_this_node < 1):
            raise ValueError("Max parents can't be less than 1")
        if node is label_node:
            n_parents = max(2, random.randint(1, max_parents_this_node))
        else:
            n_parents = random.randint(1, max_parents_this_node)
            n_parents = max(1, n_parents)
        parents = random.sample(candidate_parents, n_parents)
        for p in parents:
            graph.add_edge(p, node)
            
    for node in nodes:
        if node.is_root():
            node.mapper = random.choice(root_mappers)()
        elif node is label_node:
            if task == 'regression':
                node.mapper = np.random.choice(reg_mappers)()
            else:
                node.mapper = np.random.choice(categorical_mappers)()
        else:
            node.mapper = random.choice(edge_mappers)()
    
    return graph, roots, label_node


if __name__ == '__main__':

    # This code generates 10,000 samples from a random 25-dimensional SCM and saves it to data_sample.csv.
    # Feel free to modify the parameters below to generate different datasets.
    # The task of the generated dataset is classification by default. Change the 'problem' parameter in build_connected_dag to 'regression' for regression tasks.

    n_features = 10
    graph, roots, label_node = build_connected_dag(n_features, 3, 3, task='classification')

    output = 'data_sample'
            
    graph.visualize_graph(output)    

    # 10,000 samples with drifts at specified points
    drift_points = [1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000]
    
    # Drift events generated below
    ## Points at which drifts occur
    drift_sizes = [500, 1, 500, 1, 1, 250, 1, 1, 500]
    ## Types of drifts
    drift_types = ['real', 'real', 'severe', 'real', 'real', 'severe', 'real', 'real', 'real']
    ## Types of drifts over time
    drift_types_time = ['incremental', 'abrupt', 'gradual', 'abrupt', 'abrupt', 'incremental', 'abrupt', 'abrupt', 'gradual']
    
    print("Generating Data...")
    df = pd.DataFrame(graph.generate(10000, intervention_prob=0.1, drift_points=drift_points, drift_sizes=drift_sizes, drift_types_time=drift_types_time, drift_types=drift_types, missing_prob=0))


    # Save the graph structure
    graph.save_graph_to_json(f'{output}.json')
    graph.save_graph(f'{output}.pkl')
    
    df.to_csv(f'{output}.csv', index=False)
    
    print(f"Success! Data file was saved to {output}")