import os
import torch
from torch.utils.data import Dataset
from torchvision.io import read_video
from torchvision.transforms.functional import normalize
import torchvision.transforms as transforms

import torch.nn.utils.rnn as rnn_utils
import scipy.io
from tqdm import tqdm
import pandas as pd
import time

import torch.nn.functional as F

TORCH_INTERPOLATE = False

#NOTE: can decrease imu_length if needed
class MMACT(Dataset):
    def __init__(self, split_file, video_length=30, imu_length=180, transform=None, return_path=False, presaved=False): 
        self.split_file = split_file
        self.transform = transform
        self.videos = []
        self.vid_length = video_length
        self.imu_length = imu_length
        self.return_path = return_path
        self.presaved = presaved
        
        #assume we are reading "path label_action label_PID" from a file
        f = open(self.split_file,'r')
        for line in f.readlines():
            if not self.presaved:
                # wrist_accel_path, rgb_path, class_idx = line.split(" ")
                wrist_accel_path, phone_accel_path, phone_gyro_path, phone_orientation_path, rgb_path, class_idx = line.split(" ")
                class_idx = int(class_idx)
                if class_idx > 34 or class_idx < 0:
                    raise ValueError(f"{class_idx} is an invalid class index")
                # pid_idx = int(pid_idx)-1
                # if pid_idx > 8 or pid_idx < 0:
                #     raise ValueError(f"{pid_idx} is an invalid PID index")
            
                pid_idx = 0 # dummy for now to extend to multitask later    
                # self.videos.append((rgb_path, wrist_accel_path, class_idx, pid_idx))  # (video path, IMU path, class index)
                self.videos.append((rgb_path, wrist_accel_path, phone_accel_path, phone_gyro_path, phone_orientation_path, class_idx, pid_idx))  # (video path, IMU path, class index)
            else:
                self.videos.append(line.strip())
        f.close()

            

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, idx):
        # print("in mmact:", idx)
        # print("presaved:", self.presaved)
        # print("return path:", self.return_path)
        # print("self.videos[idx]:", self.videos[idx])
    
        if self.presaved:
            # The data was saved as a .pt file with frames, accel and class as tensors but rgb_path and imu_path as strings
            frames, accel_data, class_idx, pid_idx, rgb_path, imu_path = torch.load(self.videos[idx])
            #always save rgb_path and imu_path for debugging, but then return accordingly
            # print("rgb_path:", rgb_path)
            # print("imu_path:", imu_path)
            # print("frames:", frames.shape)
            # print("accel_data:", accel_data.shape)
            # print("class_idx:", class_idx)

            #quick fix for ANON
            HOME_DIR = os.environ['HOME']
            if rgb_path.split("/data/")[0] != HOME_DIR:
                # print(HOME_DIR)
                # print(os.path.join(HOME_DIR, 'data', "/testing.mp4"))
                # print(rgb_path.split("/data")[1])
                # print(os.path.join(HOME_DIR, 'data', rgb_path.split("/data/")[1]))
                # print(os.path.join(os.path.join(HOME_DIR, 'data'), rgb_path.split("/data")[1]))
                rgb_path = os.path.join(HOME_DIR, 'data', rgb_path.split("/data/")[1])
            if imu_path.split("/data")[0] != HOME_DIR:
                imu_path = os.path.join(HOME_DIR, 'data', imu_path.split("/data/")[1])

            # print("New rgb_path:", rgb_path)
            # print("New imu_path:", imu_path)
        else:
            # rgb_path, imu_path, class_idx, pid_idx = self.videos[idx]
            rgb_path, wrist_accel_path, phone_accel_path, phone_gyro_path, phone_orientation_path, class_idx, pid_idx = self.videos[idx]
            imu_path = wrist_accel_path

            frames, audio, info = read_video(rgb_path, pts_unit="sec") # Tensor Shape THWC: 57,480,640,3
            

            """ VIDEO PROCESSING"""
            # # Normalize video frames (you can adjust the mean and std)
            # frames = normalize(frames.float(), mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
            
            # frames = torch.stack(frames) # conver list to tensor
            # print(len(frames),frames[0].shape)

            # print("Video length:", frames.shape)
            if TORCH_INTERPOLATE:
                # use torch interpolate:
                # permute to CTHW
                frames = frames.permute(3,0,1,2).float()
                frames = F.interpolate(frames.unsqueeze(0), size=(self.vid_length, frames.shape[2], frames.shape[3]), mode='trilinear', align_corners=False)
                frames = frames.squeeze(0) #i think we need to add and remove batch dim for interpolate
                frames = frames.permute(1,0,2,3) # permute to TCHW to perform image-wise transforms
            
            else:
                # Pad or shorten video to vid_length
                # before cropping in time, let's downsample in time by a ANONor of 4, bc the average is 180~200, but our gpu probs fits about 50
                frames = frames[::4].clone()
                t,h,w,c = frames.shape
                if t>self.vid_length:
                    frames = frames[:self.vid_length,:,:,:]
                elif t<self.vid_length:
                    # Pad frames with zeros to make them the same length
                    frames = torch.cat([frames, torch.zeros(self.vid_length - len(frames), *frames.shape[1:])]) 
                frames = frames.permute(0,3,1,2) # permute to TCHW to perform image-wise transforms

            #perform tansforms on each frames
            if self.transform:
                frames = torch.stack([self.transform(frame) for frame in frames])

            # Time, channel makes more sense bc channel describes img, and t describes multiple imgs
            # # # Permute to CTHW for 3d convs
            # frames = frames.permute(1,0,2,3)
            
            # return frames, class_idx #returns TCHW and class idx

            """IMU PROCESSING"""
            # read csv file 
            try:
                accel_data = pd.read_csv(imu_path, header=None)
            except pd.errors.EmptyDataError:
                print("Empty data error@:", imu_path)
                print("Skipping to next item")
                return self.__getitem__(idx+1)
            #delete the first column
            accel_data = accel_data.drop(accel_data.columns[0], axis=1)
            accel_data = torch.tensor(accel_data.values) # shape [timesteps, 3]
            # print("IMU length:", accel_data.shape)
            
            if TORCH_INTERPOLATE:
                accel_data = accel_data.permute(1,0) # permute to CT to interpolate
                accel_data = F.interpolate(accel_data.unsqueeze(0), size=(self.imu_length), mode='linear', align_corners=False)
                accel_data = accel_data.squeeze(0) #i think we need to add and remove batch dim for interpolate
                accel_data = accel_data.permute(1,0) # permute back to TC
            else:
                # again lets scale from 700 down to 180 maybe by 3
                accel_data = accel_data[::3].clone()
                t,xyz = accel_data.shape
                if t>self.imu_length:
                    accel_data = accel_data[:self.imu_length,:]
                elif t<self.imu_length:
                    # Pad accel_data with zeros to make them the same length
                    accel_data = torch.cat([accel_data, torch.zeros(self.imu_length - len(accel_data), *accel_data.shape[1:])])


            # For the rest concatenate on the first axis (channel dimension)
            imu_stuff = [phone_accel_path, phone_gyro_path, phone_orientation_path]
            for imu in imu_stuff:
                try:
                    data = pd.read_csv(imu, header=None)
                except pd.errors.EmptyDataError:
                    print("Empty data error@:", imu)
                    print("Skipping to next item")
                    return self.__getitem__(idx+1)
                data = data.drop(data.columns[0], axis=1)
                data = torch.tensor(data.values)
                if TORCH_INTERPOLATE:
                    data = data.permute(1,0) # permute to CT to interpolate
                    data = F.interpolate(data.unsqueeze(0), size=(self.imu_length), mode='linear', align_corners=False)
                    data = data.squeeze(0) #i think we need to add and remove batch dim for interpolate
                    data = data.permute(1,0) # permute back to TC
                else:
                    # again lets scale from 700 down to 180 maybe by 3
                    data = data[::3].clone()
                    t,xyz = data.shape
                    if t>self.imu_length:
                        data = data[:self.imu_length,:]
                    elif t<self.imu_length:
                        # Pad data with zeros to make them the same length
                        data = torch.cat([data, torch.zeros(self.imu_length - len(data), *data.shape[1:])])

                # print(data.shape)
                # print(accel_data.shape)
                accel_data = torch.cat([accel_data, data], dim=1)


        # return accel_data, int(label), file_path
        frames = frames.float()
        accel_data = accel_data.float()
        if self.return_path:
            return frames, accel_data, class_idx, pid_idx, rgb_path, imu_path
        else:
            return frames, accel_data, class_idx, pid_idx #returns TCHW video

