# Copyright (2024) Bytedance Ltd. and/or its affiliates

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import random
import warnings
import traceback
import argparse
from omegaconf import OmegaConf
from tqdm import tqdm
from torchvision import transforms as T
import torch
from torch.utils.data import Dataset,DataLoader
import numpy as np
import imageio
from decord import VideoReader, cpu
from concurrent.futures import ThreadPoolExecutor, as_completed
from einops import rearrange

# from mdt.datasets.utils.dataset_util import euler2rotm, rotm2euler
# from mdt.datasets.utils.video_transforms import Resize_Preprocess, ToTensorVideo
# from mdt.datasets.utils.util import update_paths
from scipy.spatial.transform import Rotation as R  
import decord

class Dataset_mix(Dataset):
    def __init__(
            self,
            args,
            mode = 'val',
    ):
        """Constructor."""
        super().__init__()
        self.args = args
        self.mode = mode

        # dataset stucture
        # dataset_dir/dataset_name/annotation_name/mode/traj
        # dataset_dir/dataset_name/video/mode/traj
        # dataset_dir/dataset_name/latent_video/mode/traj

        # samles:{'ann_file':xxx, 'frame_idx':xxx, 'dataset_name':xxx}

        # prepare all datasets path
        self.video_path_all = []
        self.samples_all = []
        self.samples_len = []
        self.norm_all = []


        data_root_path = args.data_root_path
        dataset_names = args.dataset_names.split('+')
        dataset_cfgs = args.dataset_cfgs.split('+')
        self.prob = args.prob
        for dataset_name, dataset_cfg in zip(dataset_names, dataset_cfgs):
            data_json_path = f'exp_cfg/{dataset_cfg}/{mode}_sample.json'
     
            with open(data_json_path, "r") as f:
                samples = json.load(f)
            video_path = [os.path.join(data_root_path, sample['dataset_name']) for sample in samples]
            print(f"ALL dataset, {len(samples)} samples in total")
            self.video_path_all.append(video_path)
            self.samples_all.append(samples)
            self.samples_len.append(len(samples))

            # prepare normalization
            with open(f'exp_cfg/{dataset_name}/stat.json', "r") as f:
                data_stat = json.load(f)
                state_p01 = np.array(data_stat['state_01'])[None,:]
                state_p99 = np.array(data_stat['state_99'])[None,:]
                self.norm_all.append((state_p01, state_p99))
        
        self.max_id = max(self.samples_len)
        print('samples_len:',self.samples_len, 'max_id:',self.max_id)

    def __len__(self):
        return self.max_id

    def _load_latent_video(self, video_path, frame_ids):
        # video_path = video_path.split('/')[:-1]
        # video_path = '/'.join(video_path)+'/0.pt'
        
        # print(video_path)
        with open(video_path,'rb') as file:
            video_tensor = torch.load(file)
            video_tensor.requires_grad = False
        # vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
        # print(video_tensor.size(),np.array(frame_ids))
        # try:
        #     assert (np.array(frame_ids) < video_tensor.size()[0]).all()
        #     assert (np.array(frame_ids) >= 0).all()
        # except:
        #     assert False
        
        max_frames = video_tensor.size()[0]
        frame_ids =  [int(frame_id) if frame_id < max_frames else max_frames-1 for frame_id in frame_ids]
        frame_data = video_tensor[frame_ids]
        return frame_data

    def _get_frames(self, label, frame_ids, cam_id, pre_encode, video_dir, use_img_cond=False):
        # directly load videos latent after svd-vae encoder
        assert cam_id is not None
        assert pre_encode == True
        if pre_encode: 
            video_path = label['latent_videos'][cam_id]['latent_video_path']
            video_path = os.path.join(video_dir,video_path)
            if 'calvin' in video_path:
                items = video_path.split('/')
                items[-1]='0.pt'
                video_path = '/'.join(items)
            try:
                frames = self._load_latent_video(video_path, frame_ids)
            except:
                video_path = video_path.replace("latent_videos", "latent_videos_svd")
                frames = self._load_latent_video(video_path, frame_ids)
        # load original videos
        else: 
            if use_img_cond:
                frame_ids = frame_ids[0]
            video_path = label['videos'][cam_id]['video_path']
            video_path = os.path.join(video_dir,video_path)
            # frames = self._load_video(video_path, frame_ids)
            # frames = mediapy.read_video(video_path)
            vr = decord.VideoReader(video_path)
            frames = vr[frame_ids].asnumpy()
            frames = torch.from_numpy(frames).permute(2,0,1).unsqueeze(0) # (frame, h, w, c) -> (frame, c, h, w)
            # resize the video to self.args.video_size
            frames = self.preprocess(frames)
        return frames

    def _get_obs(self, label, frame_ids, cam_id, pre_encode, video_dir):
        if cam_id is None:
            temp_cam_id = random.choice(self.cam_ids)
        else:
            temp_cam_id = cam_id
        frames = self._get_frames(label, frame_ids, cam_id = temp_cam_id, pre_encode = pre_encode, video_dir=video_dir)
        return frames, temp_cam_id

    def normalize_bound(
        self,
        data: np.ndarray,
        data_min: np.ndarray,
        data_max: np.ndarray,
        clip_min: float = -1,
        clip_max: float = 1,
        eps: float = 1e-8,
    ) -> np.ndarray:
        ndata = 2 * (data - data_min) / (data_max - data_min + eps) - 1
        return np.clip(ndata, clip_min, clip_max)

    def denormalize_bound(
        self,
        data: np.ndarray,
        data_min: np.ndarray,
        data_max: np.ndarray,
        clip_min: float = -1,
        clip_max: float = 1,
        eps=1e-8,
    ) -> np.ndarray:
        clip_range = clip_max - clip_min
        rdata = (data - clip_min) / clip_range * (data_max - data_min) + data_min
        return rdata

    def __getitem__(self, index):

        # dataset_id = random.randint(0, len(self.samples_all)-1)
        # sample id with self.prob
        dataset_id = np.random.choice(len(self.samples_all), p=self.prob)

        samples = self.samples_all[dataset_id]
        video_path = self.video_path_all[dataset_id]
        state_p01, state_p99 = self.norm_all[dataset_id]

        index = index % len(samples)
        sample = samples[index]
        sampled_video_dir = video_path[index]
        ann_file = sample['ann_file']
        if self.args.annotation_name not in ann_file:
            ann_file = ann_file.replace('annotation', self.args.annotation_name)
        ann_file = f'{sampled_video_dir}/{ann_file}'
        # print(self.args.annotation_name)
        # print(f'Load {ann_file} for index {index}, dataset_id {dataset_id}, video_path {sampled_video_dir}')
        with open(ann_file, "r") as f:
            label = json.load(f)
        
        joint_len = len(label['observation.state.joint_position'])-1
        frame_len = np.floor(joint_len / 3)
        skip = random.randint(1, 2)
        # skip_his = int(skip*4)
        # skip_his = 0 with prob = 0.15, which means no history frames
        skip_his = self.args.skip_his
        p = random.random()
        #  0 4 8 12
        if p < 0.15:
            skip_his = 0
        else:
            skip_his = self.args.skip_his
        
        # if random.random() < 0.15:
        #     skip_his = 0

        # frame id
        frame_ids = sample['frame_ids']
        frame_now = frame_ids[0]
        frame_ids_all = []
        
        for i in range(self.args.num_history,0,-1):
            frame_ids_all.append(int(frame_now - i*skip_his))
        frame_ids_all.append(frame_now)
        for i in range(1, self.args.num_frames):
            frame_ids_all.append(int(frame_now + i*skip))
        
        frame_ids_all = np.array(frame_ids_all)
        frame_ids_all = np.clip(frame_ids_all, 0, frame_len).tolist()
        frame_ids_all = [int(frame_id) for frame_id in frame_ids_all]

        assert len(frame_ids_all) == self.args.num_frames + self.args.num_history, f"frame_ids_all length {len(frame_ids_all)} != {self.args.num_frames + self.args.num_history}"


        # condition id
        # action_id = np.array(frame_ids_all[self.args.num_history+1:])*3-1
        # action_id = np.clip(action_id, 0, joint_len).tolist()
        # obs_id = np.array(frame_ids_all[:self.args.num_history+1])*3

        obs_id = np.array(frame_ids_all)*3


        data = dict()
        data['text'] = label['texts'][0]

        # video data
        cond_cam_id1 = 0
        cond_cam_id2 = 1
        cond_cam_id3 = 2
        latnt_cond1,_ = self._get_obs(label, frame_ids_all, cond_cam_id1, pre_encode=True, video_dir=sampled_video_dir)
        latnt_cond2,_ = self._get_obs(label, frame_ids_all, cond_cam_id2, pre_encode=True, video_dir=sampled_video_dir)
        latnt_cond3,_ = self._get_obs(label, frame_ids_all, cond_cam_id3, pre_encode=True, video_dir=sampled_video_dir)
        latent = torch.zeros((self.args.num_frames+self.args.num_history, 4, 72, 40), dtype=torch.float32)
        latent[:,:,0:24] =  latnt_cond1
        latent[:,:,24:48] = latnt_cond2
        latent[:,:,48:72] = latnt_cond3
        data['latent'] = latent.float()

        # prepare action data
        # print(ann_file, np.array(label['observation.state.joint_position']).shape, [int(frame_start*3),action_id[0]])
        obs_joint = np.array(label['observation.state.cartesian_position'])[obs_id]
        obs_gripper = np.array(label['observation.state.gripper_position'])[obs_id][..., np.newaxis]

        # action_gripper = np.array(label['observation.action.gripper_position'])[obs_id][..., np.newaxis]
        # obs_gripper = np.concatenate((obs_gripper[:self.args.num_history+1], action_gripper[self.args.num_history+1:]), axis=0)

        # action_joint = np.array(label['action.joint_position'])[action_id]
        # action_gripper = np.array(label['action.gripper_position'])[action_id][..., np.newaxis]
        # print(obs_joint.shape, obs_gripper.shape, action_joint.shape, action_gripper.shape)
        obs = np.concatenate((obs_joint, obs_gripper), axis=-1)
        # action = np.concatenate((action_joint, action_gripper), axis=-1)
        # action = np.concatenate((obs, action), axis=0)
        action = obs
        action = self.normalize_bound(action, state_p01, state_p99)
        data['action'] = torch.tensor(action).float()

        return data
        

if __name__ == "__main__":
    from hydra import compose, initialize
    from omegaconf import OmegaConf
    with initialize(config_path="../../conf", job_name="VPP_xhand_train.yaml"):
        cfg = compose(config_name="VPP_xhand_train")
    
    # import sys
    # sys.path.append('/cephfs/cjyyj/code/video_robot_svd-main/mdt')
    # from utils.util import get_args
    # train_args = get_args(cfg.datamodule.args)
    # print(train_args)
    train_dataset = Dataset_mix(cfg.dataset_args,mode="val")
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.dataset_args.batch_size,
        shuffle=cfg.dataset_args.shuffle,
    )
    for data in tqdm(train_loader,total=len(train_loader)):
        print(data['ann_file'])

    