# mainly borrowed from https://github.com/OpenMotionLab/MotionGPT/blob/main/mGPT/data/humanml/dataset_t2m.py

import os
import rich
import random
import pickle
import codecs as cs
import numpy as np
from torch.utils import data
from rich.progress import track
from os.path import join as pjoin


class Text2MotionDataset(data.Dataset):

    def __init__(
        self,
        data_root,
        split,
        mean,
        std,
        max_motion_length=196,
        min_motion_length=40,
        unit_length=4,
        fps=20,
        tmpFile=True,
        tiny=False,
        debug=False,
        motion_dir=None,
        text_dir=None,
        code_path='VQVAE',
        **kwargs,
    ):

        # restrian the length of motion and text
        self.max_length = 20
        self.max_motion_length = max_motion_length
        self.min_motion_length = min_motion_length
        self.unit_length = unit_length

        # Data mean and std
        self.mean = mean
        self.std = std

        # Data path
        split_file = pjoin(data_root, split + '.txt')
        code_dir = pjoin(data_root, code_path)
        motion_dir = motion_dir if motion_dir is not None else pjoin(data_root, 'new_joint_vecs')
        text_dir = text_dir if text_dir is not None else pjoin(data_root, 'texts')

        # Data id list
        self.id_list = []
        with cs.open(split_file, "r") as f:
            for line in f.readlines():
                self.id_list.append(line.strip())

        # Debug mode
        if tiny or debug:
            enumerator = enumerate(self.id_list)
            maxdata = 100
            subset = '_tiny'
        else:
            enumerator = enumerate(
                track(
                    self.id_list,
                    f"Loading HumanML3D {split}",
                ))
            maxdata = 1e10
            subset = ''

        new_name_list = []
        length_list = []
        data_dict = {}
    
        for idx, name in enumerator:
            if len(new_name_list) > maxdata:
                break
            try:
                motion = np.load(pjoin(motion_dir, name + ".npy"))
                if (len(motion)) < self.min_motion_length or (len(motion)
                                                                >= self.max_motion_length):
                    continue

                # try:
                m_token_list = np.load(pjoin(code_dir, f'{name}.npy'))
                # except:
                #     continue

                # Read text
                text_data = []
                flag = False
                with cs.open(pjoin(text_dir, name + '.txt')) as f:
                    lines = f.readlines()
                    for line in lines:
                        text_dict = {}
                        line_split = line.strip().split('#')
                        caption = line_split[0]
                        t_tokens = line_split[1].split(' ')
                        f_tag = float(line_split[2])
                        to_tag = float(line_split[3])
                        f_tag = 0.0 if np.isnan(f_tag) else f_tag
                        to_tag = 0.0 if np.isnan(to_tag) else to_tag

                        text_dict['caption'] = caption
                        text_dict['tokens'] = t_tokens
                        if f_tag == 0.0 and to_tag == 0.0:
                            flag = True
                            text_data.append(text_dict)
                        else:
                            m_token_list_new = [
                                tokens[int(f_tag * fps / unit_length
                                            ):int(to_tag * fps / unit_length)]
                                for tokens in m_token_list
                            ]
                            motion_new = motion[int(f_tag *
                                                    fps):int(to_tag * fps)]
                            if (len(motion_new)
                                ) < self.min_motion_length or (
                                    len(motion_new) >= 200):
                                continue
                            new_name = '%s_%f_%f' % (name, f_tag, to_tag)
                            data_dict[new_name] = {
                                'm_token_list': m_token_list_new,
                                'motion': motion_new,
                                "length": len(motion_new),
                                'text': [text_dict]
                            }
                            new_name_list.append(new_name)
                            length_list.append(len(motion_new))

                if flag:
                    data_dict[name] = {
                        'm_token_list': m_token_list,
                        'motion': motion,
                        "length": len(motion),
                        'text': text_data
                    }
                    new_name_list.append(name)
                    length_list.append(len(motion))
            except:
                pass

        name_list, length_list = zip(
            *sorted(zip(new_name_list, length_list), key=lambda x: x[1]))

        if tmpFile:
            os.makedirs(pjoin(data_root, 'tmp'), exist_ok=True)
            with open(pjoin(data_root, f'tmp/{split}{subset}_data.pkl'),
                        'wb') as file:
                pickle.dump(data_dict, file)
            with open(pjoin(data_root, f'tmp/{split}{subset}_index.pkl'),
                        'wb') as file:
                pickle.dump(name_list, file)

        self.length_arr = np.array(length_list)
        self.data_dict = data_dict
        self.name_list = name_list
        self.nfeats = data_dict[name_list[0]]['motion'].shape[1]
        self.reset_max_len(self.max_length)
        print(f'[Text2MotionDataset] {split} dataset loading done', len(self.data_dict))

    def reset_max_len(self, length):
        assert length <= self.max_motion_length
        self.pointer = np.searchsorted(self.length_arr, length)
        print("Pointer Pointing at %d" % self.pointer)
        self.max_length = length

    def __len__(self):
        return len(self.name_list) - self.pointer

    def __getitem__(self, item):
        idx = self.pointer + item
        fname = self.name_list[idx]
        data = self.data_dict[fname]
        motion, m_length, text_list = data["motion"], data["length"], data[
            "text"]

        # Randomly select a caption
        text_data = random.choice(text_list)
        caption = text_data["caption"]

        all_captions = [
            ' '.join([token.split('/')[0] for token in text_dic['tokens']])
            for text_dic in text_list
        ]

        # Crop the motions in to times of 4, and introduce small variations
        if self.unit_length < 10:
            coin2 = np.random.choice(["single", "single", "double"])
        else:
            coin2 = "single"

        if coin2 == "double":
            m_length = (m_length // self.unit_length - 1) * self.unit_length
        elif coin2 == "single":
            m_length = (m_length // self.unit_length) * self.unit_length

        idx = random.randint(0, len(motion) - m_length)
        motion = motion[idx:idx + m_length]

        # Z Normalization
        motion = (motion - self.mean) / self.std

        return caption, None, None, motion, m_length, None, None, None, None, all_captions, None, fname