# task 1 has 35 classes https://mmact19.github.io/challenge/
label_dict = {
    'pushing': 0,
    'running': 1,
    'crouching': 2,
    'using_phone': 3,
    'jumping': 4,
    'fall': 5,
    'waving_hand': 6,
    'picking_up': 7,
    'looking_around': 8,
    'pointing': 9,
    'closing': 10,
    'loitering': 11,
    'standing': 12,
    'checking_time': 13,
    'kicking': 14,
    'talking_on_phone': 15,
    'exiting': 16,
    'entering': 17,
    'setting_down': 18,
    'opening': 19,
    'carrying': 20,
    'throwing': 21,
    'walking': 22,
    'pulling': 23,
    'transferring_object': 24,
    'talking': 25,
    'using_pc': 26,
    'sitting_down': 27,
    'sitting': 28,
    'pocket_out': 29,
    'drinking': 30,
    'standing_up': 31,
    'pocket_in': 32,
    'carrying_heavy': 33,
    'carrying_light': 34}

def get_files(dir, ext):
    """
    Get a list of file paths, file info, and label dictionary for files in a directory with a specific extension.

    Args:
        dir (str): The directory path.
        ext (str): The file extension.

    Returns:
        tuple: A tuple containing the following:
            - file_paths (list): A list of file paths.
            - file_info (dict): A dictionary mapping file setups to file names. k: setup_description, v: file_path
            - file_labels (dict): A dictionary mapping file names to label numbers. k: file_path, v: label
    """
    file_paths = []
    file_info = {}
    file_label = {}
    for root, dirs, files in os.walk(dir):
        for file in files:
            if ext=='.mp4':
                #lets only use cam1
                if "cam1" not in root:
                    continue
        
            if not file.endswith(ext):
                # Let's warn the user if a file in the directory is not the expected file type
                print(f"Warning: found a file that is not {ext}:", file)
                continue

            
            #save the file path
            full_path = os.path.join(root, file)

            #some acc data is empty... so let's read it and check if it's empty, if so skip it!
            if ext=='.csv':
                try:
                    accel_data = pd.read_csv(full_path, header=None)
                except pd.errors.EmptyDataError:
                    print("Empty data error@:", full_path)
                    print("Skipping to next item")
                    continue

            file_paths.append(full_path)

            #build the label dictionary
            label = file.replace(ext, "") #the file name is the label!
            file_label[full_path] = label_dict[label.lower()]
                
            #extract the file info to compare with other modalities
            # extract everything after subject (including subject), ignore camera
            if ext=='.mp4':
                setup = full_path[full_path.find("/subject"):].replace("cam1/","").replace(".mp4","")
            else:
                setup = full_path[full_path.find("/subject"):].replace(ext,"") 

            #check to make sure we don't have two of the same takes (that shouldn't be possible)
            assert setup not in file_info, f"ERROR: {setup} already in video_info"
            file_info[setup] = full_path

        
    return file_paths, file_info, file_label

