import os
import torch
from torch.utils.data import Dataset
from torchvision.io import read_video
from torchvision.transforms.functional import normalize
import torch.nn.utils.rnn as rnn_utils
import scipy.io
import pandas as pd
import torch.nn.functional as F

import torchvision.transforms as transforms
from datasets.dataset_mmact import MMACT
from datasets.dataset_ANON import CZUANONDataset
from datasets.dataset_holo import Holo
from torch.utils.data import DistributedSampler

TORCH_INTERPOLATE = False

# Define the custom dataset and data loader
# https://torchvideo.readthedocs.io/en/latest/transforms.html
transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),  # Resize frames
    transforms.ToTensor(),           # Convert frames to tensors
])

def load_dataloaders(dataset, model_info, rgb_video_length, imu_length, world_size, args, return_path = False):
    HOME_DIR = os.environ['HOME']
    if dataset=='ANON-ANON':
        model_info['num_imu_channels'] = 6
        if "PID" in model_info['tasks']:
            model_info['num_classes'] = 8
        elif "HAR" in model_info['tasks']:
            model_info['num_classes'] = 27
        #NOTE: Project name doesn't include dataset here, it's default dataset
        num_workers = 4

        datapath = "Both_splits/both_40_40_10_10_#1" #all models should use the same split for comparison
        base_path = f"{HOME_DIR}/data/ANON-ANON/"
        train_dir = os.path.join(base_path,datapath,"train.txt")
        train_2_dir = os.path.join(base_path, datapath,"train_2.txt")
        val_dir = os.path.join(base_path, datapath,"val.txt")
        test_dir = os.path.join(base_path, datapath,"test.txt")

        train_dataset = RGB_IMU_Dataset(train_dir, video_length=rgb_video_length, transform=transforms, base_path=base_path, return_path=return_path)
        #NOTE train_2 is the (X_rgb,Y) used for HAR, train_1 is the (X_IMU, X_rgb)
        train_2_dataset = RGB_IMU_Dataset(train_2_dir, video_length=rgb_video_length, transform=transforms, base_path=base_path, return_path=return_path)
        val_dataset = RGB_IMU_Dataset(val_dir, video_length=rgb_video_length, transform=transforms, base_path=base_path, return_path=return_path)
        test_dataset = RGB_IMU_Dataset(test_dir, video_length=rgb_video_length, transform=transforms, base_path=base_path, return_path=return_path)

    elif dataset=='mmact':
        model_info['num_imu_channels'] = 12
        model_info['num_classes'] = 35
        model_info['project_name'] = "mmact-"+model_info['project_name']
        # num_workers = 4
        num_workers = 12
        # i've experimented around with this. looks like 12 is good. every 12 iteration of dataloader is slower, but the next 12 is faster
        # or 8 works too, it seems like 16 has a hihger chance of crashing

        presaved_mmact = True 
        if presaved_mmact:
            base_path = f"{HOME_DIR}/data/mmact/ANON_presaved_all_sensors/"
        else:
            base_path = f"{HOME_DIR}/data/mmact/ANON_splits/"

        train_dir = os.path.join(base_path,"train_align.txt")
        train_2_dir = os.path.join(base_path,"train_har.txt")
        val_dir = os.path.join(base_path,"val.txt")
        test_dir = os.path.join(base_path,"test.txt")
        
        train_dataset = MMACT(train_dir, video_length=rgb_video_length, imu_length=imu_length, transform=transforms, presaved=presaved_mmact, return_path=return_path)
        train_2_dataset = MMACT(train_2_dir, video_length=rgb_video_length, imu_length=imu_length, transform=transforms, presaved=presaved_mmact, return_path=return_path)
        val_dataset = MMACT(val_dir, video_length=rgb_video_length, imu_length=imu_length, transform=transforms, presaved=presaved_mmact, return_path=return_path)
        test_dataset = MMACT(test_dir, video_length=rgb_video_length, imu_length=imu_length, transform=transforms, presaved=presaved_mmact, return_path=return_path)
        
    elif dataset=='mmea':
        model_info['num_imu_channels'] = 6
        model_info['num_classes'] = 32
        model_info['project_name'] = "mmea-"+model_info['project_name']
        num_workers = 4

        base_path = f"{HOME_DIR}/data/UESTC-MMEA-CL/ANON_splits_all_data"
        train_dir = os.path.join(base_path,"train_align.txt")
        train_2_dir = os.path.join(base_path,"train_har.txt")
        val_dir = os.path.join(base_path,"val.txt")
        test_dir = os.path.join(base_path,"test.txt")
        data_path = f"{HOME_DIR}/data/UESTC-MMEA-CL/"

        train_dataset = RGB_IMU_Dataset(train_dir, video_length=rgb_video_length, transform=transforms, dataset='mmea',base_path=data_path, return_path=return_path)
        train_2_dataset = RGB_IMU_Dataset(train_2_dir, video_length=rgb_video_length, transform=transforms, dataset='mmea',base_path=data_path, return_path=return_path)
        val_dataset = RGB_IMU_Dataset(val_dir, video_length=rgb_video_length, transform=transforms, dataset='mmea',base_path=data_path, return_path=return_path)
        test_dataset = RGB_IMU_Dataset(test_dir, video_length=rgb_video_length, transform=transforms, dataset='mmea',base_path=data_path, return_path=return_path)
        
    elif dataset=='czu-ANON':
        model_info['num_imu_channels'] = 60
        model_info['num_classes'] = 22
        model_info['project_name'] = "czu_ANON-"+model_info['project_name']
        num_workers = 4

        base_path = f"{HOME_DIR}/data/CZU-ANON"
        train_dir = "train_align"
        train_2_dir = "train_har"
        val_dir = "val"
        test_dir = "test"

        train_dataset = CZUANONDataset(base_path, train_dir, video_length=rgb_video_length, transform=transforms, return_path=return_path)
        train_2_dataset = CZUANONDataset(base_path, train_2_dir, video_length=rgb_video_length, transform=transforms, return_path=return_path)
        val_dataset = CZUANONDataset(base_path, val_dir, video_length=rgb_video_length, transform=transforms, return_path=return_path)
        test_dataset = CZUANONDataset(base_path, test_dir, video_length=rgb_video_length, transform=transforms, return_path=return_path)
        
    elif dataset=='holo':
        model_info['num_imu_channels'] = 9
        model_info['num_classes'] = 1887
        model_info['project_name'] = "holo-"+model_info['project_name']
        num_workers = 0

        base_path = f"{HOME_DIR}/data/HoloAssist"
        # FOR THE PURPOSES OF action recognition we use the same train and train_2 loaders
        train_dir = f"{HOME_DIR}/data/HoloAssist/train-v1.txt"
        train_2_dir = f"{HOME_DIR}/data/HoloAssist/train-v1.txt"
        val_dir = f"{HOME_DIR}/data/HoloAssist/val-v1.txt"
        test_dir = f"{HOME_DIR}/data/HoloAssist/test-v1.txt"

        train_dataset = Holo(train_dir, base_path = base_path, video_length=rgb_video_length, transform=transforms)
        train_2_dataset = Holo(train_2_dir, base_path = base_path, video_length=rgb_video_length, transform=transforms)
        val_dataset = Holo(val_dir, base_path = base_path, video_length=rgb_video_length, transform=transforms)
        test_dataset = val_dataset
        #NOTE: RIGHT NOW TEST IS SAME AS VAL! NEED TO FIGURE OUT HOW TO ACTUALLY RUN TEST LATER, bc it doesn't use annot json!
        
        #NOTE: I'm not sure why i made this separate for HoloAssist,
        sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=args.rank)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=num_workers, drop_last=True, sampler=sampler) #drop_last is for multi-gpu training
        sampler_2 = DistributedSampler(train_2_dataset, num_replicas=world_size, rank=args.rank)
        train_2_loader = torch.utils.data.DataLoader(train_2_dataset, batch_size=args.batch_size, num_workers=num_workers, drop_last=True, sampler=sampler_2)
        sampler_val = DistributedSampler(val_dataset, num_replicas=world_size, rank=args.rank)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, num_workers=num_workers, drop_last=True, sampler=sampler_val)
        # test_dataset = Holo(val_dir, base_path = base_path, video_length=rgb_video_length, transform=transforms)
        # test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers)
        test_loader = val_loader
    
    else:
        raise NotImplementedError("Dataset not implemented: ", dataset)
    
    if args.single_gpu:
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers)
        train_2_loader = torch.utils.data.DataLoader(train_2_dataset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers)
    else:
        sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=args.rank)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=num_workers, drop_last=True, sampler=sampler, pin_memory=True) #drop_last is for multi-gpu training
        sampler_2 = DistributedSampler(train_2_dataset, num_replicas=world_size, rank=args.rank)
        train_2_loader = torch.utils.data.DataLoader(train_2_dataset, batch_size=args.batch_size, num_workers=num_workers, drop_last=True, sampler=sampler_2, pin_memory=True)
        sampler_val = DistributedSampler(val_dataset, num_replicas=world_size, rank=args.rank)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, num_workers=num_workers, sampler=sampler_val, pin_memory=True)
        sampler_test = DistributedSampler(test_dataset, num_replicas=world_size, rank=args.rank)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, num_workers=num_workers, sampler=sampler_test, pin_memory=True)

    return train_loader, train_2_loader, val_loader, test_loader, model_info

