import os
import os.path as osp
import random
from tqdm import tqdm
import numpy as np
import networkx as nx
import torch
import copy
data_dir = f'./data/CRCG-MOTIF/raw/'


from molecular import generate_false_cause_dataset2, molecular_generators, feature_connection, reindex_graph
from BA3_loc import find_gd

def ensure_connected(G):

    if nx.is_connected(G):
        return G
    
    largest_cc = max(nx.connected_components(G), key=len)
    G_connected = G.subgraph(largest_cc).copy()
    return G_connected


def generateBaseshape(index, node_feature_mean=[1.5, 2.0, 1.2, 1.3, 1.8], std=[1.5, 2.0, 1.2, 1.3, 1.8]):
    motif, role_id = molecular_generators[index][0](random.randint(5,10), node_feature_mean, std)
    motif = ensure_connected(motif)

    nodes_to_remove = [node for node in motif.nodes if 'feature' not in motif.nodes[node]]
    motif.remove_nodes_from(nodes_to_remove)
    motif, mapping = reindex_graph(motif)

    role_id = [0] * motif.number_of_nodes()
    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()
    edge_index = torch.tensor(list(motif.edges()), dtype=torch.long).t().contiguous()
    if edge_index.size(1) == 0:

        default_value = random.randint(0, motif.number_of_nodes() - 1)
        edge_index = torch.tensor([[default_value], [default_value]], dtype=torch.long)
    node_features = [motif.nodes[node]['feature'].tolist() for node in motif.nodes if 'feature' in motif.nodes[node]]
    return motif, role_id, edge_index, node_features


def addOtherShape(indexs, G, num):
    G_other = copy.deepcopy(G)
    additional_nodes_count = 0
    for _ in range(num):
            additional_index = random.choice([i for i in range(1, 27) if i not in indexs])
            node_feature_mean = [index] * 5
            std = [index] * 5
            motif, role_id, edge_index, node_features = generateBaseshape(additional_index, node_feature_mean=node_feature_mean, std=std)
            G_other, role_id, edge_index, node_features = feature_connection(G_other, motif)
            additional_nodes_count += motif.number_of_nodes()
    return G_other, role_id, edge_index, additional_nodes_count, node_features


def addNoise(G, G_other, num_nodes_to_add):

    G_noisy = copy.deepcopy(G_other)

    for _ in range(num_nodes_to_add):
        node_id = G_noisy.number_of_nodes()
        G_noisy.add_node(node_id)
        G_noisy.nodes[node_id]['feature'] = np.random.normal([1.5, 2.0, 1.2, 1.3, 1.8], [1.5, 2.0, 1.2, 1.3, 1.8])

        connected = False
        while not connected:
            nodes_to_connect = random.sample(G_noisy.nodes(), 3)
            for n in nodes_to_connect:
                if not G.has_node(n) and not G_noisy.has_edge(node_id, n):
                    G_noisy.add_edge(node_id, n)
            connected = nx.is_connected(G_noisy)
            if not connected:
                for n in nodes_to_connect:
                    if G_noisy.has_edge(node_id, n):
                        G_noisy.remove_edge(node_id, n)

    node_features = [G_noisy.nodes[node]['feature'].tolist() for node in G_noisy.nodes if 'feature' in G_noisy.nodes[node]]
    role_id = [0] * G_noisy.number_of_nodes()
    role_id[-6:] = torch.randint(1, 4, (6,)).tolist()

    edge_index = torch.tensor(list(G_noisy.edges()), dtype=torch.long).t().contiguous()
    return G_noisy, role_id, edge_index, node_features

