import pickle as pkl
from tqdm import tqdm
from joblib import Parallel, delayed

import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset
from torch_geometric.data import Data
from e3nn.o3 import matrix_x, matrix_y, matrix_z

class NBodySystemDataset(Dataset):
    """
    NBodySystemDataset
    """
    def __init__(
        self, 
        dataset_name: str = '5_0_0', 
        data_dir: str = './', 
        partition: str = 'train', 
        max_samples: float = 1e8, 
        frame_0: int = 30, 
        frame_T: int = 40, 
        cutoff_rate: float = 0.0, 
        device: str = 'cpu'
    ):
        super(NBodySystemDataset, self).__init__()
        self.device = device
        
        self.data_dir, self.partition = data_dir, partition
        self.suffix = f'{self.partition}_charged{dataset_name}'
        
        self.max_samples = int(max_samples)
        self.frame_0, self.frame_T = frame_0, frame_T
        self.cutoff_rate = cutoff_rate

        loc, vel, charges = self.load_data()

        print('Processing data ...')
        self.data = self.process(loc, vel, charges)  # Process in GPU
        print(f'{partition} dataset total len: {len(self.data)}')
        print(self.data[0])


    def __len__(self):
        return len(self.data)

    
    def __getitem__(self, i):
        return self.data[i]
    

    def load_data(self):
        loc = np.load(f'{self.data_dir}/loc_{self.suffix}.npy')
        vel = np.load(f'{self.data_dir}/vel_{self.suffix}.npy')
        charges = np.load(f'{self.data_dir}/charges_{self.suffix}.npy')

        return loc, vel, charges
    
    
    def process(self, loc, vel, charges):
        charges, loc, vel = Tensor(charges), Tensor(loc), Tensor(vel)  

        # loc, vel = loc[0:self.max_samples].transpose(2, 3), vel[0:self.max_samples].transpose(2, 3)
        loc, vel = loc[0:self.max_samples], vel[0:self.max_samples]

        charges = charges[0: self.max_samples]  # [num_systems, num_node_r, 1]

        loc_0, loc_t = loc[:, self.frame_0, ...], loc[:, self.frame_T, ...]  # [num_systems, num_node_r, 3]
        vel_0, vel_t = vel[:, self.frame_0, ...], vel[:, self.frame_T, ...]  # [num_systems, num_node_r, 3]

        num_systems, num_node_r, _ = charges.size()
        loc_0, loc_t, vel_0, charges = loc_0.to(self.device), loc_t.to(self.device), vel_0.to(self.device), charges.to(self.device)

        data = []
        for i in tqdm(range(num_systems)):
            data.append(self.get_graph_step(loc_0[i], vel_0[i], charges[i], loc_t[i]).to('cpu'))

        return data
    

    def get_graph_step(self, loc_0, vel_0, charges, loc_t):        
        if self.partition == 'test':
            angle = 2 * torch.pi * torch.rand([3], device=self.device)
            rotate_matrix = matrix_x(angle[0]) @ matrix_y(angle[1]) @ matrix_z(angle[2])
            loc_0, loc_t, vel_0 = loc_0 @ rotate_matrix, loc_t @ rotate_matrix, vel_0 @ rotate_matrix

        # Edge
        edge_index = self.cutoff_edge(loc_0)
        row, col = edge_index
        edge_attr = torch.cat([
            torch.norm(loc_0[row] - loc_0[col], dim=1, keepdim=True),
            charges[row] * charges[col]
        ], dim=1)

        # Node Feat
        node_feat = torch.cat([
            torch.norm(vel_0, dim=1, keepdim=True), 
            charges / charges.max(),
        ], dim=1)

        return Data(
            node_feat=node_feat,
            node_pos=loc_0, 
            node_vel=vel_0,
            edge_index=edge_index,
            edge_attr=edge_attr, 
            label=loc_t, 
        )


    def cutoff_edge(self, loc_0):
        # Complete Graph and Cutoff    
        num_node_r = loc_0.size(0)
        dist = torch.cdist(loc_0, loc_0, p=2)  # [num_node_r, num_node_r]
        dist += torch.eye(num_node_r).to(loc_0.device) * 1e18  # [num_node_r, num_node_r]
        num_edge_rr_chosen = int(num_node_r * (num_node_r - 1) * (1 - self.cutoff_rate))
        _, id_chosen = torch.topk(dist.view(num_node_r * num_node_r), num_edge_rr_chosen, dim=0, largest=False)
        edge_rr = torch.cat([
            id_chosen.div(num_node_r, rounding_mode='trunc').unsqueeze(0), 
            id_chosen.remainder(num_node_r ).unsqueeze(0)
        ], dim=0).long()  # [2, num_edge_rr_chosen]
        return edge_rr
    
