import torch
import pickle
import numpy as np
from torch.utils import data
from torch.utils.data import DataLoader
from os.path import join as pjoin


class DiMoP3D_Dataset(data.Dataset):
    def __init__(self, data_root, split="train"):
        data_root = data_root or "/dataset/GIMO/SLICES_8s_fps20/"
        self.mean = torch.Tensor(np.load("/dataset/HumanML3D/Mean.npy")).float()        # shape [263,]
        self.std = torch.Tensor(np.load("/dataset/HumanML3D/Std.npy")).float()
        self.fixed_length = 160 - 1

        assert split in ["train", "test"]
        self.motion = torch.load(pjoin(data_root, split, "new_joint_vecs.pth"))
        self.raw_joints = torch.load(pjoin(data_root, split, "joints.pth"))
        self.scene_height = pickle.load(open(pjoin(data_root, split, "scene_heightmap.pkl"), 'rb'))
        self.scene_base = torch.load(pjoin(data_root, split, "scene_heightmap_base.pth"))
        self.scene_points = torch.load(pjoin(data_root, split, "scene_points.pth"))
        self.scene_feats = torch.load(pjoin(data_root, split, "scene_feats.pth"))
        self.objects = torch.load(pjoin(data_root, split, "scene_objects.pth"))
        self.recover_args = torch.load(pjoin(data_root, split, "recover.pth"))

        with open(pjoin(data_root, split, "sample_name.txt"), 'r') as file:
            self.sample_name = [line.strip() for line in file]

    def __len__(self):
        return len(self.motion)

    def __getitem__(self, item):
        motion = self.motion[item]
        motion = (motion - self.mean) / self.std
        raw_joints = self.raw_joints[item]
        scene_height = torch.Tensor(self.scene_height[item])
        scene_base = self.scene_base[item]
        recover = self.recover_args[item]
        return motion.T.float().unsqueeze(1), raw_joints, scene_height, scene_base, recover, self.sample_name[item], \
               self.scene_feats[item], self.objects[item]

    def inv_transform(self, data):
        return data * self.std + self.mean


def lengths_to_mask(lengths, max_len):
    # max_len = max(lengths)
    mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
    return mask


def collate_tensors(batch):
    dims = batch[0].dim()
    max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
    size = (len(batch),) + tuple(max_size)
    canvas = batch[0].new_zeros(size=size)
    for i, b in enumerate(batch):
        sub_tensor = canvas[i]
        for d in range(dims):
            sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
        sub_tensor.add_(b)
    return canvas


def collate(batch):
    notnone_batches = [b for b in batch if b is not None]
    databatch = [b['inp'] for b in notnone_batches]
    if 'lengths' in notnone_batches[0]:
        lenbatch = [b['lengths'] for b in notnone_batches]
    else:
        lenbatch = [len(b['inp'][0][0]) for b in notnone_batches]

    databatchTensor = collate_tensors(databatch)
    lenbatchTensor = torch.as_tensor(lenbatch)
    maskbatchTensor = lengths_to_mask(lenbatchTensor, databatchTensor.shape[-1]).unsqueeze(1).unsqueeze(
        1)  # unqueeze for broadcasting

    motion = databatchTensor
    cond = {'y': {'mask': maskbatchTensor, 'lengths': lenbatchTensor}}

    if 'text' in notnone_batches[0]:
        textbatch = [b['text'] for b in notnone_batches]
        cond['y'].update({'text': textbatch})

    if 'tokens' in notnone_batches[0]:
        textbatch = [b['tokens'] for b in notnone_batches]
        cond['y'].update({'tokens': textbatch})

    if 'action' in notnone_batches[0]:
        actionbatch = [b['action'] for b in notnone_batches]
        cond['y'].update({'action': torch.as_tensor(actionbatch).unsqueeze(1)})

    # collate action textual names
    if 'action_text' in notnone_batches[0]:
        action_text = [b['action_text'] for b in notnone_batches]
        cond['y'].update({'action_text': action_text})

    return motion, cond


# an adapter to our collate func
def t2m_collate(batch):
    # batch.sort(key=lambda x: x[3], reverse=True)
    adapted_batch = [{
        'inp': torch.tensor(b[4].T).float().unsqueeze(1),  # [seqlen, J] -> [J, 1, seqlen]
        'text': b[2],  # b[0]['caption']
        'tokens': b[6],
        'lengths': b[5],
    } for b in batch]
    return collate(adapted_batch)


def get_dimop3d_dataset_loader(data_dir, batch_size, split='train'):
    dataset = DiMoP3D_Dataset(data_dir, split=split)

    training = (split == 'train')
    batch_size = batch_size if training else 1

    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=training,
        num_workers=8, drop_last=training,
    )

    return loader
