from os import device_encoding
import pickle as pkl
from torch_geometric import edge_index
from tqdm import tqdm
from joblib import Parallel, delayed

import numpy as np
import torch
from torch import LongTensor, Tensor
from torch.utils.data import Dataset
from torch_geometric.data import Data
from e3nn.o3 import matrix_x, matrix_y, matrix_z

class MD17Dataset(Dataset):
    """
    MD17Dataset
    """
    def __init__(
        self, 
        mol_type: str = 'aspirin', 
        data_dir: str = './', 
        partition: str = 'train', 
        max_samples: int = 500,
        delta_frame: int = 3000, 
        device: str = 'cpu'
    ):
        super(MD17Dataset, self).__init__()
        self.device = device
        
        self.mol_type = mol_type
        self.data_dir, self.partition = data_dir, partition
        self.max_samples = max_samples
        self.delta_frame = delta_frame

        pos, vel, atom = self.load_data()
        self.edge_index, self.edge_hop = self.get_edge(Tensor(pos[0]))

        print('Processing data ...')
        self.data = self.process(pos, vel, atom)  # Process in GPU
        print(f'{partition} dataset total len: {len(self.data)}')
        print(self.data[0])


    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, i):
        return self.data[i]
    

    def load_data(self):
        data = np.load(f'{self.data_dir}/{self.mol_type}.npz')
        with open(f'{self.data_dir}/{self.mol_type}_split.pkl', 'rb') as f:
            split = pkl.load(f)

        pos = data['R']
        vel = pos[1:] - pos[:-1]
        atom = data['z']

        self.frame_0 = split[{
            'train': 0,
            'valid': 1,
            'test' : 2,
        }[self.partition]][:self.max_samples]
        self.frame_t  = self.frame_0 + self.delta_frame

        pos = pos[:, atom > 1, ...]
        vel = vel[:, atom > 1, ...]
        atom = atom[atom > 1]

        return pos, vel, atom
    
    
    def process(self, pos, vel, atom):
        atom, pos, vel = Tensor(atom).unsqueeze(-1), Tensor(pos), Tensor(vel)  

        pos_0, vel_0 = pos[self.frame_0], vel[self.frame_0]
        pos_t, vel_t = pos[self.frame_t], vel[self.frame_t]

        num_systems = len(self.frame_0)
        pos_0, pos_t, vel_0, atom = pos_0.to(self.device), pos_t.to(self.device), vel_0.to(self.device), atom.to(self.device)

        data = []
        for i in tqdm(range(num_systems)):
            data.append(self.get_graph_step(pos_0[i], vel_0[i], atom, pos_t[i]).to('cpu'))

        return data
    

    def get_graph_step(self, pos_0, vel_0, atom, pos_t):        
        if self.partition == 'test':
            angle = 2 * torch.pi * torch.rand([3], device=self.device)
            rotate_matrix = matrix_x(angle[0]) @ matrix_y(angle[1]) @ matrix_z(angle[2])
            pos_0, pos_t, vel_0 = pos_0 @ rotate_matrix, pos_t @ rotate_matrix, vel_0 @ rotate_matrix

        # Edge
        row, col = self.edge_index
        edge_attr = torch.cat([
            atom[row], 
            atom[col], 
            self.edge_hop, 
            torch.zeros_like(self.edge_hop, device=self.edge_hop.device), 
            torch.norm(pos_0[row]-pos_0[col], dim=1, keepdim=True),
        ], dim=1)

        # Node Feat
        node_feat = torch.cat([
            torch.norm(vel_0, dim=1, keepdim=True), 
            atom / atom.max(),
        ], dim=1)

        return Data(
            node_feat=node_feat,
            node_pos=pos_0, 
            node_vel=vel_0,
            edge_index=self.edge_index,
            edge_attr=edge_attr, 
            label=pos_t, 
        )

    def get_edge(self, pos, _lambda = 1.6):
        num_node = pos.size(0)
        d = torch.cdist(pos, pos)
        d[d < _lambda] = 1
        d[d >= _lambda] = 0
        edge_matrix_1 = d
        d = d @ d
        d[d > 0] = 1
        edge_matrix_2 = d

        row, col, edge_hop = [], [], []
        for i in range(num_node):
            for j in range(num_node):
                if i != j:
                    if edge_matrix_1[i][j]:
                        row.append(i)
                        col.append(j)
                        edge_hop.append(1)
                    if edge_matrix_2[i][j]:
                        row.append(i)
                        col.append(j)
                        edge_hop.append(2)
        row, col ,edge_hop = LongTensor(row), LongTensor(col), Tensor(edge_hop).unsqueeze(-1)
        edge_index = [row, col]
        return edge_index, edge_hop.to(self.device)

    
if __name__ == '__main__':
    MyData = MD17Dataset(data_dir='/mnt/ai4sci_develop_fast/cenjc/Dataset/MD17')
