"""
NOTE: Explanation of the data format can be found here: https://github.com/microsoft/psi/tree/master/Sources/MixedReality/HoloLensCapture/HoloLensCaptureExporter

hand for 26 joints: relative_time_stamp global_time_stamp data_active_bool (26 4x4 coordinate matrices) (26 valid_bools) (26 tracked bools)
if valid_bool is fale or tracked is fales ignore data (zero it out or something)

"""

import os
import torch
from torch.utils.data import Dataset
from torchvision.io import read_video
from torchvision.transforms.functional import normalize
import torch.nn.utils.rnn as rnn_utils
import scipy.io
import pandas as pd
import torch.nn.functional as F
from moviepy.editor import VideoFileClip

import json

TORCH_INTERPOLATE = False

class Holo(Dataset):
    def __init__(self, split_file, video_length=50, imu_length=180, transform=None, base_path ="/home/ANON/data/HoloAssist", crop_frames=True, return_path=False, print_error=False, test=False):
        self.split_file = split_file
        self.transform = transform
        self.clips = []
        self.vid_length = video_length
        self.imu_length = imu_length
        self.return_path = return_path
        self.crop_frames = crop_frames
        
        self.num_classes = 1887

        # Read actions into a label dictionary from the fine_grained_actions_map.txt file
        self.label_dict = {}
        with open(os.path.join(base_path, "fine_grained_actions_map.txt"), 'r') as f:
            for line in f.readlines():
                label, action = line.strip().split()
                self.label_dict[action] = int(label)

        if test==False:
            #assume we are reading "video_name" from a file
            f = open(self.split_file,'r')
            video_names = []
            for line in f.readlines():
                video_names.append(line.strip())
            f.close()

            #read the json file
            json_file = os.path.join(base_path, "data-annotation-trainval-v1_1.json")
            with open(json_file) as f:
                data = json.load(f)
            data_dict = {}
            for elt in data:
                data_dict[elt['video_name']] = elt

            # now for each video we need to get all the clips
            for video in video_names:
                video_annot = data_dict[video]
                events = video_annot['events']
                rgb_path = os.path.join(base_path,"video_compress", video,"Export_py", "Video_compress.mp4")
                imu_paths = [os.path.join(base_path,"imu", video, "Export_py", "IMU", "Accelerometer_sync.txt"),
                            os.path.join(base_path,"imu", video, "Export_py", "IMU", "Gyroscope_sync.txt"),
                            os.path.join(base_path,"imu", video, "Export_py", "IMU", "Magnetometer_sync.txt")]

                fps = float(video_annot["videoMetadata"]["video"]["fps"])

                if not os.path.exists(rgb_path):
                    if print_error: print(f"Video file not found: {rgb_path}")
                    continue
                if not os.path.exists(imu_paths[0]):
                    if print_error: print(f"IMU file not found: {imu_paths[0]}")
                    continue
                
                for e in events:
                    if e['label']=="Fine grained action":
                        class_idx = self.label_dict[e["attributes"]["Verb"]+"-"+e["attributes"]["Noun"]]
                    else:
                        continue

                    if class_idx > self.num_classes or class_idx < 0:
                        raise ValueError(f"{class_idx} is an invalid class index")
                
                    start_time = e["start"]
                    end_time = e["end"]
                    

                    # test_video = VideoFileClip(rgb_path).subclip(start_time, end_time)

                    self.clips.append((rgb_path, imu_paths, class_idx, start_time, end_time, fps))  # (video path, IMU path, class index, start time, end time, fps)
        else:
            # print("in here", self.split_file)
            f = open(self.split_file,'r')
            video_names = []
            for line in f.readlines():
                # print(line)
                video_names.append(line.strip())
            f.close()
            for item in video_names:
                # print(item)
                #extracting the time stamps is a little complicated since some video names have an unddercore in them
                items = item.split("_") 
                if len(items)==3:
                    video, start, end = item.split("_")
                elif len(items)==4:
                    video_pt1, video_pt2, start, end = item.split("_")
                    video = video_pt1+"_"+video_pt2
                else:
                    print(f"Error: invalid video name format: {item}")
                    raise ValueError

                start = float(start)
                end = float(end)
                rgb_path = os.path.join(base_path,"video_compress", video,"Export_py", "Video_compress.mp4")
                imu_paths = [os.path.join(base_path,"imu", video, "Export_py", "IMU", "Accelerometer_sync.txt"),
                            os.path.join(base_path,"imu", video, "Export_py", "IMU", "Gyroscope_sync.txt"),
                            os.path.join(base_path,"imu", video, "Export_py", "IMU", "Magnetometer_sync.txt")]
                
                if not os.path.exists(rgb_path):
                    print_error: print(f"Video file not found: {rgb_path}")
                    continue
                if not os.path.exists(imu_paths[0]):
                    print(f"IMU file not found: {imu_paths[0]}")
                    continue

                class_idx = item
                self.clips.append((rgb_path, imu_paths, item, start, end, None))  # (video path, IMU path, class index, start time, end time, fps)

    def __len__(self):
        return len(self.clips)//10000
        # return len(self.clips)

    def __getitem__(self, idx):
        rgb_path, imu_paths, class_idx, start_time, end_time, fps = self.clips[idx]
        frames, audio, info = read_video(rgb_path, pts_unit="sec") # Tensor Shape THWC: 57,480,640,3

        if fps==None:
            try:
                fps = info['video_fps']
            except KeyError:
                print(f"Error: fps not found for video:{rgb_path}")
                raise KeyError

        # # Camera is about 29.5 fps, but each video is different
        start_frame = int(start_time*fps)
        end_frame = int(end_time*fps)
        t = frames.shape[0]
        if start_frame>=t or end_frame>=t:
            print(f"Error: start_frame:{start_frame} or end_frame:{end_frame} is greater than total frames:{t}, for video:{rgb_path} from {start_time} to {end_time}")
            print("Skipping to next item")
            return self.__getitem__(idx+1)
        frames = frames[start_frame:end_frame]


        """ VIDEO PROCESSING"""
        # # Normalize video frames (you can adjust the mean and std)
        # frames = normalize(frames.float(), mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
        
        # frames = torch.stack(frames) # conver list to tensor
        # print(len(frames),frames[0].shape)
        if self.crop_frames:
            if TORCH_INTERPOLATE:
                #try using torch interpolate NOPE Takes tooo longgg
                # permute to CTHW
                frames = frames.permute(3,0,1,2).float()
                frames = F.interpolate(frames.unsqueeze(0), size=(self.vid_length, frames.shape[2], frames.shape[3]), mode='trilinear', align_corners=False)
                frames = frames.squeeze(0) #i think we need to add and remove batch dim for interpolate
                frames = frames.permute(1,0,2,3) # permute to TCHW to perform image-wise transforms
            else:
                # Pad or shorten video to vid_length
                t,h,w,c = frames.shape

                # downsample time by half
                frames = frames[::2,:,:,:].clone()

                t,h,w,c = frames.shape
                if t>self.vid_length:
                    frames = frames[:self.vid_length,:,:,:]
                elif t<self.vid_length:
                    # Pad frames with zeros to make them the same length
                    frames = torch.cat([frames, torch.zeros(self.vid_length - len(frames), *frames.shape[1:])]) 
                frames = frames.permute(0,3,1,2) # permute to TCHW to perform image-wise transforms
        else:
            frames = frames.permute(0,3,1,2) # permute to TCHW to perform image-wise transforms

        #perform tansforms on each frame
        if self.transform:
            frames = torch.stack([self.transform(frame) for frame in frames])

        # Time, channel makes more sense bc channel describes img, and t describes multiple imgs
        # # # Permute to CTHW for 3d convs
        # frames = frames.permute(1,0,2,3)
        
        # return frames, class_idx #returns TCHW and class idx

        """IMU PROCESSING"""

        acc_file = open(imu_paths[0],'r')
        gyro_file = open(imu_paths[1],'r')
        mag_file = open(imu_paths[2],'r')
        accel_data = []

        # i'm assuming the file is in format: relative_time, absolute_time, x, y, z
        # but i can't find any documentation confirming this.
        for acc_line in acc_file.readlines():
            acc_data = acc_line.strip().split()
            #read gyro and mag data in parallel
            gyro_line = gyro_file.readline().strip().split()
            mag_line = mag_file.readline().strip().split()
            assert acc_data[0]==gyro_line[0]==mag_line[0], f"Time mismatch in IMU data:{acc_data[0]} {gyro_line[0]} {mag_line[0]}"
            time = float(acc_data[0])
            if time >= start_time and time <= end_time:
                accel_data.append([float(acc_data[2]), float(acc_data[3]), float(acc_data[4]), float(gyro_line[2]), float(gyro_line[3]), float(gyro_line[4]), float(mag_line[2]), float(mag_line[3]), float(mag_line[4])])
        acc_file.close()
        gyro_file.close()
        mag_file.close()

        accel_data = torch.tensor(accel_data).float()

        if self.crop_frames:
            if TORCH_INTERPOLATE:
                #let's try using torch to interpolate the data
                accel_data = accel_data.permute(1,0) # permute to CT to interpolate
                accel_data = F.interpolate(accel_data.unsqueeze(0), size=(self.imu_length), mode='linear', align_corners=False)
                accel_data = accel_data.squeeze(0) #i think we need to add and remove batch dim for interpolate
                accel_data = accel_data.permute(1,0) # permute back to TC
            else:
                t,xyz = accel_data.shape
                t,xyz = accel_data.shape
                if t>self.imu_length:
                    accel_data = accel_data[:self.imu_length,:]
                elif t<self.imu_length:
                    # Pad accel_data with zeros to make them the same length
                    accel_data = torch.cat([accel_data, torch.zeros(self.imu_length - len(accel_data), *accel_data.shape[1:])])
        #no need to permute, already in txc shape

        # return accel_data, int(label), file_path
        frames=frames.float()
        accel_data=accel_data.float()
        pid_idx = 0 # dummy for existing code
        if self.return_path:
            return frames, accel_data, class_idx, pid_idx, rgb_path, imu_paths
        else:
            # print("Called dataset: returning:", frames.shape, accel_data.shape, class_idx, pid_idx)
            return frames, accel_data, class_idx, pid_idx #returns TCHW video , TC IMU

