from torch.utils.data import Dataset
import torch
import numpy as np
import os
import glob


class SequenceDataset(Dataset):
    def __init__(self, root_dir, signal_length, sync=False, shift=16):
        self.soi_files = glob.glob(os.path.join(root_dir, "*.npy"))
        self.signal_length = signal_length
        self.sync = sync
        self.shift = shift

    def __len__(self):
        return len(self.soi_files)

    def __getitem__(self, idx):
        signal = np.load(self.soi_files[idx])
        start_pos = np.random.randint(signal.shape[0] - self.signal_length + 1)
        if self.sync:
            start_pos -= start_pos % self.shift

        signal = signal[start_pos:start_pos + self.signal_length]
        signal = torch.view_as_real(torch.from_numpy(signal)).to(torch.float32)
        return signal.reshape((2 * self.signal_length,))


class MixtureDataset(Dataset):
    def __init__(self, soi_dir, interference_dir, signal_length, use_rand_phase, noise_L=-33, noise_R=3):
        self.soi_files = glob.glob(os.path.join(soi_dir, "*.npy"))
        self.interference_files = glob.glob(os.path.join(interference_dir, "*.npy"))
        self.signal_length = signal_length
        self.use_rand_phase = use_rand_phase
        self.noise_L = noise_L
        self.noise_R = noise_R

    def __len__(self):
        return len(self.soi_files)

    def __getitem__(self, idx):
        soi = np.load(self.soi_files[idx])
        start_pos = np.random.randint(soi.shape[0] - self.signal_length + 1)
        soi = soi[start_pos:start_pos + self.signal_length]

        interference_idx = np.random.randint(len(self.interference_files))
        interference = np.load(self.interference_files[interference_idx])
        start_pos = np.random.randint(interference.shape[0] - self.signal_length + 1)
        interference = interference[start_pos:start_pos + self.signal_length]

        sinr_db = self.noise_L + np.random.rand() * (self.noise_R - self.noise_L)
        coeff = 10 ** (-0.5 * sinr_db / 10)
        if self.use_rand_phase:
            rand_phase = np.random.rand()
            coeff = coeff * np.exp(1j * 2 * np.pi * rand_phase)
        mixture = soi + coeff * interference

        mixture = torch.view_as_real(torch.from_numpy(mixture)).to(torch.float32)
        return mixture.reshape((2 * self.signal_length,))
