import os
import torch
import pandas as pd
import scipy
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torchvision.io as io


def train_test_split(root, save_root=None):
    if save_root is None:
        save_root = root
    depth_dir = os.path.join(root, 'depth_mat')
    sensor_dir = os.path.join(root, 'sensor_mat')
    skeleton_dir = os.path.join(root, 'skeleton_mat')

    depth_files = os.listdir(depth_dir)
    class_ids = set([f.split('_')[1] for f in depth_files])
    class_2_idx = {c: int(c[1:]) - 1 for i, c in enumerate(class_ids)}
    times_per_subject = list(set([f.split('_')[-1].split('.')[0] for f in depth_files]))
    tot_times = len(times_per_subject)
    # train_times, val_times, test_times = times_per_subject[:int(0.6*tot_times)], times_per_subject[int(0.6*tot_times):int(0.8*tot_times)], times_per_subject[int(0.8*tot_times):]
    train_times, val_times, test_times = times_per_subject[:int(0.8*tot_times)], times_per_subject[int(0.8*tot_times):int(0.9*tot_times)], times_per_subject[int(0.9*tot_times):]
    subjects = set([f.split('_')[0] for f in depth_files])

    train_dat = {'depth': [], 'sensor': [], 'skeleton': [], 'label': []}
    val_dat = {'depth': [], 'sensor': [], 'skeleton': [], 'label': []}
    test_dat = {'depth': [], 'sensor': [], 'skeleton': [], 'label': []}

    for s in subjects:
        for t in train_times:
            for c in class_ids:
                depth_file = os.path.join(depth_dir, f'{s}_{c}_{t}.mat')
                sensor_file = os.path.join(sensor_dir, f'{s}_{c}_{t}.mat')
                skeleton_file = os.path.join(skeleton_dir, f'{s}_{c}_{t}.mat')
                if os.path.exists(depth_file) and os.path.exists(sensor_file) and os.path.exists(skeleton_file):
                    train_dat['depth'].append(f'{s}_{c}_{t}.mat')
                    train_dat['sensor'].append(f'{s}_{c}_{t}.mat')
                    train_dat['skeleton'].append(f'{s}_{c}_{t}.mat')
                    train_dat['label'].append(class_2_idx[c])

        for t in val_times:
            for c in class_ids:
                depth_file = os.path.join(depth_dir, f'{s}_{c}_{t}.mat')
                sensor_file = os.path.join(sensor_dir, f'{s}_{c}_{t}.mat')
                skeleton_file = os.path.join(skeleton_dir, f'{s}_{c}_{t}.mat')
                if os.path.exists(depth_file) and os.path.exists(sensor_file) and os.path.exists(skeleton_file):
                    val_dat['depth'].append(f'{s}_{c}_{t}.mat')
                    val_dat['sensor'].append(f'{s}_{c}_{t}.mat')
                    val_dat['skeleton'].append(f'{s}_{c}_{t}.mat')
                    val_dat['label'].append(class_2_idx[c])
        
        for t in test_times:
            for c in class_ids:
                depth_file = os.path.join(depth_dir, f'{s}_{c}_{t}.mat')
                sensor_file = os.path.join(sensor_dir, f'{s}_{c}_{t}.mat')
                skeleton_file = os.path.join(skeleton_dir, f'{s}_{c}_{t}.mat')
                if os.path.exists(depth_file) and os.path.exists(sensor_file) and os.path.exists(skeleton_file):
                    test_dat['depth'].append(f'{s}_{c}_{t}.mat')
                    test_dat['sensor'].append(f'{s}_{c}_{t}.mat')
                    test_dat['skeleton'].append(f'{s}_{c}_{t}.mat')
                    test_dat['label'].append(class_2_idx[c])

    pd.DataFrame(train_dat).to_csv(os.path.join(save_root, 'train.csv'), index=False)
    pd.DataFrame(val_dat).to_csv(os.path.join(save_root, 'val.csv'), index=False)
    pd.DataFrame(test_dat).to_csv(os.path.join(save_root, 'test.csv'), index=False)

