import numpy as np
import torch
import pickle as pkl
import os


class MotionDataset():
    """
    Motion Dataset

    """

    def __init__(self, partition, max_samples=1e8, delta_frame=30, data_dir='motion/raw_data/'):
        with open(os.path.join(data_dir, 'motion.pkl'), 'rb') as f:
            edges, X = pkl.load(f)
        print(edges)
        
        print(len(edges))

        V = []
        for i in range(len(X)):
            V.append(X[i][1:] - X[i][:-1])
            X[i] = X[i][:-1]


        N = X[0].shape[1] # num_vertex

        train_case_id = [20, 1, 17, 13, 14, 9, 4, 2, 7, 5, 16]
        val_case_id = [3, 8, 11, 12, 15, 18]
        test_case_id = [6, 19, 21, 0, 22, 10]


        split_dir = os.path.join(data_dir, 'split.pkl')

        self.partition = partition

        try:
            with open(split_dir, 'rb') as f:
                print('Got Split!')
                split = pkl.load(f)
        except:
            np.random.seed(100)

            # sample 100 for each case
            itv = 300
            train_mapping = {}
            for i in train_case_id:
                # cur_x = X[i][:itv]
                sampled = np.random.choice(np.arange(itv), size=200, replace=False)
                train_mapping[i] = sampled
            val_mapping = {}
            for i in val_case_id:
                # cur_x = X[i][:itv]
                sampled = np.random.choice(np.arange(itv), size=200, replace=False)
                val_mapping[i] = sampled
            test_mapping = {}
            for i in test_case_id:
                # cur_x = X[i][:itv]
                sampled = np.random.choice(np.arange(itv), size=200, replace=False)
                test_mapping[i] = sampled

            with open(split_dir, 'wb') as f:
                pkl.dump((train_mapping, val_mapping, test_mapping), f)

            print('Generate and save split!')
            split = (train_mapping, val_mapping, test_mapping)

        if partition == 'train':
            mapping = split[0]
        elif partition == 'val':
            mapping = split[1]
        elif partition == 'test':
            mapping = split[2]
        else:
            raise NotImplementedError()

        each_len = max_samples // len(mapping) # 

        print("each len:", each_len)
        x_0, v_0, x_t, v_t = [], [], [], []
        for i in mapping:
            st = mapping[i][:each_len]
            # print("st:", st)
            cur_x_0 = X[i][st] # (each_len, N, dim==3))
            cur_v_0 = V[i][st] # (each_len, N, dim==3))
            cur_x_t = X[i][st + delta_frame] # (each_len, N, dim==3))
            cur_v_t = V[i][st + delta_frame] # (each_len, N, dim==3))
            x_0.append(cur_x_0)
            v_0.append(cur_v_0)
            x_t.append(cur_x_t)
            v_t.append(cur_v_t)
        # print("x_0:", len(x_0), x_0[0].shape[0], x_0[0].shape[1], x_0[0].shape[2]) # (num_train, each_len, N, dim==3)
        # print("x_t:", len(x_t), x_t[0].shape[0], x_t[0].shape[1], x_t[0].shape[2])
        x_0 = np.concatenate(x_0, axis=0) # (num_train*each_len, N, dim==3)
        v_0 = np.concatenate(v_0, axis=0) # (num_train*each_len, N, dim==3)
        x_t = np.concatenate(x_t, axis=0) # (num_train*each_len, N, dim==3)
        v_t = np.concatenate(v_t, axis=0) # (num_train*each_len, N, dim==3)


        print('Got {:d} samples!'.format(x_0.shape[0]))

        self.n_node = N

        atom_edges = torch.zeros(N, N).int()
        for edge in edges:
            atom_edges[edge[0], edge[1]] = 1
            atom_edges[edge[1], edge[0]] = 1
        # print(atom_edges)

        self.atom_edge = atom_edges
        edge_attr = []
        # Initialize edges and edge_attributes
        rows, cols = [], []
        for i in range(N):
            for j in range(N):
                if i != j:
                    rows.append(i)
                    cols.append(j)
                    edge_attr.append(atom_edges[i][j])

        edges = [rows, cols]  # edges for equivariant message passing
        edge_attr = torch.Tensor(np.array(edge_attr)).unsqueeze(-1)  # [edge, 1]
        self.edge_attr = edge_attr  # [edge = 31*30,1]
        self.edges = edges  # [2, edge]

        self.x_0, self.v_0, self.x_t, self.v_t = torch.Tensor(x_0), torch.Tensor(v_0), torch.Tensor(x_t), torch.Tensor(
            v_t)
        mole_idx = np.ones(N)
        self.mole_idx = torch.Tensor(mole_idx)


    def __getitem__(self, i):
        return self.x_0[i], self.v_0[i], self.edge_attr, self.x_t[i], self.v_t[i]

    def __len__(self):
        return len(self.x_0)

    def get_edges(self, batch_size, n_nodes):
        edges = [torch.LongTensor(self.edges[0]), torch.LongTensor(self.edges[1])]
        if batch_size == 1:
            return edges
        elif batch_size > 1:
            rows, cols = [], []
            for i in range(batch_size):
                rows.append(edges[0] + n_nodes * i)
                cols.append(edges[1] + n_nodes * i)
            edges = [torch.cat(rows), torch.cat(cols)]
        return edges



if __name__ == '__main__':
    data = MotionDataset(partition='train', max_samples=1100, delta_frame=30, data_dir='raw_data')
    # print("data length:", len(data))
    # loader_train = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True, drop_last=True)
    # for batch_idx, data in enumerate(loader_train):
    #     batch_size, num_agents, _ = data[0].size()
        
    #     loc, vel, edge_attr, loc_end, vel_end = data
    #     print("loc shape:", loc.shape)
    #     print("vel shape:", vel.shape)
    #     print("edge shape:", edge_attr.shape)
    #     edges = loader_train.dataset.get_edges(batch_size, num_agents)
    #     print(edges)
        
    