import numpy as np
import torch
import random

class NBodyDataset():
    def __init__(self, partition='train', max_samples=1e8, dataset_name="nbody", T = 8):
        self.partition = partition
        if self.partition == 'val':
            self.sufix = 'valid'
        else:
            self.sufix = self.partition
        self.dataset_name = dataset_name
        if dataset_name == "nbody":
            self.sufix += "_charged5_initvel1"

        elif dataset_name == "nbody_small" or dataset_name == "nbody_small_out_dist":
            self.sufix += "_charged5_initvel1"
        
        elif dataset_name == "nbody_large":
            self.sufix += "_charged10_initvel1large" 
        else:
            raise Exception("Wrong dataset name %s" % self.dataset_name)

        self.max_samples = int(max_samples)
        self.dataset_name = dataset_name
        self.data, self.edges = self.load()
        self.frame_0 = 0
        self.frame_T = 10

    def load(self):
        # loc = np.load('n_body_system/dataset/loc_' + self.sufix + '.npy')
        # vel = np.load('n_body_system/dataset/vel_' + self.sufix + '.npy')
        # edges = np.load('n_body_system/dataset/edges_' + self.sufix + '.npy')
        # charges = np.load('n_body_system/dataset/charges_' + self.sufix + '.npy')
        loc = np.load('dataset/loc_' + self.sufix + '.npy')
        vel = np.load('dataset/vel_' + self.sufix + '.npy')
        edges = np.load('dataset/edges_' + self.sufix + '.npy')
        charges = np.load('dataset/charges_' + self.sufix + '.npy')

        loc, vel, edge_attr, edges, charges = self.preprocess(loc, vel, edges, charges)
        return (loc, vel, edge_attr, charges), edges # loc(max_samples, time_intervals, num_agents, dim_coord)


    def preprocess(self, loc, vel, edges, charges):
        # cast to torch and swap n_nodes <--> n_features dimensions
        loc, vel = torch.Tensor(loc).transpose(2, 3), torch.Tensor(vel).transpose(2, 3)
        n_nodes = loc.size(2)
        loc = loc[0:self.max_samples, :, :, :]  # limit number of samples
        vel = vel[0:self.max_samples, :, :, :]  # speed when starting the trajectory
        charges = charges[0:self.max_samples]
        edge_attr = []
        
        #Initialize edges and edge_attributes
        rows, cols = [], []
        for i in range(n_nodes):
            for j in range(n_nodes):
                if i != j:
                    edge_attr.append(edges[:, i, j])
                    rows.append(i)
                    cols.append(j)
        edges = [rows, cols]
        edge_attr = torch.Tensor(edge_attr).transpose(0, 1).unsqueeze(2) # swap n_nodes <--> batch_size and add nf dimension
        print(edge_attr.shape)

        return torch.Tensor(loc), torch.Tensor(vel), torch.Tensor(edge_attr), edges, torch.Tensor(charges)
    
    def __getitem__(self, i):
        loc, vel, edge_attr, charges = self.data
        loc, vel, edge_attr, charges = loc[i], vel[i], edge_attr[i], charges[i]
        return loc[self.frame_0], vel[self.frame_0], edge_attr, charges, loc[self.frame_T]
    """
    def __getitem__(self, i):
        loc, vel, edge_attr, charges = self.data
        loc, vel, edge_attr, charges = loc[i], vel[i], edge_attr[i], charges[i] # ith single trajectory 

        # loc[0] (batchsize, num_agents, dim_coord)
        # edge_attr (batchsize, num_edges, dim_attr == 1)
        # (maxsamples, num_agents, dim_attr == 1)
        # loc (batchsize, time_intervals, num_agents, dim_coord)
        # vel (batchsize, time_intervals, num_agents, dim_coord)
        return loc[0], vel[0], edge_attr, charges, loc, vel # initial loc, initial vel, and whole trajectory loc, whole vel
    """
    def __len__(self):
        return len(self.data[0])
    
    def get_edges(self, batch_size, n_nodes):
        edges = [torch.LongTensor(self.edges[0]), torch.LongTensor(self.edges[1])]
        if batch_size == 1:
            return edges
        elif batch_size > 1:
            rows, cols = [], []
            for i in range(batch_size):
                rows.append(edges[0] + n_nodes * i)
                cols.append(edges[1] + n_nodes * i)
            edges = [torch.cat(rows), torch.cat(cols)]
        return edges

if __name__ == "__main__":
    NBodyDataset()