def interpret_data():
    """
    This function extracts all the data from the trimmed files for one modality and sees if it has corersponding data in the other modality. 
    """
        # Ignore the names of the variables, that was just created bc i was comparing video and watch accel data first, but you can used this code to compare any sets of data
    """Modality 1"""
    # Uncomment to compare cross_scene
    # subdir = "/home/ANON/data/mmact/cross_scene_video/video/cross_scene/trainval/"
    # file_paths, video_info, file_dict = get_files(subdir, '.mp4')

    # Uncomment to compare cross_view
    # subdir = "/home/ANON/data/mmact/cross_view_video/video/cross_view/trainval/"
    # file_paths, video_info, file_dict = get_files(subdir, '.mp4')

    # Uncomment to compare phone accel
    subdir = "/home/ANON/data/mmact/sensor/acc_phone_clip"
    file_paths, video_info, file_dict = get_files(subdir, '.csv')

    """Modality 2"""
    # # Uncomment to compare accel watch
    accel_watch_dir = "/home/ANON/data/mmact/sensor/acc_watch_clip"
    accel_watch_paths, accel_watch_info, accel_watch_dict = get_files(accel_watch_dir, '.csv')

    # Uncomment to compare phone gyro
    # accel_watch_dir = "/home/ANON/data/mmact/sensor/gyro_clip"
    # accel_watch_paths, accel_watch_info, accel_watch_dict = get_files(accel_watch_dir, '.csv')

    # Uncomment to compare phone orientation
    # accel_watch_dir = "/home/ANON/data/mmact/sensor/orientation_clip"
    # accel_watch_paths, accel_watch_info, accel_watch_dict = get_files(accel_watch_dir, '.csv')
    
    
    print("Len M1:")
    len_video = len(file_paths)
    len_video_info = len(video_info)
    assert len_video == len_video_info, "ERROR: video paths and video info are not the same length"
    print(len_video)
    print("Len M2:")
    len_accel_watch = len(accel_watch_paths)
    len_accel_watch_info = len(accel_watch_info)
    assert len_accel_watch == len_accel_watch_info, "ERROR: accel watch paths and accel watch info are not the same length"
    print(len_accel_watch)
    
    print("Difference M2-M1", len_video-len_accel_watch)
    
    count = 0
    for key in video_info:
        if key not in accel_watch_info:
            print(f"ERROR: {key} not found in accel_watch_info (M2), but is in video_info (M1)")
            count+=1
        
        
    print("Count of missing keys:", count)
    """ It looks like between cross_view and accel_watch we have difference of 22 paths collected, but 258 keys mismatch. That means only 8758-258 = 8500 are synchornous sequences between the two modalities.
    Between cross_scene and accel_watch we have difference of -4265 (more accel than cross_scene), and 145 keys missing. That's 4471 - 154 = 4317 synchornous sequences between the two modalities.
    phone accel has exactly 246 more seqs than watch accel and all the other seqs match perfectly. 8736 of watch accel is totally in phone accel
    It looks like phone accel and gyro are a perfect match, however, orientation has 3 more than phone accel.

    I think it makes sense to use watch_accel as the primary and match all the other modality sequences to that

    Let's start off with the cross view data. if it becomes to big and sticky to work with we can go to cross_scene
    """

