import pickle as pkl
from tqdm import tqdm
from joblib import Parallel, delayed

import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset
from torch_geometric.data import Data
from e3nn.o3 import matrix_x, matrix_y, matrix_z

class TetrahedronDataset(Dataset):
    def __init__(
        self, 
        data_dir: str = './', 
        partition: str = 'train', 
        label_type: str = 'G',
        max_samples: float = 1e8, 
        cutoff_rate: float = 0.0, 
        device: str = 'cpu'
    ):
        super().__init__()
        self.device = device
        
        self.data_dir, self.partition = data_dir, partition
        self.label_type = label_type
        
        self.max_samples = int(max_samples)
        self.cutoff_rate = cutoff_rate

        tet_list = self.load_data()

        print('Processing data ...')
        self.data = self.process(tet_list)  # Process in GPU
        print(f'{partition} dataset total len: {len(self.data)}')
        print(self.data[0])


    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, i):
        return self.data[i]
    

    def load_data(self):
        tet_list = np.load(f'{self.data_dir}/{self.partition}.npy', allow_pickle=True).tolist()
        return tet_list[0:self.max_samples]
    
    
    def process(self, tet_list):
        num_systems = len(tet_list)

        data = []
        for i in tqdm(range(num_systems)):
            node_pos = Tensor(tet_list[i]['pos']).to(self.device)
            label = Tensor(tet_list[i][self.label_type]).unsqueeze(0).to(self.device)
            data.append(self.get_graph_step(node_pos, label).to('cpu'))
        return data
    

    def get_graph_step(self, node_pos, label):        
        if self.partition == 'test':
            angle = 2 * torch.pi * torch.rand([3], device=self.device)
            rotate_matrix = matrix_x(angle[0]) @ matrix_y(angle[1]) @ matrix_z(angle[2])
            node_pos, label = node_pos @ rotate_matrix, label @ rotate_matrix

        # Edge
        edge_index = self.cutoff_edge(node_pos)
        row, col = edge_index
        edge_attr = torch.norm(node_pos[row] - node_pos[col], dim=1, keepdim=True)

        # Node Feat
        node_feat = torch.zeros([node_pos.size(0), 1], device=self.device)

        return Data(
            node_feat=node_feat,
            node_pos=node_pos, 
            edge_index=edge_index,
            edge_attr=edge_attr, 
            label=label, 
        )


    def cutoff_edge(self, node_pos):
        # Complete Graph and Cutoff    
        num_node_r = node_pos.size(0)
        dist = torch.cdist(node_pos, node_pos, p=2)  # [num_node_r, num_node_r]
        dist += torch.eye(num_node_r).to(node_pos.device) * 1e18  # [num_node_r, num_node_r]
        num_edge_rr_chosen = int(num_node_r * (num_node_r - 1) * (1 - self.cutoff_rate))
        _, id_chosen = torch.topk(dist.view(num_node_r * num_node_r), num_edge_rr_chosen, dim=0, largest=False)
        edge_rr = torch.cat([
            id_chosen.div(num_node_r, rounding_mode='trunc').unsqueeze(0), 
            id_chosen.remainder(num_node_r ).unsqueeze(0)
        ], dim=0).long()  # [2, num_edge_rr_chosen]
        return edge_rr
    
