import os
import time
import numpy as np
import pandas as pd
import itertools
import torch

from scipy.special import comb
from scipy.io import loadmat, savemat
from torch.utils.data import Dataset

class Preprocess_Dataset(Dataset):
    def __init__(self, file_path: str, window_ap_num: int, seq_ap_num: int, time_step: float, time_window: float, frame_num: int):
        self.rx_num = 2         # For this dataset, it is fixed.
        self.file_path = file_path
        self.window_ap_num = window_ap_num
        self.seq_ap_num = seq_ap_num
        self.time_step = time_step
        self.time_window = time_window
        self.ensemble_num = int(comb(self.window_ap_num, self.seq_ap_num))
        self.frame_num = frame_num

        self.phone_name = self.file_path.split('/')[-2]
        self.file_name = self.file_path.split('/')[-1].split('.')[0]

        self.__load_data()
        self.remove_nan_names()

        # No data in this file
        if np.size(self.ap_names) == 0:
            print("* No data in for '{}'".format(file_path))
            return

        self.ap_coords = self.ap_name2ap_coords()

        self.__filter()

        # No data in this file
        if np.size(self.csi) == 0:
            print("* No data in for '{}'".format(file_path))
            return

        self.agc_caled_csi = self.__cal_agc_vec() # this function is more faster

        self.csi_covariance = self.__cal_covariance_vec(self.csi) # this function is more faster

        self.agc_caled_csi_covariance = self.__cal_covariance(self.agc_caled_csi)

        result = self.__construct_ap_seq_data(window_ap_num, time_step, time_window)

        if result is None:
            return  # 或者 continue（如果你在循环处理多个文件）
        
        (seq_timestamps, seq_ap_coords, seq_ap_names, seq_rssi, seq_csi_covariance, seq_agc_caled_csi_covariance) = result

        print("seq data: ")
        print("seq_timestamps", seq_timestamps.shape)
        print("seq_ap_coords, seq_ap_names", seq_ap_coords.shape, seq_ap_names.shape)
        print("seq_rssi, seq_csi_covariance, seq_agc_caled_csi_covariance", seq_rssi.shape, seq_csi_covariance.shape, seq_agc_caled_csi_covariance.shape)        
        
        # generate ensemble data
        (ensemble_timestamps, ensemble_ap_coords, ensemble_ap_names, ensemble_rssi, ensemble_csi_covariance, ensemble_agc_caled_csi_covariance) = self.__gen_ensemble_ap_seq(seq_timestamps, seq_ap_coords, seq_ap_names, seq_rssi, seq_csi_covariance, seq_agc_caled_csi_covariance)
        print("ensemble data: ")
        print("ensemble_timestamps", ensemble_timestamps.shape)
        print("ensemble_ap_coords, ensemble_ap_names, ", ensemble_ap_coords.shape, ensemble_ap_names.shape)
        print("ensemble_seq_rssi, ensemble_seq_csi_covariance, ensemble_seq_agc_caled_csi_covariance", ensemble_rssi.shape, ensemble_csi_covariance.shape, ensemble_agc_caled_csi_covariance.shape)        
    
        # save data
        self.__save_data(ensemble_timestamps, ensemble_ap_coords, ensemble_ap_names, ensemble_rssi, ensemble_csi_covariance, 'csi')
        self.__save_data(ensemble_timestamps, ensemble_ap_coords, ensemble_ap_names, ensemble_rssi, ensemble_agc_caled_csi_covariance, 'agc_caled_csi')

    def ap_name2ap_coords(self):
        ap_info_path = './document/5b5c/ap_hash.xlsx'
        df = pd.read_excel(ap_info_path)

        sample_num = self.ap_names.shape[0]
        ap_coords = np.zeros((sample_num, 2))

        for sample_idx in range(sample_num):
            target_name = self.ap_names[sample_idx]
            df['ap_name'] = df['ap_name'].astype(str)
            filtered_df = df[df['ap_name'].str.contains(str(int(target_name[0])), na=False)]
            
            x_axis = filtered_df['x_axis'].values
            y_axis = filtered_df['y_axis'].values
            x_axis = x_axis[0]
            y_axis = y_axis[0]

            ap_coords[sample_idx, 0] = x_axis
            ap_coords[sample_idx, 1] = y_axis

        return ap_coords

    def __load_data(self):
        file_f = loadmat(self.file_path)

        self.timestamps = file_f['timestamps']
        self.ap_names = file_f['ap_names']
        self.csi, self.rssi = file_f['csi_data'], file_f['rssi']
        self.channel = file_f['channel']
        
        self.csi_sts_nums = np.array(list(map(lambda curr_csi: curr_csi.shape[1], self.csi[0, :]))).reshape(-1, 1)

        # change the timestamps to relative time of the first packet
        timestamps_diff = self.timestamps - self.timestamps[0]
        self.timestamps = timestamps_diff

        return
    
    # check nan values and remove them
    def remove_nan_names(self):
        nan_mask = np.isnan(self.ap_names).any(axis=(1))

        self.timestamps = self.timestamps[~nan_mask]
        self.ap_names = self.ap_names[~nan_mask]
        self.rssi = self.rssi[~nan_mask]
        self.csi = self.csi[:,~nan_mask]
        self.channel = self.channel[~nan_mask]
        self.csi_sts_nums = self.csi_sts_nums[~nan_mask]

        print("filtered_timestamps", self.timestamps.shape)
        print("filtered_channel, filtered_ap_names", self.channel.shape, self.ap_names.shape)
        print("filtered_rssi, filtered_data", self.rssi.shape, self.csi.shape)        
        
        return 
    
    def __construct_ap_seq_data(self, K, time_step, time_window):
        # Algorithm logic:
        # Input: Mixed data from N APs, K (predefined number of APs in AP Seq)
        # Output:
        # 1. Move forward with a certain step size and window
        # 2. First determine how many valid APs are around a single moment
        #   1. For moments with too few APs (less than K), discard the data for that moment.
        #   2. For moments with more APs (greater than or equal to K), select AP data.
        #       1. Top K-RSSI
        #       2. Discreteness of CSI subcarriers
        # 3. After selecting the data, combine the AP data
        
        time_merge_timestamps, time_merge_ap_coords, time_merge_ap_names = list(), list(), list()
        time_merge_rssi, time_merge_csi_covariance, time_merge_agc_caled_csi_covariance = list(), list(), list()
        
        start_timestamp, stop_timestamp = int(self.timestamps[0]), int(self.timestamps[-1])+1
        left_timestamp = start_timestamp

        seq_count = 0
        for left_timestamp in np.arange(start_timestamp, stop_timestamp, time_step):
            merge_timestamps, merge_ap_coords, merge_ap_names = list(), list(), list()
            merge_rssi, merge_csi_covariance, merge_agc_caled_csi_covariance = list(), list(), list()
            right_timestamp = left_timestamp + time_window
            
            # choose data in the current window
            selected_indices = np.where((left_timestamp <= self.timestamps) & (self.timestamps <= right_timestamp))[0]
            (selected_timestamps, selected_ap_coords, selected_ap_names, selected_rssi, selected_csi_covariance, selected_agc_caled_csi_covariance) = self.__select_data(selected_indices)
            # sort by RSSI
            selected_rssi_mean = np.mean(selected_rssi, axis=1)
            sorted_indices = np.argsort(selected_rssi_mean)[::-1]   # get the index of AP RSSI in descending order
            (sorted_timestamps, sorted_ap_coords, sorted_ap_names, sorted_rssi, sorted_csi_covariance, sorted_agc_caled_csi_covariance) = self.__choose_data(selected_timestamps, selected_ap_coords,  selected_ap_names, selected_rssi, selected_csi_covariance, selected_agc_caled_csi_covariance, sorted_indices)

            # get the unique AP names
            unique_aps_num = len(np.unique(sorted_ap_coords))

            if unique_aps_num < K:
                continue

            # get the top K AP names
            top_k_ap_names = []
            for ap_name in sorted_ap_names:
                if ap_name not in top_k_ap_names:
                    top_k_ap_names.append(ap_name)
                if len(top_k_ap_names) == K:
                    break

            # get the index of the top K AP names, and choose the data
            top_k_data_indices = [i for i, ap_name in enumerate(sorted_ap_names) if ap_name in top_k_ap_names]
            (top_k_timestamps, top_k_ap_coords, top_k_ap_names, top_k_rssi, top_k_csi_covariance, top_k_agc_caled_csi_covariance) = self.__choose_data(sorted_timestamps, sorted_ap_coords, sorted_ap_names, sorted_rssi, sorted_csi_covariance, sorted_agc_caled_csi_covariance, top_k_data_indices)
            (list_timestamps, list_ap_coords, list_ap_names, list_rssi, list_csi_covariance, list_agc_caled_csi_covariance) = self.np2list(top_k_timestamps, top_k_ap_coords, top_k_ap_names, top_k_rssi, top_k_csi_covariance, top_k_agc_caled_csi_covariance)

            while True:
                flag, chosen_indices, curr_timestamps, curr_ap_coords, curr_ap_names, curr_rssi, curr_csi_covariance, curr_agc_caled_csi_covariance = self.get_seq_data(list_timestamps, list_ap_coords, list_ap_names, list_rssi, list_csi_covariance, list_agc_caled_csi_covariance, K)

                if not flag:
                    break

                curr_timestamps = np.mean(curr_timestamps)

                merge_timestamps.append(curr_timestamps)
                merge_ap_coords.append(curr_ap_coords)
                merge_ap_names.append(curr_ap_names)
                
                merge_rssi.append(curr_rssi)
                merge_csi_covariance.append(curr_csi_covariance)
                merge_agc_caled_csi_covariance.append(curr_agc_caled_csi_covariance)

                # remove the selected data, and continue to find the next data
                list_ap_names, list_ap_coords, list_rssi, list_csi_covariance, list_agc_caled_csi_covariance = self.remove_data(chosen_indices, list_ap_names, list_ap_coords, list_rssi, list_csi_covariance, list_agc_caled_csi_covariance)
            
                seq_count += 1
            
            if len(merge_timestamps) < self.frame_num:
                continue

            combined_timestamps = np.stack(merge_timestamps, axis=0)
            combined_ap_coords = np.stack(merge_ap_coords, axis=0)
            combined_ap_names = np.stack(merge_ap_names, axis=0)
            combined_rssi = np.stack(merge_rssi, axis=0)
            combined_csi_covariance = np.stack(merge_csi_covariance, axis=0)
            combined_agc_caled_csi_covariance = np.stack(merge_agc_caled_csi_covariance, axis=0)
            
            # 在1s内的所有帧数中选择12帧
            rand_nums = torch.randperm((combined_timestamps.shape[0]))[:self.frame_num]
            sort_rand_nums = np.sort(rand_nums.numpy())
            combined12_timestamps = combined_timestamps[sort_rand_nums]
            combined12_ap_coords = combined_ap_coords[sort_rand_nums, :, :]
            combined12_ap_names = combined_ap_names[sort_rand_nums, :, :]
            combined12_rssi = combined_rssi[sort_rand_nums, :, :]
            combined12_csi_covariance = combined_csi_covariance[sort_rand_nums, :, :, :]
            combined12_agc_caled_csi_covariance = combined_agc_caled_csi_covariance[sort_rand_nums, :, :, :]

            # 时间帧和坐标需要聚合(需要考虑，未必聚合)
            # combined12_timestamps = np.mean(combined12_timestamps)
            # combined12_coords = np.mean(combined12_coords, axis=0)
            time_merge_timestamps.append(combined12_timestamps)
            time_merge_ap_coords.append(combined12_ap_coords)
            time_merge_ap_names.append(combined12_ap_names)
                
            time_merge_rssi.append(combined12_rssi)
            time_merge_csi_covariance.append(combined12_csi_covariance)
            time_merge_agc_caled_csi_covariance.append(combined12_agc_caled_csi_covariance)

        if len(time_merge_timestamps) == 0:
            print(f"[WARNING] Empty data for file: {self.file_path}, skip.")
            return None
            
        stacked_timestamps = np.stack(time_merge_timestamps, axis=0)
        stacked_ap_coords = np.stack(time_merge_ap_coords, axis=0)
        stacked_ap_names = np.stack(time_merge_ap_names, axis=0)
        stacked_rssi = np.stack(time_merge_rssi, axis=0)
        stacked_csi_covariance = np.stack(time_merge_csi_covariance, axis=0)
        stacked_agc_caled_csi_covariance = np.stack(time_merge_agc_caled_csi_covariance, axis=0)
        
        stacked_timestamps = stacked_timestamps[:, :, np.newaxis]
        if stacked_ap_coords.shape[0]==1:
            pass
        else:
            stacked_ap_coords = np.squeeze(stacked_ap_coords)
            stacked_rssi = np.squeeze(stacked_rssi)
        
        return stacked_timestamps, stacked_ap_coords, stacked_ap_names, stacked_rssi, stacked_csi_covariance, stacked_agc_caled_csi_covariance
    
    def remove_data(self, indices_to_remove, list_ap_names, list_ap_coords, list_rssi, list_csi_covariance, list_agc_caled_csi_covariance):
        list_ap_names = [value for index, value in enumerate(list_ap_names) if index not in indices_to_remove]
        list_ap_coords = [value for index, value in enumerate(list_ap_coords) if index not in indices_to_remove]
        list_rssi = [value for index, value in enumerate(list_rssi) if index not in indices_to_remove]
        list_csi_covariance = [value for index, value in enumerate(list_csi_covariance) if index not in indices_to_remove]
        list_agc_caled_csi_covariance = [value for index, value in enumerate(list_agc_caled_csi_covariance) if index not in indices_to_remove]

        return list_ap_names, list_ap_coords, list_rssi, list_csi_covariance, list_agc_caled_csi_covariance

    def get_seq_data(self, list_timestamps, list_ap_coords, list_ap_names, list_rssi, list_csi_covariance, list_agc_caled_csi_covariance, K):
        curr_timestamps = np.zeros((K, 1))
        curr_ap_coords = np.zeros((K, 2))
        curr_ap_names = np.zeros((K, 1))
        curr_rssi = np.zeros((K, 2))
        curr_csi_covariance = np.zeros((K, self.rx_num, self.rx_num), dtype=complex)
        curr_agc_caled_csi_covariance = np.zeros((K, self.rx_num, self.rx_num), dtype=complex)

        # chosen_indices is used to identify the index of the selected data, which needs to be removed later
        chosen_ap_num = 0
        chosen_indices = list()
        data_num = len(list_ap_names)
        for data_idx in range(data_num):
            temp_timestamp = list_timestamps[data_idx]
            temp_ap_coords = list_ap_coords[data_idx]
            temp_ap_name = list_ap_names[data_idx]
            temp_rssi = list_rssi[data_idx]
            temp_csi_covariance = list_csi_covariance[data_idx]
            temp_agc_caled_csi_covariance = list_agc_caled_csi_covariance[data_idx]
            
            if np.isin(temp_ap_name, curr_ap_names):
                continue

            curr_timestamps[chosen_ap_num] = temp_timestamp
            curr_ap_coords[chosen_ap_num, :] = temp_ap_coords
            curr_ap_names[chosen_ap_num] = temp_ap_name
            curr_rssi[chosen_ap_num] = temp_rssi
            curr_csi_covariance[chosen_ap_num, :, :] = temp_csi_covariance
            curr_agc_caled_csi_covariance[chosen_ap_num, :, :] = temp_agc_caled_csi_covariance

            chosen_ap_num += 1
            
            chosen_indices.append(data_idx)

            if chosen_ap_num == K:
                break

        # flag is used to identify whether the data that meets the requirements has been found after the current traversal
        if chosen_ap_num == K:
            flag = True
        elif chosen_ap_num < K:
            flag = False
        
        return flag, chosen_indices, curr_timestamps, curr_ap_coords, curr_ap_names, curr_rssi, curr_csi_covariance, curr_agc_caled_csi_covariance
    
    def __gen_ensemble_ap_seq(self, seq_timestamps, seq_ap_coords, seq_ap_names, seq_rssi, seq_csi_covariance, seq_agc_caled_csi_covariance):
        # coords and timestamps is not needed to ensemble, because they are the same for all combinations
        # 生成所有从 7 个 AP 中选 6 个的组合，类似于数据增强
        seq_num, frame_num, window_ap_num, height, width = seq_csi_covariance.shape

        # generate all combinations indicies of the APs
        ensemble_combinations = np.array(list(itertools.combinations(range(window_ap_num), seq_ap_num)))

        ensemble_seq_ap_coords = np.zeros((seq_num, frame_num, self.ensemble_num, seq_ap_num, 2))
        ensemble_seq_ap_names = np.zeros((seq_num, frame_num, self.ensemble_num, seq_ap_num, 1))
        ensemble_seq_rssi = np.zeros((seq_num, frame_num, self.ensemble_num, seq_ap_num, 2))
        ensemble_seq_csi_covariance = np.zeros((seq_num, frame_num, self.ensemble_num, seq_ap_num, height, width), dtype=complex)
        ensemble_seq_agc_caled_csi_covariance = np.zeros((seq_num, frame_num, self.ensemble_num, seq_ap_num, height, width), dtype=complex)

        # loop through the combination indices and generate the result array
        for comb_idx, comb in enumerate(ensemble_combinations):
            ensemble_seq_ap_coords[:,:,comb_idx,:,:] = seq_ap_coords[:, :, comb, :]
            ensemble_seq_ap_names[:,:,comb_idx,:,:] = seq_ap_names[:, :,comb, :]
            ensemble_seq_rssi[:,:,comb_idx,:,:] = seq_rssi[:, :,comb, :]
            ensemble_seq_csi_covariance[:,:,comb_idx,:,:,:] = seq_csi_covariance[:, :,comb, :, :]
            ensemble_seq_agc_caled_csi_covariance[:,:,comb_idx,:,:,:] = seq_agc_caled_csi_covariance[:, :,comb,:,:]

        return (seq_timestamps, ensemble_seq_ap_coords, ensemble_seq_ap_names, ensemble_seq_rssi, ensemble_seq_csi_covariance, ensemble_seq_agc_caled_csi_covariance)

    def np2list(self, top_k_timestamps, top_k_ap_coords, top_k_ap_names, top_k_rssi, top_k_csi_covariance, top_k_agc_caled_csi_covariance):
        list_timestamps = top_k_timestamps.tolist()
        list_ap_coords = top_k_ap_coords.tolist()
        list_ap_names = top_k_ap_names.tolist()
        list_rssi = top_k_rssi.tolist()
        list_csi_covariance = top_k_csi_covariance.tolist()
        list_agc_caled_csi_covariance = top_k_agc_caled_csi_covariance.tolist()

        # list to np.array
        list_timestamps = [np.array(sublist) for sublist in list_timestamps]
        list_ap_coords = [np.array(sublist) for sublist in list_ap_coords]
        list_ap_names = [np.array(sublist) for sublist in list_ap_names]
        list_rssi = [np.array(sublist) for sublist in list_rssi]
        list_csi_covariance = [np.array(sublist) for sublist in list_csi_covariance]
        list_agc_caled_csi_covariance = [np.array(sublist) for sublist in list_agc_caled_csi_covariance]

        return list_timestamps, list_ap_coords, list_ap_names, list_rssi, list_csi_covariance, list_agc_caled_csi_covariance
    
    def __choose_data(self, timestamps, ap_coords, ap_names, rssi, csi_covariance, agc_caled_csi_covariance, indices):
        chosen_timestamps = timestamps[indices]
        chosen_ap_coords = ap_coords[indices]
        chosen_ap_names = ap_names[indices]
        chosen_rssi = rssi[indices, :]
        chosen_csi_covariance = csi_covariance[indices, :, :]
        chosen_agc_caled_csi_covariance = agc_caled_csi_covariance[indices, :, :]

        return (chosen_timestamps, chosen_ap_coords, chosen_ap_names, chosen_rssi, chosen_csi_covariance, chosen_agc_caled_csi_covariance) 

    def __select_data(self, selected_indices):
        selected_timestamps = self.timestamps[selected_indices]
        selected_ap_coords = self.ap_coords[selected_indices]
        selected_ap_names = self.ap_names[selected_indices]
        selected_rssi = self.rssi[selected_indices, :]
        selected_csi_covariance = self.csi_covariance[selected_indices, :, :]
        selected_agc_caled_csi_covariance = self.agc_caled_csi_covariance[selected_indices,:,:]
            
        return (selected_timestamps, selected_ap_coords, selected_ap_names, selected_rssi, selected_csi_covariance, selected_agc_caled_csi_covariance)

    # filter out the data that is not 5GHz and dual-stream data
    def __filter(self):
        channel_num = (self.channel / 1e3).astype(int)
        filter_channel_idx = channel_num > 2
        filter_sts_idx = self.csi_sts_nums == 1
        filter_idx = filter_channel_idx & filter_sts_idx

        self.__filter_data(filter_idx)

        self.csi = np.array([np.array(curr_csi) for curr_csi in self.csi])

        return
    
    def __cal_agc_vec(self):
        """
        Fully vectorized implementation of AGC calculation
        """
        # calculate the sum of the square of the amplitudes of the CSI
        amp_square_sum = np.sum(np.abs(self.csi) ** 2, axis=-1)  # shape: (batch_size, rx_num, tx_num)
        
        # transform RSSI to linear scale and expand the dimension to match tx_num
        rssi_linear = 10 ** (self.rssi[..., np.newaxis] / 10)  # shape: (batch_size, rx_num, 1)
        
        # calculate the AGC scalar
        eps = 1e-10
        agc_scalar = np.sqrt(rssi_linear / (amp_square_sum + eps))  # shape: (batch_size, rx_num, tx_num)
        
        # expand the dimension to match carrier_num and apply to CSI
        caled_csi = agc_scalar[..., np.newaxis] * self.csi  # shape: (batch_size, rx_num, tx_num, carrier_num)

        return caled_csi
    
    def __cal_covariance_vec(self, csi):
        point_num, rx_num, tx_num, carrier_num = csi.shape[0], csi.shape[1], csi.shape[2], csi.shape[3]
        
        # Reshape CSI to combine tx_num and carrier_num dimensions
        # New shape: (point_num, rx_num, tx_num*carrier_num)
        csi_reshaped = csi.reshape(point_num, rx_num, -1)
        # calculate the covariance matrix
        covariances = np.matmul(csi_reshaped, csi_reshaped.conj().transpose(0, 2, 1))
        
        # Adjust the shape of the covariance matrix to (point_num, rx_num, rx_num)
        covariances = covariances.reshape(point_num, rx_num, rx_num)
            
        return covariances
    
    def __cal_agc(self):
        agc_scalar = np.array(list(map(self.__get_agc_scalar, self.csi, self.rssi)))
        # print("agc_scalar", agc_scalar.shape)
        # print("csi", self.csi.shape)
        caled_csi = agc_scalar[:, :, :, np.newaxis] * self.csi

        return caled_csi

    def __get_agc_scalar(self, csi, rssi):
        rx_num, tx_num, carrier_num = csi.shape
        agc_scalar = np.zeros((rx_num, tx_num))
        eps = 1e-10  # Small epsilon to prevent division by zero

        for rx_idx in range(rx_num):
            for tx_idx in range(tx_num):
                amp_square_sum = np.sum(np.abs(csi[rx_idx, tx_idx, :]) ** 2)
                agc_scalar[rx_idx, tx_idx] = np.sqrt(10**(rssi[rx_idx]/10) / amp_square_sum + eps)
        
        return agc_scalar
    
    def __cal_covariance(self, csi):
        point_num = csi.shape[0]
        
        covariances = np.zeros((point_num, self.rx_num, self.rx_num), dtype=complex)
        for point_idx, curr_csi in enumerate(csi):
            rx_num, tx_num, carrier_num = curr_csi.shape
            covariance = 0
            for carrier_idx in range(carrier_num):
                covariance += np.matmul(curr_csi[:, :, carrier_idx], curr_csi[:, :, carrier_idx].T.conj())
            covariances[point_idx, :, :] = covariance
            
        return covariances

    def __save_data(self, timestamps, ap_coords, ap_names, rssi, csi_covariance, suffix):
        data_construction_type = "window_ap_num-{0}_seq_ap_num-{1}_time_step-{2}_time_window-{3}".format(self.window_ap_num, self.seq_ap_num, self.time_step, self.time_window)
        processed_dir = os.path.join(PROCESSED_DATA_DIR, data_construction_type)
        save_dir = os.path.join(processed_dir, self.phone_name, self.file_name)
        print("curr_save_dir: ", save_dir)

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        save_path = os.path.join(save_dir, "ap_seq_covariance_{}.mat".format(suffix))

        # timestamps, ap_coords, coords, ap_names, rssi, data = remove_nan_entries(timestamps, ap_coords, coords, ap_names, rssi, csi_covariance)
        
        savemat(save_path, {
            'timestamps': timestamps,
            'ap_coords': ap_coords,
            'ap_names': ap_names,
            'seq_rssi': rssi,
            'seq_csi_covariance': csi_covariance
        })
        
        return 

    def __filter_data(self, data_idx):
        data_idx = data_idx.reshape(-1)
        self.timestamps = self.timestamps[data_idx==1]
        self.rssi = self.rssi[data_idx==1, :]
        self.channel = self.channel[data_idx==1]
        self.csi = self.csi[0, data_idx==1]
        self.ap_coords = self.ap_coords[data_idx==1]

        return