def create_data_split():
    watch_accel = "/home/ANON/data/mmact/sensor/acc_watch_clip"
    video_dir = "/home/ANON/data/mmact/cross_view_video/video/cross_view/trainval/"
    phone_accel = "/home/ANON/data/mmact/sensor/acc_phone_clip"
    phone_gyro = "/home/ANON/data/mmact/sensor/gyro_clip"
    phone_orientation = "/home/ANON/data/mmact/sensor/orientation_clip"

    # for now let's just work with watch and video, can add phone info if it doesn't work... which i have a feeling it might not.
    watch_accel_paths, watch_accel_info, watch_accel_labels = get_files(watch_accel, '.csv')
    phone_accel_paths, phone_accel_info, phone_accel_labels = get_files(phone_accel, '.csv')
    phone_gyro_paths, phone_gyro_info, phone_gyro_labels = get_files(phone_gyro, '.csv')
    phone_orientation_paths, phone_orientation_info, phone_orientation_labels = get_files(phone_orientation, '.csv')

    multi_dict = {} #{descr: {paths: [], labels: []}}, write now assuming paths [accel, video] and label is [har], but ideally can add more (or even make it a dict)
    video_paths, video_info, video_labels = get_files(video_dir, '.mp4')
    for key in video_info:
        #info is {descr: path} labels is {path: label}
        if key in watch_accel_info and key in phone_accel_info and key in phone_gyro_info and key in phone_orientation_info:
            multi_dict[key] = {"paths": [watch_accel_info[key], phone_accel_info[key], phone_gyro_info[key], phone_orientation_info[key], video_info[key]], "labels": [video_labels[video_info[key]]]}

    # Now create the training splits
    base_dir = '/home/ANON/data/mmact/ANON_splits'
    split = {
        "train_align": 0.4,
        "train_har": 0.4,
        "val": 0.1,
        "test": 0.1
    }
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)
    
    #create a list so we can splice it into splits
    multi_list = list(multi_dict.items())
    cumsum = 0
    for key in split:
        #open txt file for the split
        split_file_name = os.path.join(base_dir, key+".txt")
        print("Creating", split_file_name)
        split_file = open(split_file_name, "w")

        split_list = multi_list[cumsum:cumsum+int(len(multi_list)*split[key])]
        cumsum += len(split_list)
        for k,v in split_list:
            for f in v["paths"]:
                split_file.write(f + " ")
            split_file.write(str(v["labels"][0]) + "\n")
        split_file.close()