class CZUANONDataset(Dataset):
    def __init__(self, root, split, video_length=30, imu_length=180, transform=None, crop_frames=True, return_path=False):
        self.root_sensor = os.path.join(root, 'sensor_mat')
        self.root_depth = os.path.join(root, 'depth_mat')
        self.root_skeleton = os.path.join(root, 'skeleton_mat')
        self.split = split
        self.data = pd.read_csv(os.path.join(root, f'{split}.csv'))
        self.crop_frames = crop_frames
        self.imu_length = imu_length
        self.vid_length = video_length
        self.transform = transform
        self.return_path = return_path

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        depth_file = os.path.join(self.root_depth, self.data['depth'][idx])
        sensor_file = os.path.join(self.root_sensor, self.data['sensor'][idx])
        skeleton_file = os.path.join(self.root_skeleton, self.data['skeleton'][idx])
        label = self.data['label'][idx]

        depth = torch.tensor(scipy.io.loadmat(depth_file)['depth'])
        sensor = scipy.io.loadmat(sensor_file)['sensor']
        # All the sensors have differnt time lenghts, so crop to the shortest
        min_len = min([sensor[i, 0].shape[0] for i in range(sensor.shape[0])])
        sensor_ts = []
        for i in range(sensor.shape[0]):
            sensor_ts.append(torch.tensor(sensor[i,0][:min_len, :6]))
        sensor = torch.stack(sensor_ts)
        skeleton = torch.tensor(scipy.io.loadmat(skeleton_file)['skeleton'])

        # return depth, sensor, skeleton, label

        # n,t,c = sensor.shape #n is number of sensors
        # lets pretend we have one big sensor with n*c channels
        sensor = sensor.permute(1,0,2) # tnc
        sensor = sensor.reshape(sensor.shape[0], -1) # t,nc
        if self.crop_frames:
            # crop to abou 180 imu_length
            sensor = sensor[::8, :]
            t, nc = sensor.shape
            if t>self.imu_length:
                    sensor = sensor[:self.imu_length,:]
            elif t<self.imu_length:
                # Pad sensor with zeros to make them the same length
                sensor = torch.cat([sensor, torch.zeros(self.imu_length - len(sensor), *sensor.shape[1:])])


        # Want to make depth look like a video 
        depth = depth.unsqueeze(1) # insert a channel dims
        # now depth is TCHW
        if self.crop_frames:
            depth = depth[::10,:,:,:].clone()
            t,c,h,w = depth.shape
            if t>self.vid_length:
                depth = depth[:self.vid_length,:,:,:]
            elif t<self.vid_length:
                # Pad depth with zeros to make them the same length
                depth = torch.cat([depth, torch.zeros(self.vid_length - len(depth), *depth.shape[1:])]) 
            # should be TCHW
        
        #i'm gonna expand the channels to 3 just so it fits with our RGB architecture
        if not self.return_path:  depth = depth.expand(-1, 3, -1, -1) #return path means we using img bind which means we need to keep the channel as 1

        #perform tansforms on each frame
        if self.transform:
            depth = torch.stack([self.transform(frame) for frame in depth])

        # return depth, sensor, skeleton, label
        depth = depth.float() #Shape  [30, 3, 224, 224]
        sensor = sensor.float() #Shape [180, 6]



        # #STUFF FOR TIME SHIFT EXPERIMENTS!
        # DEBUG = False
        # if DEBUG: print("Random time crop")
        # if DEBUG: print("Depth shape:", depth.shape, "Sensor shape:", sensor.shape)

        # #Apply a random time between .80 and 1.00 of the video lenght
        # # t_ANONor = torch.rand(1) * 0.2 + 0.8 #gives a number between .6 and 1
        # # let's try .6 to 1.0
        # t_ANONor = torch.rand(1) * 0.4 + 0.6 #gives a number between .6 and 1
        # # now .4 to 1.0
        # # t_ANONor = torch.rand(1) * 0.6 + 0.4 #gives a number between .4 and 1

        # t_shift = torch.rand(1) * (1.0 - t_ANONor) # between 0 and t_ANONor
        # if DEBUG: print("Cropping ANONors:", t_ANONor, t_shift)
        # t_depth = int(t_ANONor * self.vid_length)
        # depth = depth[int(t_shift * self.vid_length):int(t_shift * self.vid_length) + t_depth]
        # t_imu = int(t_ANONor * self.imu_length)
        # sensor = sensor[int(t_shift * self.imu_length):int(t_shift * self.imu_length) + t_imu]
        # if DEBUG: print("Depth shape:", depth.shape, "Sensor shape:", sensor.shape)
        # if DEBUG: print("")
        # if DEBUG: print("Now pad it back:")
        # depth = torch.cat([depth, torch.zeros(self.vid_length - len(depth), *depth.shape[1:])]) 
        # sensor = torch.cat([sensor, torch.zeros(self.imu_length - len(sensor), *sensor.shape[1:])])
        # if DEBUG: print("Depth shape:", depth.shape, "Sensor shape:", sensor.shape)
        # if DEBUG: print("")


        if self.return_path:
            return depth, sensor, label, 0, depth_file, sensor_file
        else:
            return depth, sensor, label, 0 # just to match the interface of my code right now
            # depth, sensor, action label, pid

        
# Define the custom dataset and data loader
# https://torchvideo.readthedocs.io/en/latest/transforms.html
transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),  # Resize frames
    transforms.ToTensor(),           # Convert frames to tensors
])

if __name__ == '__main__':
    from torch.utils.data import DataLoader
    HOME_DIR = os.environ['HOME']
    base_path = f"{HOME_DIR}/data/CZU-ANON"

    # uncomment to generate train, val, test splits
    # train_test_split(root_dir, root_dir)

    # try loading the dataset
    dataset = CZUANONDataset(base_path, 'train_align', crop_frames=True, transform=transforms)
    loader = DataLoader(dataset, batch_size=32, shuffle=True)

    print("Length of dataset:", len(dataset))
    print("Length of loader:", len(loader))
    # compute the average of time dim
    avg_depth = 0
    avg_sensor = 0
    avg_skeleton = 0

    # for i, (depth, sensor, skeleton, label) in enumerate(loader):
        # print(i, depth.shape, sensor.shape, skeleton.shape, label)
        # avg_depth += depth.shape[1]
        # # avg_sensor += sensor.shape[1]
        # avg_skeleton += skeleton.shape[1]
    print("Loaded dataset, looping through it")
    for i, (depth, sensor, label, _) in enumerate(loader):
        print(i, depth.shape, sensor.shape, label)
        avg_depth += depth.shape[1]
        avg_sensor += sensor.shape[1]
        # avg_skeleton += skeleton.shape[1]
    
    print("Length of dataset:", len(dataset))
    print("Average depth length:", avg_depth/len(dataset)) # 131
    print("Average sensor length:", avg_sensor/len(dataset)) # 1476
    print("Average skeleton length:", avg_skeleton/len(dataset)) # 126

    #40.18 %

        