class RGB_IMU_Dataset(Dataset):
    def __init__(self, split_file, video_length=50, imu_length=180, transform=None, base_path ="/home/ANON/data/ANON-ANON/", dataset = 'ANON-ANON', crop_frames=True, return_path=False):
        self.split_file = split_file
        self.transform = transform
        self.videos = []
        self.vid_length = video_length
        self.imu_length = imu_length
        self.return_path = return_path
        self.crop_frames = crop_frames
        self.dataset = dataset
        
        if dataset == 'ANON-ANON':
            self.num_classes=27
        elif dataset == 'mmea':
            self.num_classes=32
        else:
            raise NotImplementedError(f"{dataset} is not implemented")
        
        #assume we are reading "path label_action label_PID" from a file
        f = open(self.split_file,'r')
        for line in f.readlines():
            if dataset == 'ANON-ANON':
                video_name, class_idx, pid_idx = line.split(" ")
                class_idx = int(class_idx)-1 # label from 1, so we need to subtract 1
            elif dataset == 'mmea':
                video_name, class_idx = line.split(" ")
                class_idx = int(class_idx)
                pid_idx = 1 # dummy for now
            # print(line)
            else:
                raise NotImplementedError(f"{dataset} is not implemented")


            if class_idx > self.num_classes or class_idx < 0:
                raise ValueError(f"{class_idx} is an invalid class index")
            pid_idx = int(pid_idx)-1
            if pid_idx > 8 or pid_idx < 0:
                raise ValueError(f"{pid_idx} is an invalid PID index")
        
            if dataset == 'ANON-ANON':
                rgb_path = os.path.join(base_path,"RGB", video_name+"_color.avi")
                imu_path = os.path.join(base_path,"Inertial", video_name+"_inertial.mat")
            elif dataset == 'mmea':
                rgb_path = os.path.join(base_path,"video", video_name+".mp4")
                imu_path = os.path.join(base_path,"sensor", video_name+".csv")
            else:
                raise NotImplementedError(f"{dataset} is not implemented")

            self.videos.append((rgb_path, imu_path, class_idx, pid_idx))  # (video path, IMU path, class index)
        f.close()

            

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, idx):
        rgb_path, imu_path, class_idx, pid_idx = self.videos[idx]
        frames, audio, info = read_video(rgb_path, pts_unit="sec") # Tensor Shape THWC: 57,480,640,3

        """ VIDEO PROCESSING"""
        # # Normalize video frames (you can adjust the mean and std)
        # frames = normalize(frames.float(), mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
        
        # frames = torch.stack(frames) # conver list to tensor

        if self.crop_frames:
            if TORCH_INTERPOLATE:
                #try using torch interpolate NOPE Takes tooo longgg
                # permute to CTHW
                frames = frames.permute(3,0,1,2).float()
                frames = F.interpolate(frames.unsqueeze(0), size=(self.vid_length, frames.shape[2], frames.shape[3]), mode='trilinear', align_corners=False)
                frames = frames.squeeze(0) #i think we need to add and remove batch dim for interpolate
                frames = frames.permute(1,0,2,3) # permute to TCHW to perform image-wise transforms
            else:
                # Pad or shorten video to vid_length
                t,h,w,c = frames.shape
                if self.dataset == 'ANON-ANON':
                    # downsample time by half
                    frames = frames[::2,:,:,:].clone()
                elif self.dataset == 'mmea':
                    # downsample time by 1/10
                    frames = frames[::10,:,:,:].clone() # 10 is taking waayyyyy too long
                    # frames = frames[::2,:,:,:].clone()
                else:
                    raise NotImplementedError(f"{self.dataset} is not implemented")
                t,h,w,c = frames.shape
                if t>self.vid_length:
                    frames = frames[:self.vid_length,:,:,:]
                elif t<self.vid_length:
                    # Pad frames with zeros to make them the same length
                    frames = torch.cat([frames, torch.zeros(self.vid_length - len(frames), *frames.shape[1:])]) 
                frames = frames.permute(0,3,1,2) # permute to TCHW to perform image-wise transforms
        else:
            frames = frames.permute(0,3,1,2) # permute to TCHW to perform image-wise transforms

        #perform tansforms on each frame
        if self.transform:
            frames = torch.stack([self.transform(frame) for frame in frames])

        # Time, channel makes more sense bc channel describes img, and t describes multiple imgs
        # # # Permute to CTHW for 3d convs
        # frames = frames.permute(1,0,2,3)
        
        # return frames, class_idx #returns TCHW and class idx

        """IMU PROCESSING"""
        if self.dataset == 'ANON-ANON':
            data = scipy.io.loadmat(imu_path)
            # xyz = np.array(data['d_iner'])
            # Convert to torch tensor
            accel_data = torch.tensor(data['d_iner']).float()
        elif self.dataset == 'mmea':
            # read csv file 
            try:
                accel_data = pd.read_csv(imu_path, header=None)
            except pd.errors.EmptyDataError:
                print("Empty data error@:", imu_path)
                print("Skipping to next item")
                return self.__getitem__(idx+1)
            accel_data = torch.tensor(accel_data.values) # shape [timesteps, 3]
            # print("IMU length:", accel_data.shape)
        else:
            raise NotImplementedError(f"{self.dataset} is not implemented")

        if self.crop_frames:
            if TORCH_INTERPOLATE:
                #let's try using torch to interpolate the data
                accel_data = accel_data.permute(1,0) # permute to CT to interpolate
                accel_data = F.interpolate(accel_data.unsqueeze(0), size=(self.imu_length), mode='linear', align_corners=False)
                accel_data = accel_data.squeeze(0) #i think we need to add and remove batch dim for interpolate
                accel_data = accel_data.permute(1,0) # permute back to TC
            else:
                t,xyz = accel_data.shape
                if self.dataset == 'mmea':
                    #downsample time by half
                    accel_data = accel_data[::2,:].clone()
                t,xyz = accel_data.shape
                if t>self.imu_length:
                    accel_data = accel_data[:self.imu_length,:]
                elif t<self.imu_length:
                    # Pad accel_data with zeros to make them the same length
                    accel_data = torch.cat([accel_data, torch.zeros(self.imu_length - len(accel_data), *accel_data.shape[1:])])
        #no need to permute, already in txc shape


        # return accel_data, int(label), file_path
        frames=frames.float()
        accel_data=accel_data.float()
        if self.return_path:
            return frames, accel_data, class_idx, pid_idx, rgb_path, imu_path
        else:
            # print("Called dataset: returning:", frames.shape, accel_data.shape, class_idx, pid_idx)
            return frames, accel_data, class_idx, pid_idx #returns TCHW video , TC IMU

