import os.path as osp
import torch
import numpy as np
from torch.utils.data import Dataset
from scipy.signal import stft
from ..utils.signal_data_augmentation import data_augmentation

class PretrainDataset(Dataset):
    def __init__(self, samples, labels, SNR, signal_length):
        self.samples = samples
        self.SNR = SNR
        self.labels = labels

        avg = 0
        std = 0.001
        # sample_length = 1024
        for i in range(len(self.samples)):
            try:
                N = self.samples[i].shape[1]
                if N < signal_length:
                    # Create padding with noise
                    padding = np.random.normal(avg, std, size=(2, signal_length - N))
                    # Concatenate original sample with padding
                    self.samples[i] = np.concatenate([self.samples[i], padding], axis=1)
                elif N > signal_length:
                    # Truncate to sample_length
                    self.samples[i] = self.samples[i][:, :signal_length]
            except:
                print(f"Error: {self.samples[i]}")
                return None


    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        _, _, stp = stft(self.samples[idx][0,:], 1.0, 'blackman', 31, 30, 128)
        agumented_samples = data_augmentation(self.samples[idx])
        _, _, stp_agumented = stft(agumented_samples[0,:], 1.0, 'blackman', 31, 30, 128)
        try:
            IQ_original = torch.Tensor(self.samples[idx])
            IQ_agumented = torch.Tensor(agumented_samples)

            stp_original = torch.Tensor(np.expand_dims(stp[:32,:], 0))
            stp_agumented = torch.Tensor(np.expand_dims(stp_agumented[:32,:], 0))
        except:
            print(f"Error: {self.samples[idx]}")
            return None

        return IQ_original, IQ_agumented, stp_original, stp_agumented