def add_modality():
    """
    Here we want to add a modality without changing the current data split. Well we will have to change the split.txt files, but ideally we shouldn't have to resave all the data, bc the new dataset would be a subset of the old.

    WAIT ACTUALLY I WILL HAVE TO PRESAVE THE SPLIT AGAIN BECAUSE I SAVED THE WHOLE TESNOR LOL OK LET'S JUST MODIFY ABOVE AND RESAVE ALL WATCH SENSORS.
    """

    # Let's add watch_accel
    watch_accel = "/home/ANON/data/mmact/sensor/acc_watch_clip"
    video_dir = "/home/ANON/data/mmact/cross_view_video/video/cross_view/trainval/"
    phone_accel = "/home/ANON/data/mmact/sensor/acc_phone_clip"

    #First pull out the existing data
    watch_accel_paths, watch_accel_info, watch_accel_labels = get_files(watch_accel, '.csv')
    video_paths, video_info, video_labels = get_files(video_dir, '.mp4')
    phone_accel_paths, phone_accel_info, phone_accel_labels = get_files(phone_accel, '.csv')

    multi_dict = {} #{descr: {paths: [], labels: []}}, write now assuming paths [accel, video] and label is [har], but ideally can add more (or even make it a dict)
    #now require that the video is in phone accel AND watch accel
    for key in video_info:
        if key in watch_accel_info: #info is {descr: path} labels is {path: label}
            multi_dict[key] = {"paths": [watch_accel_info[key], video_info[key]], "labels": [video_labels[video_info[key]]]}

