import numpy as np
import os
import pickle
import torch
from torch_geometric.utils import remove_self_loops
from .trajdata import TrajData
from .trajdataset import TrajDataset
from .motion_utils import load_data_cmu_3d_all, load_data_cmu_3d_n, load_data_cmu_3d_all_new
from utils.misc import CMU_Transform
from torch.utils.data import DataLoader

class DummyOpt:
    def __init__(self):
        self.cuda_idx = 'cuda:0'
        self.test_sample_num = -1


class CMU(TrajDataset):
    all_actions = ["basketball", "basketball_signal", "directing_traffic", "jumping",
                   "running", "soccer", "walking", "washwindow","pretrain"]

    def __init__(self, root, name, input_n, output_n, force_length=None,
                 return_index=False, scale=1, force_reprocess=False, act='all'):
        self.root = root
        self.input_n = input_n
        self.output_n = output_n
        self.scale = scale
        self.return_index = return_index
        self.force_length = force_length
        self.act = act
        assert act == 'all' or act in self.all_actions
        name = name + '_' + act

        super().__init__(root, name, force_reprocess)  # name: cmu_train, cmu_val, cmu_test
        print(f'{name} using {len(self)} data points.')

    def processed_file(self):
        return os.path.join(self.root, self.name + '_new.pt')

    def preprocess_raw(self):
        opt = DummyOpt()
        input_n, output_n = self.input_n, self.output_n
        if '_train' in self.name:
            print(self.name)
        # if self.name in ['cmu_train']:
            path_to_data = self.root + '/train/'
            is_test = False
        elif 'val' in self.name or 'test' in self.name:
        # elif self.name in ['cmu_val', 'cmu_test']:
            path_to_data = self.root + '/test/'
            is_test = True
        else:
            raise NotImplementedError()

        if self.act == 'all':
            actions = self.all_actions
        else:
            actions = [self.act]

        act_index_map = {}
        all_seqs, dim_ignore, dim_use, act_all = load_data_cmu_3d_all_new(opt, path_to_data, actions,
                                                                          input_n, output_n)
        for act in actions:
            act_index_map[act] = []
            for i in range(len(act_all)):
                if act_all[i] == act:
                    act_index_map[act].append(i)
            act_index_map[act] = np.array(act_index_map[act])

        # if not is_test:
        #     actions = self.all_actions
        #     act_index_map = {}
        #     all_seqs, dim_ignore, dim_use = load_data_cmu_3d_all(opt, path_to_data, actions,
        #                                                          input_n, output_n, is_test=is_test)
        # else:
        #     # Additionally acquire the action index map to compute per action loss
        #     all_seqs_ = []
        #     act_index_map = {}
        #     cnt = 0
        #     dim_ignore, dim_use = None, None
        #     for act in self.all_actions:
        #         all_seqs, dim_ignore, dim_use = load_data_cmu_3d_n(opt, path_to_data, [act],
        #                                                            input_n, output_n, is_test=is_test)
        #         all_seqs_.append(all_seqs)
        #         cur_cnt = all_seqs.shape[0]  # The number of cases for current action
        #         cur_index = np.arange(cnt, cnt + cur_cnt)
        #         act_index_map[act] = cur_index
        #         cnt += cur_cnt
        #     all_seqs = np.concatenate(all_seqs_, axis=0)

        with open(self.processed_path, 'wb') as f:
            pickle.dump((all_seqs, dim_use, act_index_map, act_all), f)
        # print(f'Data saved to {self.processed_path}')

    def postprocess(self):
        all_seqs, dim_use, act_index_map = self.data[0], self.data[1], self.data[2]
        # Convert into torch tensor
        x = torch.Tensor(all_seqs)
        dim_use = torch.Tensor(dim_use).long()
        self.x = x[..., dim_use].reshape(x.size(0), x.size(1), len(dim_use) // 3, 3)  # [B, T, N, 3]
        self.x = self.x * self.scale  # Attention: apply scaling here!
        v = torch.zeros_like(self.x)
        v[:, 1:, :, :] = self.x[:, 1:, :, :] - self.x[:, :-1, :, :]
        v[:, 0, :, :] = v[:, 1, :, :]
        self.v = v  # [B, T, N, 3]
        self.v = self.v * self.scale
        # TODO: try to add z-axis coordinate to self.h
        z = self.x[..., 1].unsqueeze(-1)
        self.h = torch.norm(self.v, dim=-1, keepdim=True)  # [B, T, N, 1]
        self.h = torch.cat((self.h, z), dim=-1)  # [B, T, N, 2]
        self.dim_use = dim_use
        self.act_index_map = act_index_map

        edges = [
            [5, 6], [6, 7], [7, 8],
            [1, 2], [2, 3], [3, 4],
            [5, 9], [1, 9],
            [9, 10],
            [10, 14], [14, 15], [15, 16], [16, 17], [17, 18], [16, 19],
            [10, 20], [20, 21], [21, 22], [22, 23], [23, 24], [22, 25],
            [10, 11], [11, 12], [12, 13],
        ]
        for i in range(len(edges)):
            edges[i][0] -= 1
            edges[i][1] -= 1
        self.edge_index, self.edge_attr = CMU_Transform(max_hop=3, fc=True)(edges, N=self.x.size(2))

    def __len__(self):
        return self.x.size(0) if self.force_length is None else min(self.x.size(0), self.force_length)

    def __getitem__(self, idx):
        x = self.x[idx]  # [T, N, 3]
        v = self.v[idx]  # [T, N, 3]
        h = self.h[idx]  # [T, N, 1]
        # num_nodes = x.size(1)
        # row = torch.arange(num_nodes, dtype=torch.long)
        # col = torch.arange(num_nodes, dtype=torch.long)
        # row = row.view(-1, 1).repeat(1, num_nodes).view(-1)
        # col = col.repeat(num_nodes)
        edge_index = self.edge_index
        # edge_index = torch.stack([row, col], dim=0)
        # # remove self loop
        # edge_index = remove_self_loops(edge_index)[0]

        h = torch.nn.functional.one_hot(torch.arange(x.size(1)).long(), 25)  # break permutation equivariance  [N, Hh=N]
        # h = h.unsqueeze(0).repeat(x.size(0), 1, 1)  # [T, N, Hh=N]
        # cond_h = torch.zeros(x.size(0), x.size(1), 2)  # [T, N, 2]
        # cond_h[:10, :, 0] = 1
        # cond_h[10:, :, 1] = 1  # TODO: remove hard coding here
        # aux_h = self.h[idx] * 1.  # [T, N, 2]  NOTE: need to mask the ungiven part of h here to prevent data leakage
        # aux_h[10:] = 0
        # h = torch.cat((h, cond_h, aux_h), dim=-1)  # [T, N, N+2+2] = 29
        # h = h.permute(1, 2, 0)  # [N, H, T]


        data = TrajData(x=x.permute(1, 2, 0), v=v.permute(1, 2, 0), h=h,
                        edge_index=edge_index, edge_attr=self.edge_attr)
        # if self.return_index:
        #     data['system_id'] = torch.ones(1) * idx
        if self.return_index:
            data.system_id = torch.tensor(idx, dtype=torch.long)

        return data


if __name__ == '__main__':
    
    batch_size = 16
    test_batch_size = 32
    data_dir = 'data/cmu'
    input_n = 10
    output_n = 25
    # all_actions = ["basketball","running","jumping","soccer", "walking" ,"pretrain"]
    all_actions =["basketball_signal", "directing_traffic","washwindow","pretrain"]
    name_list = ['cmu_test','cmu_train']
    for a in all_actions:
        for n in name_list:
            dataset = CMU(root=data_dir, name=n, input_n=input_n, output_n=output_n, return_index=False, scale=0.01,
                        act=a, force_reprocess=False)
         
            print('>>> Training dataset length: {:d}'.format(dataset.__len__()))
            # print(dataset.__getitem__(0))
        # print(dataset[100].x)
        # print(dataset.act_index_map)