if __name__=='__main__':
    split_file = "/home/ANON/data/HoloAssist/val-v1.txt"
    split_file = "/home/ANON/data/HoloAssist/train-v1.txt"
    d = Holo(split_file, base_path="/home/ANON/data/HoloAssist", crop_frames=False, print_error=True)

    
    video_lengths = []
    IMU_lengths = []
    length = len(d)
    for idx,itm in enumerate(d):
        print(f"{idx}/{length} Input RGB:", itm[0].shape, "Input IMU:", itm[1].shape, "action label:", itm[2])
        video_lengths.append(itm[0].shape[0])
        IMU_lengths.append(itm[1].shape[0])
        
        if idx%10==0:
            print("running average video length:", sum(video_lengths)/len(video_lengths))
            print("running average IMU length:", sum(IMU_lengths)/len(IMU_lengths))
    

    print(len(d))
    print("average video length:", sum(video_lengths)/len(video_lengths))
    print("average IMU length:", sum(IMU_lengths)/len(IMU_lengths))

    # itm = d[58]
    # print("Input RGB:", itm[0].shape, "Input IMU:", itm[1].shape, "action label:", itm[2], "PID label:", itm[3])


    # base_path = "/home/ANON/data/HoloAssist"
    # # FOR THE PURPOSES OF action recognition we use the same train and train_2 loaders
    # train_dir = "/home/ANON/data/HoloAssist/train-v1.txt"
    # train_2_dir = "/home/ANON/data/HoloAssist/train-v1.txt"
    # val_dir = "/home/ANON/data/HoloAssist/val-v1.txt"
    # test_dir = "/home/ANON/data/HoloAssist/test-v1.txt"
    # num_workers = 8
    # train_dataset = Holo(train_dir, base_path = base_path, video_length=30) #, transform=transforms)
    # # sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=local_rank)
    # train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, num_workers=num_workers, drop_last=True) #, sampler=sampler) #drop_last is for multi-gpu training
    # for ind, data in enumerate(train_loader):
    #     print(ind, data[0].shape, data[1].shape, data[2].shape, data[3].shape)
        


    # For ANON-ANON
    # len = 366
    # average video length: 52.55668604651163 (in training)
    # average IMU length: 178.31 (in training)

    # For mmea
    # len = 3160
    # average video length: ~400
    # average IMU length: ~400

    # For holo
    # len = 18386 for val
    # len = 130867 for train 124759 after not found videos
    # average video length: ~84
    # average IMU length: ~86

    
