import os
import numpy as np
from scipy.io import loadmat
from torch.utils.data import Dataset
from scipy.io import savemat

PROCESSED_DATA_DIR = "./data/processed_pretrain/4a"

class SeqAPDataset(Dataset):
    def __init__(self, phone_names, window_ap_num, seq_ap_num, time_step, time_window, data_type='covariance', data_suffix='agc_caled_csi'):
        self.ensemble_type = "window_ap_num-{0}_seq_ap_num-{1}_time_step-{2}_time_window-{3}".format(window_ap_num, seq_ap_num, time_step, time_window)
        self.phone_names = phone_names
        self.data_type = data_type
        self.data_suffix = data_suffix
        self.rx_num = 2
        self.file_name = "ap_seq_{0}_{1}".format(self.data_type, data_suffix)

        self.timestamps_offset = 500  # 500s time offset between each phone
        
        self.total_phone_names, self.timestamps, self.ap_coords, self.seq_rssi, self.seq_csi_covariance = self.__load_data()

        self.seq_num, self.frame_num, self.ensemble_num, self.seq_ap_num, _ = self.ap_coords.shape

        print("Initialized phone names", self.total_phone_names.shape)
        print("Initialized AP coords", self.ap_coords.shape)
        print("Initialized csi", self.seq_csi_covariance.shape)
        print("Initialized rssi", self.seq_rssi.shape)

    def get_data_labels(self, data_dim, is_normalize):
        seq_num, frame_num, ensemble_num, seq_ap_num, width, height = self.seq_csi_covariance.shape

        if data_dim == 2:
            vec_size = self.rx_num * self.rx_num * data_dim + data_dim
        elif data_dim == 3:
            vec_size = data_dim * self.rx_num * self.rx_num
        
        scaled_data = np.zeros((seq_num, frame_num, ensemble_num, seq_ap_num, vec_size))
        seq_csi_covariance = self.seq_csi_covariance
        seq_rssi = self.seq_rssi

        if is_normalize:
            for seq_idx in range(seq_num):
                for frame_idx in range(frame_num):
                    for ensemble_idx in range(ensemble_num):
                        for ap_idx in range(seq_ap_num):
                            curr_norm = np.linalg.norm(self.seq_csi_covariance[seq_idx][frame_idx][ensemble_idx][ap_idx], ord='fro')
                            if curr_norm != 0:
                                seq_csi_covariance[seq_idx][frame_idx][ensemble_idx][ap_idx] /= curr_norm

        seq_csi_covariance = np.reshape(seq_csi_covariance, (seq_num, frame_num, ensemble_num, seq_ap_num, -1))
        seq_rssi = seq_rssi/100
        if data_dim == 2:
            scaled_data[:, :, :, :, 0:2] = seq_rssi
            scaled_data[:, :, :, :, 2:6] = seq_csi_covariance.real
            scaled_data[:, :, :, :, 6:10] = seq_csi_covariance.imag
        elif data_dim == 3:
        # 将实部，虚部和幅值拼接在一起
            scaled_data[:, :, :, :, 0:4] = seq_csi_covariance.real
            scaled_data[:, :, :, :, 4:8] = seq_csi_covariance.imag
            scaled_data[:, :, :, :, 8:12] = np.abs(seq_csi_covariance)


        # (ap_coords, data, timestamps) = remove_nan_entries(self.ap_coords, scaled_data, self.timestamps)
        ap_coords = self.ap_coords
        data = scaled_data
        timestamps = self.timestamps

        print("Scaled ap_coords", ap_coords.shape)
        print("Scaled data", data.shape)
        print("Scaled timestamps", timestamps.shape)

        return ap_coords, scaled_data
    
    def __load_data(self):
        timestamps = list()
        ap_coords = list()
        seq_rssi = list()
        seq_csi_covariance = list()
        total_phone_names = list()
        for phone_idx, phone_name in enumerate(self.phone_names):
            curr_phone_dir = os.path.join(PROCESSED_DATA_DIR, self.ensemble_type, phone_name)

            exp_entries = list(os.scandir(curr_phone_dir))
            exp_entry_num = len(exp_entries)
            for exp_idx, exp_entry in enumerate(exp_entries):
                curr_phone_exp_dir = os.path.join(curr_phone_dir, exp_entry.name)
                target_file_path = os.path.join(curr_phone_exp_dir, self.file_name + '.mat')
                print("target_file_path", target_file_path)
                file_f = loadmat(target_file_path)
                time_offset = self.timestamps_offset * (phone_idx * exp_entry_num + exp_idx)
                curr_timestamps = file_f['timestamps'] + time_offset
                curr_ap_coords = file_f['ap_coords']
                curr_seq_rssi = file_f['seq_rssi']
                curr_seq_csi_covariance = file_f['seq_csi_covariance']

                timestamps.append(curr_timestamps)
                ap_coords.append(curr_ap_coords)
                seq_rssi.append(curr_seq_rssi)
                seq_csi_covariance.append(curr_seq_csi_covariance)

                size = curr_timestamps.shape
                curr_phone_name = np.full(size, phone_name, dtype=object)
                total_phone_names.append(curr_phone_name)

        total_phone_names = np.concatenate(total_phone_names, axis=0)
        timestamps = np.concatenate(timestamps, axis=0)
        ap_coords = np.concatenate(ap_coords, axis=0)
        seq_rssi = np.concatenate(seq_rssi, axis=0)
        seq_csi_covariance = np.concatenate(seq_csi_covariance, axis=0)
        # savemat("./learns/time.mat", {"timestamps": timestamps})

        return total_phone_names, timestamps, ap_coords, seq_rssi, seq_csi_covariance

    def __getitem__(self, index):
        return self.timestamps[index], self.seq_csi_covariance[index,:,:,:,:,:]
            
    def __len__(self):
        return self.seq_num

def remove_nan_entries(ap_coords, data, coords, timestamps):
    nan_mask = np.isnan(coords).any(axis=(1,2))

    filtered_ap_coords = ap_coords[~nan_mask]
    filtered_data = data[~nan_mask]
    filtered_coords = coords[~nan_mask]
    filtered_timestamps = timestamps[~nan_mask]
    
    return filtered_ap_coords, filtered_data, filtered_coords, filtered_timestamps


if __name__ == "__main__":
    phone_names = ['oneplus']
    window_ap_num = 6
    seq_ap_num = 5
    time_step = 0.5
    time_window = time_step
    data_dim = 2
    dataset = SeqAPDataset(phone_names=phone_names,  window_ap_num=window_ap_num, seq_ap_num=seq_ap_num, time_step=time_step, time_window=time_window, data_suffix='agc_caled_csi')
    dataset.get_data_labels(data_dim=data_dim, is_normalize=True)
