import os
import numpy as np
import torch
import argparse
from scipy import signal

def get_spectrogram(waveform, n_length):
    _, _, spectrogram = signal.stft(waveform, fs=1.0, nperseg=n_length,
                                    window='hann', nfft=None, noverlap=None, return_onesided=False)
    spectrogram = np.abs(spectrogram)
    # Obtain the magnitude of the STFT.
    # Add a `channels` dimension, so that the spectrogram can be used
    # as image-like input data with convolution layers.
    spectrogram = spectrogram[..., np.newaxis]
    return spectrogram


def process_batch_signal(signals, n_length=100):
    signal_spectrogram_list = []
    for record in signals:
        for i in range(len(record)):
            if (i == 0):
                spectrogram = get_spectrogram(record[i], n_length)
            else:
                spectrogram = np.concatenate((spectrogram, get_spectrogram(record[i], n_length)), axis=-1)

        img = np.abs(spectrogram)
        img = np.pad(img, ((0, 0), (19, 20), (0, 0)), 'constant', constant_values=(0))
        signal_spectrogram_list.append(img)

    return signal_spectrogram_list
    
class kFoldGenerator():

    k = -1  # the fold number
    x_list = []  # x list with length=k
    y_list = []  # x list with length=k

    # Initializate
    def __init__(self, x, y):
        if len(x) != len(y):
            assert False, 'Data generator: Length of x or y is not equal to k.'
        self.k = len(x)
        self.x_list = x
        self.y_list = y

    # Get i-th fold
    def getFold(self, i):
        train_data = np.empty(shape=(0,) + self.x_list[0].shape[1:])  
        train_targets = np.empty(shape=(0,) + self.y_list[0].shape[1:])
        val_data, val_targets = None, None
        for p in range(self.k):
            if p != i:
                train_data = np.concatenate((train_data, self.x_list[p])) 
                train_targets = np.concatenate((train_targets, self.y_list[p]))
            else:
                val_data = self.x_list[p]
                val_targets = self.y_list[p]
        return train_data, train_targets, val_data, val_targets

class SleepDataProcess:

    def __init__(self,rawdata_path,save_dir):
        self.save_dir = save_dir
        self.rawdata_path = rawdata_path
        os.makedirs(save_dir, exist_ok=True)
    
    def process_and_save_all_folds(self):

        data = np.load(self.rawdata_path, allow_pickle=True)
        Fold_Data = data['Fold_data']
        Fold_Label = data['Fold_label']
        DataGenerator = kFoldGenerator(Fold_Data, Fold_Label)
        
        print(f"start preprocess")
        print(f"out dir: {self.save_dir}")
        
        for i in range(10):
            print(f"start fold {i}...")
            
            train_x, train_y, val_x, val_y = DataGenerator.getFold(i)
            
            train_stft = process_batch_signal(train_x)
            
            val_stft = process_batch_signal(val_x)
            

            train_x_tensor = torch.tensor(train_x, dtype=torch.float32)
            train_y_tensor = torch.tensor(train_y, dtype=torch.float32)
            train_stft_tensor = torch.from_numpy(np.array(train_stft)).float()
            
            val_x_tensor = torch.tensor(val_x, dtype=torch.float32)
            val_y_tensor = torch.tensor(val_y, dtype=torch.float32)
            val_stft_tensor = torch.from_numpy(np.array(val_stft)).float()
            
            fold_data = {
                'train_x': train_x_tensor,
                'train_y': train_y_tensor,
                'train_stft': train_stft_tensor,
                'val_x': val_x_tensor,
                'val_y': val_y_tensor,
                'val_stft': val_stft_tensor
            }
            
            save_path = os.path.join(self.save_dir, f'stft_fold_{i}.pt')
            torch.save(fold_data, save_path)
        
        print("All data processed and saved!")

class SleepDataLoader:

    def __init__(self, stft_data_dir):
        self.stft_data_dir = stft_data_dir
    
    def getFold(self, i):

        fold_path = os.path.join(self.stft_data_dir, f'stft_fold_{i}.pt')
        if not os.path.exists(fold_path):
            raise FileNotFoundError(f"can not find {i} fold data file: {fold_path}")
        
        fold_data = torch.load(fold_path)
        return (fold_data['train_x'], fold_data['train_y'], fold_data['train_stft'],
                fold_data['val_x'], fold_data['val_y'], fold_data['val_stft'])

def main():
    parser = argparse.ArgumentParser(description='data preprocess tool')
    parser.add_argument('--config', '-c', default='./config.config')
    parser.add_argument('--rawdata', '-r', default='./data/ISRUC_S3.npz')
    parser.add_argument('--output', '-o', default='./data_s3')
    parser.add_argument('--test', '-t', action='store_true')

    args = parser.parse_args()
    
    processor = SleepDataProcess(args.rawdata, args.output)
    processor.process_and_save_all_folds()
    
    if args.test:
        print("start test data loader...")
        loader = SleepDataLoader(args.output)
        train_x, train_y, train_stft, val_x, val_y, val_stft = loader.getFold(0)

        print(f"train_x shape: {train_x.shape}")
        print(f"train_stft shape: {train_stft.shape}")
        print(f"train_y shape: {train_y.shape}")
        print(f"val_x shape: {val_x.shape}")
        print(f"val_stft shape: {val_stft.shape}")
        print(f"val_y shape: {val_y.shape}")

if __name__ == "__main__":
    main()