edge_index_list, label_list = [], []
ground_truth_list, role_id_list, pos_list = [], [], []
e_mean, n_mean = [], []
node_features_list = []
indexs = [1, 2, 3, 4, 5]
for _ in tqdm(range(500)):
    index = random.choice(indexs)
    node_feature_mean = [index-1] * 5
    std = [index-1] * 5
    label=index-1

    G, role_id, edge_index, node_features = generateBaseshape(index, node_feature_mean=node_feature_mean, std=std)  
   
    G_other, role_id, edge_index, additional_nodes_count, node_features = addOtherShape(indexs, G, 3)
    G, role_id, edge_index, node_features = addNoise(G, G_other, additional_nodes_count)

    node_features_list.append(node_features)
    label_list.append(label)
    e_mean.append(len(G.edges))
    n_mean.append(len(G.nodes))
    role_id = np.array(role_id)
    edge_index=edge_index
    role_id_list.append(role_id)
    edge_index_list.append(edge_index)
    pos_list.append(np.array(list(nx.spring_layout(G).values())))
    ground_truth_list.append(find_gd(edge_index, role_id))

print("#Graphs: %d    #Nodes: %.2f    #Edges: %.2f " % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))
np.save(osp.join(data_dir, 'train.npy'), {'node_features':node_features_list, 'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})

edge_index_list, label_list = [], []
ground_truth_list, role_id_list, pos_list = [], [], []
e_mean, n_mean = [], []
node_features_list = []
indexs = [1, 2, 3, 4, 5]
for _ in tqdm(range(200)):
    index = random.choice(indexs)
    node_feature_mean = [index-1] * 5
    std = [index-1] * 5
    label=index-1

    G, role_id, edge_index, node_features = generateBaseshape(index, node_feature_mean=node_feature_mean, std=std)  
   
    G_other, role_id, edge_index, additional_nodes_count, node_features = addOtherShape(indexs, G, 3)
    G, role_id, edge_index, node_features = addNoise(G, G_other, additional_nodes_count)

    node_features_list.append(node_features)
    label_list.append(label)
    e_mean.append(len(G.edges))
    n_mean.append(len(G.nodes))
    role_id = np.array(role_id)
    edge_index=edge_index
    role_id_list.append(role_id)
    edge_index_list.append(edge_index)
    pos_list.append(np.array(list(nx.spring_layout(G).values())))
    ground_truth_list.append(find_gd(edge_index, role_id))

print("#Graphs: %d    #Nodes: %.2f    #Edges: %.2f " % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))
np.save(osp.join(data_dir, 'val.npy'), {'node_features':node_features_list, 'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})

edge_index_list, label_list = [], []
ground_truth_list, role_id_list, pos_list = [], [], []
e_mean, n_mean = [], []
print("#Graphs: %d    #Nodes: %.2f    #Edges: %.2f " % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))
np.save(osp.join(data_dir, 'val.npy'), {'node_features':node_features_list, 'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})

edge_index_list, label_list = [], []
ground_truth_list, role_id_list, pos_list = [], [], []
e_mean, n_mean = [], []
node_features_list = []
indexs = [1, 2, 3, 4, 5]
for _ in tqdm(range(200)):
    index = random.choice(indexs)
    node_feature_mean = [index-1] * 5
    std = [index-1] * 5
    label=index-1

    G, role_id, edge_index, node_features = generateBaseshape(index, node_feature_mean=node_feature_mean, std=std)  
   
    G_other, role_id, edge_index, additional_nodes_count, node_features = addOtherShape(indexs, G, 3)
    G, role_id, edge_index, node_features = addNoise(G, G_other, additional_nodes_count)

    node_features_list.append(node_features)
    label_list.append(label)
    e_mean.append(len(G.edges))
    n_mean.append(len(G.nodes))
    role_id = np.array(role_id)
    edge_index=edge_index
    role_id_list.append(role_id)
    edge_index_list.append(edge_index)
    pos_list.append(np.array(list(nx.spring_layout(G).values())))
    ground_truth_list.append(find_gd(edge_index, role_id))

print("#Graphs: %d    #Nodes: %.2f    #Edges: %.2f " % (len(ground_truth_list), np.mean(n_mean), np.mean(e_mean)))
np.save(osp.join(data_dir, 'test.npy'), {'node_features':node_features_list, 'edge_index': edge_index_list, 'label': label_list, 'ground_truth': ground_truth_list, 'role_id': role_id_list, 'pos': pos_list})
