# pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121
from typing import Callable, Optional
from torch.utils.data import Dataset
import torch
import torchvision
# import torchvision.transforms as transforms
torchvision.set_video_backend('pyav')
import pandas as pd
import numpy as np
import os
import json
from typing import Dict, Union, Any
from tqdm import tqdm
from .data_process_utils import linear_interpolation_1d, align_time_dimension, align_batch_n_sample_dimension, align_batch_n_frame_dimension

# from ...cfgs.real_car import DRIVING_CONDITIONS, SUBJECT_DICT, YOLO_FEATURE_LABELS, EEG_FEATURE_LABELS, CANBUS_FEATURE_LABELS

# driving_conditions = DRIVING_CONDITIONS
# subject_dict = SUBJECT_DICT

# visual_labels = YOLO_FEATURE_LABELS
# eeg_labels = EEG_FEATURE_LABELS
# canbus_labels = CANBUS_FEATURE_LABELS

class RealCarDataset(Dataset):
    def __init__(self, config_dict, mode: str, transform: Optional[Callable], logger):
        # dataset_dir, driving_conditions: list[str], subject_dict: dict[str,list[str]], visual_labels: list[str], eeg_labels: list[str], canbus_labels: list[str]
        self.dataset_dir = config_dict["dataset_dir"]
        # if change the sample rate, we should do interpolation
        self.interpolation_needed = config_dict["interpolation_needed"]
        # the specific sample rate of video, eeg and canbus
        self.sample_rate_dict = config_dict["sample_rate_dict"]
        # this should be an integer, for each frame of video, how many samples of eeg/canbus is corresponded with the frame
        self.n_sample_per_frame = config_dict["n_sample_per_frame"]
        self.driving_conditions = config_dict["driving_conditions"]
        self.subject_dict = config_dict["subject_dict"]
        self.visual_network = config_dict["visual_network"]
        self.eeg_channel_labels = config_dict["eeg_channel_labels"]
        self.canbus_channel_labels = config_dict["canbus_channel_labels"]
        self.log_data_loading_process = config_dict["log_data_loading_process"]
        
        # whole | train | test
        # whole: do some preprocess for the dataset and load the whole dataset
        # train: load data in the training set
        # test: load data in the test set
        self.mode = mode
        assert mode in ["whole", "train", "test"], "`mode` has invalid data!"
        
        self.transform = transform
        self.logger = logger
        # base on mode, get different data_identifiers
        self.data_identifiers = self.get_all_info()
        
        #TODO move data preprocess out of dataset
        # if self.mode == "whole":
        #     # only used to preprocess data
        #     if self.interpolation_needed:
        #         self.human_data_interpolation()    
        #     self.video_data_preprocess()
        # else:
        #     # we need to check whether or not the train/test set has been splited successfully
        #     n_data = len(self.data_identifiers)

        #     try:
        #         video_data_dir = os.path.join(self.dataset_dir, "video", f"{self.visual_network}_processed")
        #         final_sample_rate = self.sample_rate_dict["eeg"]
        #         human_data_dir = os.path.join(self.dataset_dir, "human", f"sample_rate_{final_sample_rate}_normalized")
                
        #         train_video_data_dir = os.path.join(video_data_dir, "train")
        #         test_video_data_dir = os.path.join(video_data_dir, "test")
                
        #         n_video_data = len(os.listdir(train_video_data_dir)) + len(os.listdir(test_video_data_dir))
        #         assert n_video_data == n_data, "number of video data is not correct"
                
        #         train_human_data_dir = os.path.join(human_data_dir, "train")
        #         test_human_data_dir = os.path.join(human_data_dir, "test")
                
        #         n_human_data = len(os.listdir(train_human_data_dir)) + len(os.listdir(test_human_data_dir))
        #         assert n_human_data == n_data, "number of human data is not correct"
        #     except:
        #         raise Exception("The dataset cannot load data in the train/test set!")

    def __len__(self):
        return len(self.data_identifiers)
    
    def __getitem__(self, index):
        subject_data = self.create_subject_level_data(index)
        
        if self.transform:
            subject_data = self.transform(subject_data)

        return subject_data

    #TODO move this out of dataloader
    def human_data_interpolation(self):
        '''
        if `interpolation_needed` is True,
        we need to do interpolation and save data into a new dataset
        '''
        if self.interpolation_needed == False:
            return
        final_sample_rate = int(self.sample_rate_dict["eeg"])
        human_data_dir = os.path.join(self.dataset_dir, "human", f"sample_rate_{final_sample_rate}")
        origin_human_data_dir = os.path.join(self.dataset_dir, "human", "origin_sample_rate_1000")
        if os.path.isdir(human_data_dir) and len(os.listdir(human_data_dir)) == len(os.listdir(origin_human_data_dir)):
            return
        # if excute here and not return, we need to interpolate the original data and save processes data into the above dir
        self.logger.info(f"Cannot find human data of sample rate {final_sample_rate} or the file number is incorrect, enter processing stage...")
        
        origin_file_name_list = os.listdir(origin_human_data_dir)
        # create the new folder
        os.makedirs(human_data_dir, exist_ok=True)
        
        origin_sr = self.sample_rate_dict["origin_human"]
        final_sr = self.sample_rate_dict["eeg"]
        
        # generate interpolated data on by one
        for file_name in tqdm(origin_file_name_list):
            src_file_path = os.path.join(origin_human_data_dir, file_name)
            dst_file_path = os.path.join(human_data_dir, file_name)
            src_data = pd.read_csv(src_file_path)
            dst_dict = {}

            n_point_origin = len(src_data["time"])
            n_point_final = int(n_point_origin / origin_sr * final_sr)
            
            dst_dict["time"] = np.arange(n_point_final) / final_sr
            
            # interpolated eeg data
            for eeg_channel in self.eeg_channel_labels:
                dst_dict[eeg_channel] = linear_interpolation_1d(src_data[eeg_channel], origin_sr=origin_sr, final_sr=final_sr)
            # interpolated canbus data
            for canbus_channel in self.canbus_channel_labels:
                dst_dict[canbus_channel] = linear_interpolation_1d(src_data[canbus_channel], origin_sr=origin_sr, final_sr=final_sr)
            
            dst_data = pd.DataFrame(dst_dict)
            dst_data.to_csv(dst_file_path, index=False)

    #TODO move this out of dataloader
    def video_data_preprocess(self):
        '''
        Preprocess the video data
        for the behind visual network
        '''
        origin_video_dir = os.path.join(self.dataset_dir, "video", "origin")
        processed_video_dir = os.path.join(self.dataset_dir, "video" ,f"{self.visual_network}_preprocessed")
        
        if os.path.exists(processed_video_dir) and len(os.listdir(processed_video_dir)) == len(os.listdir(origin_video_dir)):
            return

        self.logger.info("video data is not been preprocessed, process video data now...")
        os.makedirs(processed_video_dir, exist_ok=True)
        
        file_name_list = os.listdir(origin_video_dir)
        
        for file_name in tqdm(file_name_list):
            identifier = file_name.split(".")[0]
            
            file_path = os.path.join(origin_video_dir, file_name)
            video = torchvision.io.VideoReader(file_path, "video")
            
            frames = []
            for frame in video:
                frames.append(frame["data"])
            # shape (n_frame, C, H, W)
            visual_info = torch.stack(frames, dim=0)
            # convert video tensor into float
            # the original data is 8-bit integer (0-255)
            visual_info = visual_info.float()
            if self.visual_network == "resnet":
                # Before input the data into resnet, we need to do pre-process to each frame
                # followed by ResNet 50:
                # 1) The images are resized to resize_size=[232] using interpolation=InterpolationMode.BILINEAR
                # 2) followed by a central crop of crop_size=[224]
                # 3) Finally the values are first rescaled to [0.0, 1.0] and then normalized using mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]
                # Load the ResNet50 transforms
                weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V2
                preprocess = weights.transforms()
                # Apply transforms to each frame individually
                # visual_info = torch.stack([preprocess(frame) for frame in visual_info])
                visual_info = preprocess(visual_info)
                
                save_path = os.path.join(processed_video_dir, f"{identifier}.pt")
                torch.save(visual_info, save_path)
            else:
                raise Exception("Not implement yet...")
    
    def get_all_info(self) -> list[str]:
        data_identifier_list = []
        # info_file_path = os.path.join(self.dataset_dir, "info.json")
        split_file_path = os.path.join(self.dataset_dir, "split_info.json")
        with open(split_file_path, 'r') as fp:
            split_info = json.load(fp)
        
        if self.mode != "whole":
            for data_identifier in split_info[self.mode]:
                cur_condition, cur_subject_type, cur_subject_id, _ = data_identifier.split("_")
                if cur_condition in self.driving_conditions and cur_subject_type in self.subject_dict.keys() and cur_subject_id in self.subject_dict[cur_subject_type]:
                    data_identifier_list.append(os.path.join(self.mode, data_identifier))
        else:
            # self.mode == "whole"
            for data_identifier in split_info["train"]:
                cur_condition, cur_subject_type, cur_subject_id, _ = data_identifier.split("_")
                if cur_condition in self.driving_conditions and cur_subject_type in self.subject_dict.keys() and cur_subject_id in self.subject_dict[cur_subject_type]:
                    data_identifier_list.append(os.path.join("train", data_identifier))
            for data_identifier in split_info["test"]:
                cur_condition, cur_subject_type, cur_subject_id, _ = data_identifier.split("_")
                if cur_condition in self.driving_conditions and cur_subject_type in self.subject_dict.keys() and cur_subject_id in self.subject_dict[cur_subject_type]:
                    data_identifier_list.append(os.path.join("test", data_identifier))

        # for condition in self.driving_conditions:
        #     for subject_type, subject_id_list in self.subject_dict.items():
        #         for subject_id in subject_id_list:
        #             identifier = f"{condition}_{subject_type}_{subject_id}"
        #             count = info[identifier]
        #             for i in range(count):
        #                 data_identifier_list.append(f"{identifier}_part{i+1}")

        return data_identifier_list
    
    def get_visual_info(self, identifier):
        video_tensor_path = os.path.join(self.dataset_dir, "video", f"{self.visual_network}_preprocessed", f"{identifier}.pt")

        video_tensor = torch.load(video_tensor_path)
        if self.log_data_loading_process:
            self.logger.info(f"Loading visual info tensor, shape is (n_frame, C, H, W) <= {video_tensor.shape}")
        return video_tensor

    def create_subject_level_data(self, index: int):
        identifier = self.data_identifiers[index]
        if self.interpolation_needed:
            sample_rate = int(self.sample_rate_dict["eeg"])
            file_path = os.path.join(self.dataset_dir, "human", f"sample_rate_{sample_rate}_splited", f"{identifier}.csv")
        else:
            file_path = os.path.join(self.dataset_dir, "human", "origin_sample_rate_1000", f"{identifier}.csv")
        
        data = pd.read_csv(file_path)
        # identifier is "mode/condition_subject_type_subjectid_partxx"
        driving_condition, subject_type, subject_id, _ = identifier.split("_") # the last one is part_xx
        driving_condition = driving_condition.split("/")[1] # the first one is self.mode
        
        # visual_info = np.stack([data[visual_label] for visual_label in self.visual_labels], axis=1)
        visual_info = self.get_visual_info(identifier)
        eeg_info = np.stack([data[eeg_label] for eeg_label in self.eeg_channel_labels], axis=1)
        canbus_info = np.stack([data[canbus_label] for canbus_label in self.canbus_channel_labels], axis=1)
        
        visual_info, eeg_info, eeg_info_mask, canbus_info, canbus_info_mask = align_time_dimension(visual_info, eeg_info, canbus_info, n_sample_per_frame=self.n_sample_per_frame)
        
        ret_dict = {
            "driving_condition": driving_condition,
            "subject_type": subject_type,
            "subject_id": subject_id,
            "visual_info": visual_info, # (n_frame, C, H, W)
            "eeg_info": torch.from_numpy(eeg_info), # (n_sample, n_eeg_channel)
            "eeg_info_mask": torch.from_numpy(eeg_info_mask), # (n_sample,)
            "canbus_info": torch.from_numpy(canbus_info), # (n_sample, n_canbus_channel)
            "canbus_info_mask": torch.from_numpy(canbus_info_mask) # (n_sample,)
        }

        return ret_dict

    def collate_batch(self, batch_list):
        batch_size = len(batch_list)
        batch_dict = {}
        for key in batch_list[0].keys():
            batch_dict[key] = [batch_list[idx][key] for idx in range(batch_size)]
        
        ret_dict = {}
        for key, val_list in batch_dict.items():
            if key in ["driving_condition", "subject_type", "subject_id"]:
                ret_dict[key] = np.array(val_list)
            elif key == "visual_info":
                ret_dict[key], ret_dict[f"{key}_mask"] = align_batch_n_frame_dimension(val_list)
            elif key in ["eeg_info", "canbus_info"]:
                ret_dict[key], ret_dict[f"{key}_mask"] = align_batch_n_sample_dimension(info_list=val_list, info_mask_list=batch_dict[f"{key}_mask"])
            else:
                # mask are skipped
                pass

        return ret_dict