import numpy as np
import torch
import random

class ConForDataset():
    def __init__(self, partition='train', level='difficult', max_samples=7000, num_agents=5):
        self.partition = partition
        self.level = level
        if self.partition == 'val':
            self.sufix = 'valid'
        else:
            self.sufix = self.partition
        
        self.sufix += "consensus"
        self.num_agents = num_agents
        
        if self.num_agents == 5:
            self.sufix += "5"
        elif self.num_agents == 10:
            self.sufix += "10"
        else:
            raise Exception("Wrong agents number %s" % self.num_agents)

        self.max_samples = int(max_samples)
        self.data, self.edges = self.load()
        self.frame_0 = 0
        self.frame_T = 600

    def load(self):
        loc = np.load('con_formation/' + self.level + '/loc_' + self.sufix + '.npy') # (train_nums, time_intervals, num_agents, dim_coord)
        vel = np.load('con_formation/' + self.level + '/vel_' + self.sufix + '.npy')
        edges = np.load('con_formation/' + self.level + '/edges_' + self.sufix + '.npy')

        loc, vel, edge_attr, edges = self.preprocess(loc, vel, edges)
        print("loc shape:", loc.shape)
        print("vel shape:", vel.shape)
        print("edge_attr shape:", edge_attr.shape)
        return (loc, vel, edge_attr), edges # loc(max_samples, time_intervals, num_agents, dim_coord)


    def preprocess(self, loc, vel, edges):
        loc, vel = torch.Tensor(loc), torch.Tensor(vel)
        n_nodes = loc.size(2)
        loc = loc[0:self.max_samples, :, :, :]  # limit number of samples
        vel = vel[0:self.max_samples, :, :, :]  # speed 
        edge_attr = []

        #Initialize edges and edge_attributes
        rows, cols = [], []
        for i in range(n_nodes):
            for j in range(n_nodes):
                if i != j:
                    edge_attr.append(edges[:, i, j])
                    rows.append(i)
                    cols.append(j)
        edges = [rows, cols]
        edge_attr = torch.Tensor(edge_attr).transpose(0, 1).unsqueeze(2) # swap n_nodes <--> batch_size and add nf dimension

        return torch.Tensor(loc), torch.Tensor(vel), torch.Tensor(edge_attr), edges
    
    def __getitem__(self, i):
        loc, vel, edge_attr = self.data
        loc, vel, edge_attr = loc[i], vel[i], edge_attr[i]
        return loc[self.frame_0], vel[self.frame_0], edge_attr, loc[self.frame_T], vel[self.frame_T]

    def __len__(self):
        return len(self.data[0])
    
    def get_edges(self, batch_size, n_nodes):
        edges = [torch.LongTensor(self.edges[0]), torch.LongTensor(self.edges[1])]
        if batch_size == 1:
            return edges
        elif batch_size > 1:
            rows, cols = [], []
            for i in range(batch_size):
                rows.append(edges[0] + n_nodes * i)
                cols.append(edges[1] + n_nodes * i)
            edges = [torch.cat(rows), torch.cat(cols)]
        return edges

if __name__ == "__main__":
    dataset = ConForDataset()
    # ConForDataset()