import glob
import os
from typing import List, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset

from utils.utils import torch_randint, torch_rand, describe_tensor
from utils.augment import apply_transforms


class NpyStorage:
    def __init__(self, root_dirs, load_to_ram, **kwargs):
        if isinstance(root_dirs, str):
            root_dirs = [root_dirs]
        self.files = [
            np.load(filename, **kwargs) if load_to_ram else filename
            for path in root_dirs
            for filename in glob.glob(os.path.join(path, "*.npy"))]
        assert self.files, f"List of files cannot be empty, root_dirs: {root_dirs}"
        self.load_to_ram = load_to_ram
        self.kwargs = kwargs

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        return self.files[idx] if self.load_to_ram else np.load(self.files[idx], **self.kwargs)


def mix(soi, interference, sinr_lo, sinr_hi, use_rand_phase):
    sinr_db = sinr_lo + (sinr_hi - sinr_lo) * torch_rand()
    coeff = 10 ** (-0.5 * sinr_db / 10)
    if use_rand_phase:
        rand_phase = torch_rand()
        coeff = coeff * np.exp(1j * 2 * np.pi * rand_phase)
    return soi + coeff * interference


def np_to_complex_tensor(arr, is_complex=True):
    arr = torch.from_numpy(arr)
    if not is_complex:
        arr = torch.view_as_complex(arr)
    return arr


def window_start(full_length, window_length, sync_by=1):
    max_start_pos = full_length - window_length
    return torch_randint(0, max_start_pos // sync_by + 1) * sync_by


def choose_window(x, window_length, sync_by=1):
    start_pos = window_start(x.numel(), window_length, sync_by)
    return x[start_pos:start_pos + window_length]


class RFDataset(Dataset):
    def __init__(
        self,
        soi,
        interference,
        sinr_lo,
        sinr_hi,
        signal_length,
        sync_soi_by=1,
        use_rand_phase=True,
        load_to_ram=True,
        transforms=[{"name": "rotate"}],
    ):
        self.soi = NpyStorage(soi, load_to_ram=load_to_ram)
        self.interference = NpyStorage(interference, load_to_ram=load_to_ram)
        self.sinr_lo = sinr_lo
        self.sinr_hi = sinr_hi
        self.signal_length = signal_length
        self.sync_soi_by = sync_soi_by
        if not use_rand_phase:
            transforms = []  # Backward compatibility
        self.transforms = transforms


    def __len__(self):
        return len(self.soi)

    def __getitem__(self, idx):
        soi = np_to_complex_tensor(self.soi[idx])
        interference_idx = torch_randint(0, len(self.interference))
        interference = np_to_complex_tensor(self.interference[interference_idx])
        interference = apply_transforms(interference, self.transforms)
        offset = window_start(soi.numel(), self.signal_length, self.sync_soi_by)
        soi = soi[offset:offset + self.signal_length]
        interference = choose_window(interference, self.signal_length)
        mixture = mix(soi, interference, self.sinr_lo, self.sinr_hi, use_rand_phase=False)
        return {
            "mixture": torch.view_as_real(mixture),
            "target": torch.view_as_real(soi),
            "offset": offset,
        }


class RFMixtureDataset(Dataset):
    def __init__(
        self,
        mixtures,
        signal_length,
        sync_soi_by=1,
        load_to_ram=True,
        is_complex=True,
    ):
        self.mixtures = NpyStorage(mixtures, load_to_ram=load_to_ram, allow_pickle=True)
        self.signal_length = signal_length
        self.sync_soi_by = sync_soi_by
        self.is_complex = is_complex

    def __len__(self):
        return len(self.mixtures)

    def __getitem__(self, idx):
        npy_file = self.mixtures[idx].item()
        soi = np_to_complex_tensor(npy_file["sample_soi"], is_complex=self.is_complex)
        mixture = np_to_complex_tensor(npy_file["sample_mix"], is_complex=self.is_complex)

        # If "offset" is there, it is supposed to be a deterministic dataset
        assert "offset" not in npy_file
        offset = window_start(mixture.numel(), self.signal_length, self.sync_soi_by)
        soi = soi[offset:offset + self.signal_length]
        mixture = mixture[offset:offset + self.signal_length]

        return {
            "mixture": torch.view_as_real(mixture),
            "target": torch.view_as_real(soi),
            "offset": offset
        }


class DeterministicDataset(Dataset):
    def __init__(
        self,
        mixtures,
        signal_length=None,
        sync_soi_by=1,
        load_to_ram=True,
        is_complex=True,
    ):
        self.mixtures = NpyStorage(mixtures, load_to_ram=load_to_ram, allow_pickle=True)
        self.signal_length = signal_length
        self.sync_soi_by = sync_soi_by
        self.is_complex = is_complex

    def __len__(self):
        return len(self.mixtures)

    def __getitem__(self, idx):
        npy_file = self.mixtures[idx].item()
        soi = np_to_complex_tensor(npy_file["sample_soi"], is_complex=self.is_complex)
        mixture = np_to_complex_tensor(npy_file["sample_mix"], is_complex=self.is_complex)
        offset = npy_file["offset"]
        if self.signal_length is not None:
            assert soi.shape[0] == self.signal_length
            assert mixture.shape[0] == self.signal_length
        assert offset % self.sync_soi_by == 0  # Verify synchronization

        return {
            "mixture": torch.view_as_real(mixture),
            "target": torch.view_as_real(soi),
            "offset": offset
        }
