import random
import numpy as np
import torch
import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.utils import is_undirected, to_undirected, remove_self_loops, to_dense_adj, dense_to_sparse
from e3nn import o3


def create_kchains(k):
    assert k >= 2
    
    dataset = []

    # Graph 0
    atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )
    edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )
    pos = torch.FloatTensor(
        [[-4, -3, 0]] + 
        [[0, 5*i , 0] for i in range(k)] + 
        [[4, 5*(k-1) + 3, 0]]
    )
    center_of_mass = torch.mean(pos, dim=0)
    pos = pos - center_of_mass
    y = torch.LongTensor([0])  # Label 0
    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)
    
    # Graph 1
    atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )
    edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )
    pos = torch.FloatTensor(
        [[4, -3, 0]] + 
        [[0, 5*i , 0] for i in range(k)] + 
        [[4, 5*(k-1) + 3, 0]]
    )
    center_of_mass = torch.mean(pos, dim=0)
    pos = pos - center_of_mass
    y = torch.LongTensor([1])  # Label 1
    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data2.edge_index = to_undirected(data2.edge_index)
    dataset.append(data2)
    
    return dataset


def create_two_body_envs():
    dataset = []

    # Environment 0
    # atoms = torch.LongTensor([ 0, 1, 2 ])
    atoms = torch.LongTensor([ 0, 0, 0 ])
    edge_index = torch.LongTensor([ [0, 0], [1, 2] ])
    pos = torch.FloatTensor([ 
        [0, 0, 0],
        [5, 0, 0],
        [3, 0, 4]
    ])
    y = torch.LongTensor([0])  # Label 0
    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)
    
    # Environment 1
    # atoms = torch.LongTensor([ 0, 1, 2 ])
    atoms = torch.LongTensor([ 0, 0, 0 ])
    edge_index = torch.LongTensor([ [0, 0], [1, 2] ])
    pos = torch.FloatTensor([ 
        [0, 0, 0],
        [5, 0, 0],
        [-5, 0, 0]
    ])
    y = torch.LongTensor([1])  # Label 1
    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data2.edge_index = to_undirected(data2.edge_index)
    dataset.append(data2)
    
    return dataset


def create_three_body_envs():
    dataset = []

    a_x, a_y, a_z = 5, 0, 5
    b_x, b_y, b_z = 5, 5, 5
    c_x, c_y, c_z = 0, 5, 5
    
    # Environment 0
    # atoms = torch.LongTensor([ 0, 1, 2, 3, 4 ])
    atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])
    edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])
    pos = torch.FloatTensor([ 
        [0, 0, 0],
        [a_x, a_y, a_z],
        [+b_x, +b_y, b_z],
        [-b_x, -b_y, b_z],
        [c_x, +c_y, c_z],
    ])
    y = torch.LongTensor([0])  # Label 0
    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)
    
    # Environment 1
    # atoms = torch.LongTensor([ 0, 1, 2, 3, 4 ])
    atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])
    edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])
    pos = torch.FloatTensor([ 
        [0, 0, 0],
        [a_x, a_y, a_z],
        [+b_x, +b_y, b_z],
        [-b_x, -b_y, b_z],
        [c_x, -c_y, c_z],
    ])
    y = torch.LongTensor([1])  # Label 1
    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data2.edge_index = to_undirected(data2.edge_index)
    dataset.append(data2)
    
    return dataset


def create_four_body_nonchiral_envs():
    dataset = []

    a1_x, a1_y, a1_z = 3, 2, -4
    a2_x, a2_y, a2_z = 0, 2, 5
    a3_x, a3_y, a3_z = -3, 2, -4
    b1_x, b1_y, b1_z = 3, -2, -4
    b2_x, b2_y, b2_z = 0, -2, 5
    b3_x, b3_y, b3_z = -3, -2, -4
    c_x, c_y, c_z = 0, 5, 0

    angle = 2 * torch.pi / 10 # random angle
    Q = o3.matrix_y(torch.tensor(angle)).numpy()

    # Environment 0
    # atoms = torch.LongTensor([ 0, 1, 1, 1, 1, 1, 1, 2 ])
    atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0, 0, 0 ])
    edge_index = torch.LongTensor([ [0, 0, 0, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7] ])
    pos = torch.FloatTensor([ 
        [0, 0, 0],
        [a1_x, a1_y, a1_z],
        [a2_x, a2_y, a2_z],
        [a3_x, a3_y, a3_z],
        [b1_x, b1_y, b1_z] @ Q,
        [b2_x, b2_y, b2_z] @ Q,
        [b3_x, b3_y, b3_z] @ Q,
        [c_x, +c_y, c_z],
    ])
    y = torch.LongTensor([0])  # Label 0
    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)
    
    # Environment 1
    # atoms = torch.LongTensor([ 0, 1, 1, 1, 1, 1, 1, 2 ])
    atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0, 0, 0 ])
    edge_index = torch.LongTensor([ [0, 0, 0, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7] ])
    pos = torch.FloatTensor([ 
        [0, 0, 0],
        [a1_x, a1_y, a1_z],
        [a2_x, a2_y, a2_z],
        [a3_x, a3_y, a3_z],
        [b1_x, b1_y, b1_z] @ Q,
        [b2_x, b2_y, b2_z] @ Q,
        [b3_x, b3_y, b3_z] @ Q,
        [c_x, -c_y, c_z],
    ])
    y = torch.LongTensor([1])  # Label 1
    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data2.edge_index = to_undirected(data2.edge_index)
    dataset.append(data2)
    
    return dataset


def create_four_body_chiral_envs():
    dataset = []

    a1_x, a1_y, a1_z = 3, 0, -4
    a2_x, a2_y, a2_z = 0, 0, 5
    a3_x, a3_y, a3_z = -3, 0, -4
    c_x, c_y, c_z = 0, 5, 0

    # Environment 0
    # atoms = torch.LongTensor([ 0, 1, 1, 1, 2 ])
    atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])
    edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])
    pos = torch.FloatTensor([ 
        [0, 0, 0],
        [a1_x, a1_y, a1_z],
        [a2_x, a2_y, a2_z],
        [a3_x, a3_y, a3_z],
        [c_x, +c_y, c_z],
    ])
    y = torch.LongTensor([0])  # Label 0
    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data1.edge_index = to_undirected(data1.edge_index)
    dataset.append(data1)
    
    # Environment 1
    # atoms = torch.LongTensor([ 0, 1, 1, 1, 2 ])
    atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])
    edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])
    pos = torch.FloatTensor([ 
        [0, 0, 0],
        [a1_x, a1_y, a1_z],
        [a2_x, a2_y, a2_z],
        [a3_x, a3_y, a3_z],
        [c_x, -c_y, c_z],
    ])
    y = torch.LongTensor([1])  # Label 1
    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)
    data2.edge_index = to_undirected(data2.edge_index)
    dataset.append(data2)
    
    return dataset