if __name__=='__main__':
    # dir = "/home/ANON/data/ANON-ANON/Both_splits/both_80_20_#1/train.txt"
    # dir = "/home/ANON/data/ANON-ANON/Both_splits/both_80_20_#1/val.txt"
    # d = RGB_IMU_Dataset(dir, crop_frames=False)

    dir = "/home/ANON/data/UESTC-MMEA-CL/ANON_splits/train_har.txt"
    base_path = "/home/ANON/data/UESTC-MMEA-CL/"
    d = RGB_IMU_Dataset(dir, crop_frames=False, dataset='mmea', base_path=base_path)
    # d = RGB_IMU_Dataset(dir, dataset='mmea', base_path=base_path)
    
    video_lengths = []
    IMU_lengths = []
    for itm in d:
        print("Input RGB:", itm[0].shape, "Input IMU:", itm[1].shape, "action label:", itm[2], "PID label:", itm[3])
        video_lengths.append(itm[0].shape[0])
        IMU_lengths.append(itm[1].shape[0])
        continue

    print(len(d))
    print("average video length:", sum(video_lengths)/len(video_lengths))
    print("average IMU length:", sum(IMU_lengths)/len(IMU_lengths))
    
    # For ANON-ANON
    # len = 366
    # average video length: 52.55668604651163 (in training)
    # average IMU length: 178.31 (in training)

    # For mmea
    # len = 3160
    # average video length: ~400
    # average IMU length: ~400

    