def create_presaved():
    """
    This code assumes you have already created the ANON_splits data splits
    """
    use_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),  # Resize frames
        transforms.ToTensor(),           # Convert frames to tensors
    ])
    # Here lets create a directory of presaved tensors:
    presave_dir = "/home/ANON/data/mmact/ANON_presaved_all_sensors"
    if not os.path.exists(presave_dir):
        os.makedirs(presave_dir)
    # splits = ["train_align", "train_har", "val", "test"]
    # splits = ["train_align"]
    # splits = ["train_har"]
    # splits = ["val"]
    splits = ["test"]

    #for every split, loop through the dataset and save the raw tensors
    for split in splits:
        dataset = MMACT(f"/home/ANON/data/mmact/ANON_splits/{split}.txt", transform=use_transforms, return_path=True)
        save_split_dir = os.path.join(presave_dir, split)
        if not os.path.exists(save_split_dir):
            os.makedirs(save_split_dir)
        
        #also create a new split file for the new paths of just the tensors not the raw data
        new_split_file = open(os.path.join(presave_dir, f"{split}.txt"), "w")

        for i,itm in enumerate(tqdm(dataset)):
            #lets save each tensor named with the og accel path it came from, so we can always reference it to debug
            new_file_dir = os.path.join(save_split_dir, f"{itm[-1].replace('/','_').replace('.csv','')}.pt") #this is the new tesnor saved file
            torch.save(itm, new_file_dir) 
            new_split_file.write(new_file_dir + "\n")
            # #just for debugging
            # if i==5:
            #     break

        new_split_file.close()

def test_presaved_vs_og_dataset():

    print("Testing new dataset")
    # Now let's test the new dataset
    start = time.time()
    dataset_presaved = MMACT(f"/home/ANON/data/mmact/ANON_presaved_all_sensors/test.txt", presaved=True, return_path=True)
    for i, itm in enumerate(dataset_presaved):
        print(f"Item {i}/{len(dataset_presaved)}")
        print("Input RGB:", itm[0].shape, "Input IMU:", itm[1].shape, "action label:", itm[2], "PID label:", itm[3])
        if i==5:
            break
    presaved_time = time.time()-start
    

    use_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),  # Resize frames
        transforms.ToTensor(),           # Convert frames to tensors
    ])
    start = time.time()
    d = MMACT(f"/home/ANON/data/mmact/ANON_splits/test.txt", transform=use_transforms, return_path=True)
    for i, itm in enumerate(d):
        print(f"Item {i}/{len(d)}")
        print("Input RGB:", itm[0].shape, "Input IMU:", itm[1].shape, "action label:", itm[2], "PID label:", itm[3])
        if i==5:
            break
    original_time = time.time()-start

    print("Comparing the two datasets")
    for i, itms in enumerate(zip(d, dataset_presaved)):
        itm1, itm2 = itms
        if not torch.equal(itm1[0], itm2[0]):
            print("RGB not equal, path:", itm1[-2], itm2[-2])
        if not torch.equal(itm1[1], itm2[1]):
            print("IMU not equal, path:", itm1[-1], itm2[-1])
        if itm1[2] != itm2[2]:
            print("Action label not equal, path:", itm1[-2], itm2[-2])
        if itm1[3] != itm2[3]:
            print("PID label not equal")
    print("If no errors printed they are equivalent")

    print("Time taken for presaved:", presaved_time)
    print("Time taken for original:", original_time)


if __name__=='__main__':
    # dir = "/home/ANON/data/mmact/ANON_splits/test.txt"

    # d = MMACT(dir)
    
    # video_lengths = []
    # imu_lenghts = []
    # for i, itm in enumerate(d):
    #     print(f"Item {i}/{len(d)}")
    #     print("Input RGB:", itm[0].shape, "Input IMU:", itm[1].shape, "action label:", itm[2], "PID label:", itm[3])
    #     video_lengths.append(itm[0].shape[0])
    #     imu_lenghts.append(itm[1].shape[0])
    #     if i==3:
    #         break

    # print(len(d))
    # print("average video length:", sum(video_lengths)/len(video_lengths))
    # print("average imu length:", sum(imu_lenghts)/len(imu_lenghts))
    # # average video length: 226.9035294117647
    # # average imu length: 768.9341176470588

    # create_data_split()

    create_presaved()

    # you should probs run this test with just a few presaved tensors
    # test_presaved_vs_og_dataset()
    
    


    
    
    
        
# Next Steps: put this in a dataset and read the actual video to a tensor and accel data to a tensor
    





    
