import random
import numpy as np
import torch
from torch.nn import functional as F
import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.utils import is_undirected, to_undirected, remove_self_loops, to_dense_adj, dense_to_sparse



def get_edges(n_nodes):
    rows, cols = [], []
    for i in range(n_nodes):
        for j in range(n_nodes):
            if i != j:
                rows.append(i)
                cols.append(j)

    edges = [rows, cols]
    return edges

def create_1_wl():
    dataset = []
    # Environment 0
    # atoms = torch.LongTensor([ 0, 1, 1, 1, 2 ])
    size = 6
    atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0 ])
    rows, cols = get_edges(size)
    edge_index = torch.LongTensor([ rows, cols ])# fully vonnected
    #laplacians of 2-regular, 6 vertices non isomorphic graphs
    pos = torch.FloatTensor([[ 6.73555740e-17,  1.10000000e+00,  2.20000000e+00],
                        [-2.02066722e-16, -1.10000000e+00, -2.20000000e+00],
                        [ 2.03951216e+00,  6.27697301e+00,  0.00000000e+00],
                        [-2.03951216e+00, -6.27697301e+00,  0.00000000e+00],
                        [-3.30000000e+00,  4.04133444e-16,  4.40000000e+00],
                        [ 3.30000000e+00, -8.08266887e-16, -4.40000000e+00]])


    y = torch.LongTensor([0])  # Label 0
    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)
    
    # Environment 1
    # atoms = torch.LongTensor([ 0, 1, 1, 1, 2 ])
    size = 6
    atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0 ])
    rows, cols = get_edges(size)
    edge_index = torch.LongTensor([ rows, cols ])# fully vonnected
    pos = torch.FloatTensor([[ 6.73555740e-17,  1.10000000e+00, -2.20000000e+00],
                        [-2.02066722e-16, -1.10000000e+00,  2.20000000e+00],
                        [ 2.03951216e+00,  6.27697301e+00,  0.00000000e+00],
                        [-2.03951216e+00, -6.27697301e+00,  0.00000000e+00],
                        [-3.30000000e+00,  4.04133444e-16,  4.40000000e+00],
                        [ 3.30000000e+00, -8.08266887e-16, -4.40000000e+00]])

    y = torch.LongTensor([1])  # Label 1
    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data2.edge_index = to_undirected(data2.edge_index)
    dataset.append(data2)
    
    return dataset
def create_incompleteness_1_a_envs():
    dataset = []
    # Environment 0
    # atoms = torch.LongTensor([ 0, 1, 1, 1, 2 ])
    size = 5
    atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])
    rows, cols = get_edges(size)
    edge_index = torch.LongTensor([ rows, cols ])# fully vonnected
    pos = torch.FloatTensor([ [-2,0,-2],
                              [2,0,2],
                              [0,1,-1],
                              [1,1,0],
                              [-1,-1,0] ])


    y = torch.LongTensor([0])  # Label 0
    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)
    
    # Environment 1
    # atoms = torch.LongTensor([ 0, 1, 1, 1, 2 ])
    size = 5
    atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])
    rows, cols = get_edges(size)
    edge_index = torch.LongTensor([ rows, cols ])# fully vonnected
    pos = torch.FloatTensor([[-2,0,-2],
                             [2,0,2],
                             [0,1,1],
                             [1,1,0],
                             [-1,-1,0]])
    y = torch.LongTensor([1])  # Label 1
    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data2.edge_index = to_undirected(data2.edge_index)
    dataset.append(data2)
    
    return dataset
def create_incompleteness_1_b_envs():
    dataset = []
    # Environment 0
    # atoms = torch.LongTensor([ 0, 1, 1, 1, 2 ])
    size = 7
    atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0, 0 ])
    rows, cols = get_edges(size)
    edge_index = torch.LongTensor([ rows, cols ])# fully vonnected

    pos = torch.FloatTensor([
    [-2,0,-2],
    [2,0,2],
    [1,1,0],
    [-1,-1,0],
    [1,2,0],
    [-1,2,0],
    [0,0,1]
    ])


    y = torch.LongTensor([0])  # Label 0
    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)
    
    # Environment 1
    # atoms = torch.LongTensor([ 0, 1, 1, 1, 2 ])
    size = 7
    atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0, 0 ])
    rows, cols = get_edges(size)
    edge_index = torch.LongTensor([ rows, cols ])# fully vonnected
    pos = torch.FloatTensor([
    [-2,0,-2],
    [2,0,2],
    [1,1,0],
    [-1,-1,0],
    [1,2,0],
    [-1,2,0],
    [0,0,-1]
    ])
    y = torch.LongTensor([1])  # Label 1
    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data2.edge_index = to_undirected(data2.edge_index)
    dataset.append(data2)
    
    return dataset
def create_incompleteness_1_c_envs():
    dataset = []
    # Environment 0
    # atoms = torch.LongTensor([ 0, 1, 1, 1, 2 ])
    size = 6
    atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0 ])
    rows, cols = get_edges(size)
    edge_index = torch.LongTensor([ rows, cols ])# fully vonnected

    pos = torch.FloatTensor([
    [-2,0,-2],
    [2,0,2],
    [1.47633,1.84294,0],
    [-0.74309,1.455,0],
    [-0.72972,2.82605,0],
    [0,0,1]
    ])


    y = torch.LongTensor([0])  # Label 0
    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)
    
    # Environment 1
    # atoms = torch.LongTensor([ 0, 1, 1, 1, 2 ])
    size = 6
    atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0 ])
    rows, cols = get_edges(size)
    edge_index = torch.LongTensor([ rows, cols ])# fully vonnected
    pos = torch.FloatTensor([
    [-2,0,-2],
    [2,0,2],
    [1.47633,1.84294,0],
    [-0.74309,1.455,0],
    [-0.72972,2.82605,0],
    [0,0,-1]
    ])
    y = torch.LongTensor([1])  # Label 1
    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data2.edge_index = to_undirected(data2.edge_index)
    dataset.append(data2)
    
    return dataset