# check nan values and remove them
def remove_nan_entries(timestamps, ap_coords, coords, ap_names, rssi, data):
    nan_mask = np.isnan(coords).any(axis=(1,2))

    filtered_timestamps = timestamps[~nan_mask]
    filtered_ap_coords = ap_coords[~nan_mask]
    filtered_coords = coords[~nan_mask]
    filtered_ap_names = ap_names[~nan_mask]
    filtered_rssi = rssi[~nan_mask]
    filtered_data = data[~nan_mask]

    print("filtered_timestamps, filtered_coords", filtered_timestamps.shape, filtered_coords.shape)
    print("filtered_ap_coords, filtered_ap_names", filtered_ap_coords.shape, filtered_ap_names.shape)
    print("filtered_rssi, filtered_data", filtered_rssi.shape, filtered_data.shape)        
    
    return filtered_timestamps, filtered_ap_coords, filtered_coords, filtered_ap_names, filtered_rssi, filtered_data


if __name__ == "__main__":
    DATA_DIR = "./data/raw/5b5c/"
    PROCESSED_DATA_DIR = "./data/processed_pretrain/5b5c/"

    seq_ap_nums = [5]
    frame_num = 11

    for seq_ap_num in seq_ap_nums:
        window_ap_num = seq_ap_num + 1

        time_step = 0.5  # Unit: s
        time_window = time_step

        target_data_dir = DATA_DIR
        phone_names = [name for name in os.listdir(target_data_dir) if os.path.isdir(os.path.join(target_data_dir, name))]
        for phone_name in phone_names:
            curr_phone_dir = os.path.join(target_data_dir, phone_name)

            exp_entries = os.scandir(curr_phone_dir)
            for exp_entry in exp_entries:
                curr_phone_exp_file = os.path.join(curr_phone_dir, exp_entry.name)

                print("curr_phone_exp_file: ", curr_phone_exp_file)
                Preprocess_Dataset(file_path=curr_phone_exp_file, window_ap_num=window_ap_num, seq_ap_num=seq_ap_num, time_step=time_step, time_window=time_window, frame_num=frame_num)


