import numpy as np
import torch
import pickle as pkl
import os
import networkx as nx
from networkx.algorithms import tree


"""
Molecule Number:
aspirin : 13
azobenzene : 14
benzene : 6
ethanol : 3
malonaldehyde : 5
naphthalene : 10
paracetamol : 11
salicylic : 10
toluene : 7
uracil : 8
"""
class MD17Dataset():
    """
    MD17 Dataset

    """

    def __init__(self, partition, max_samples, delta_frame, data_dir, molecule_type):
        # setup a split, tentative setting
        train_par, val_par, test_par = 0.1, 0.05, 0.05
        full_dir = os.path.join(data_dir, molecule_type + '_dft.npz')
        split_dir = os.path.join(data_dir, molecule_type + '_split.pkl')
        data = np.load(full_dir)
        self.partition = partition
        self.molecule_type = molecule_type

        x = data['R']
        v = x[1:] - x[:-1]
        x = x[:-1]

        try:
            with open(split_dir, 'rb') as f:
                print('Got Split!')
                split = pkl.load(f)
        except:
            _x = x[10000: -10000]

            train_idx = np.random.choice(np.arange(_x.shape[0]), size=int(train_par * _x.shape[0]), replace=False)
            flag = np.zeros(_x.shape[0])
            for _ in train_idx:
                flag[_] = 1
            rest = [_ for _ in range(_x.shape[0]) if not flag[_]]
            val_idx = np.random.choice(rest, size=int(val_par * _x.shape[0]), replace=False)
            for _ in val_idx:
                flag[_] = 1
            rest = [_ for _ in range(_x.shape[0]) if not flag[_]]
            test_idx = np.random.choice(rest, size=int(test_par * _x.shape[0]), replace=False)

            train_idx += 10000
            val_idx += 10000
            test_idx += 10000

            split = (train_idx, val_idx, test_idx)

            with open(split_dir, 'wb') as f:
                pkl.dump(split, f)
            print('Generate and save split!')

        if partition == 'train':
            st = split[0]
        elif partition == 'val':
            st = split[1]
        elif partition == 'test':
            st = split[2]
        else:
            raise NotImplementedError()

        st = st[:max_samples]

        z = data['z']
        print('mol idx:', z)
        x = x[:, z > 1, ...]
        v = v[:, z > 1, ...]
        z = z[z > 1]

        x_0, v_0 = x[st], v[st]
        x_t, v_t = x[st + delta_frame], v[st + delta_frame]

        print('Got {:d} samples!'.format(x_0.shape[0]))

        mole_idx = z
        print('mole_idx shape', mole_idx.shape)
        n_node = mole_idx.shape[0]
        self.n_node = n_node

        _lambda = 1.6

        def d(_i, _j, _t):
            return np.sqrt(np.sum((x[_t][_i] - x[_t][_j]) ** 2))

        n = z.shape[0]

        self.Z = torch.Tensor(z)

        atom_edges = torch.zeros(n, n).int()
        for i in range(n):
            for j in range(n):
                if i != j:
                    _d = d(i, j, 0)
                    if _d < _lambda:
                        atom_edges[i][j] = 1

        self.atom_edge = atom_edges
        edge_attr = []
        # Initialize edges and edge_attributes
        rows, cols = [], []
        for i in range(n_node):
            for j in range(n_node):
                if i != j:
                    rows.append(i)
                    cols.append(j)
                    edge_attr.append(atom_edges[i][j])

        edges = [rows, cols]  # edges for equivariant message passing
        edge_attr = torch.Tensor(np.array(edge_attr)).unsqueeze(-1)  # [edge, 1]
        self.edge_attr = edge_attr  # [num_edges,1]
        self.edges = edges  # [2, edge]
        
        self.x_0, self.v_0, self.x_t, self.v_t = torch.Tensor(x_0), torch.Tensor(v_0), torch.Tensor(x_t), torch.Tensor(
            v_t)
        self.mole_idx = torch.Tensor(mole_idx)


    def __getitem__(self, i):
        return self.x_0[i], self.v_0[i], self.edge_attr, self.x_t[i]

    def __len__(self):
        return len(self.x_0)

    def get_edges(self, batch_size, n_nodes):
        edges = [torch.LongTensor(self.edges[0]), torch.LongTensor(self.edges[1])]
        if batch_size == 1:
            return edges
        elif batch_size > 1:
            rows, cols = [], []
            for i in range(batch_size):
                rows.append(edges[0] + n_nodes * i)
                cols.append(edges[1] + n_nodes * i)
            edges = [torch.cat(rows), torch.cat(cols)]
        return edges

if __name__ == '__main__':
    data = MD17Dataset(partition='train', max_samples=3000, delta_frame=1000, data_dir='dataset', molecule_type='aspirin')
    # print("data length:", len(data))
    # loader_train = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True, drop_last=True)
    # for batch_idx, data in enumerate(loader_train):
    #     batch_size, num_agents, _ = data[0].size()
        
    #     loc, vel, edge_attr, loc_end, vel_end = data
    #     print("loc shape:", loc.shape)
    #     print("vel shape:", vel.shape)
    #     print("edge shape:", edge_attr.shape)
    #     edges = loader_train.dataset.get_edges(batch_size, num_agents)
    #     print(edges)
        
