import torch
import torch.nn.functional as F
from einops import rearrange


# Data formats
#
# B - batch size
# s - signal length (in complex numbers)
# c - dimension of size 2 containing real and imag parts
#
# "stacked" -- (B, s, 2) - real and imag waveform are stacked in the last dim
# "interleaving" -- (B, 2 * s) - real and imag waveform alternate in the last dim
# "complex" -- (B, s), complex tensor
# "windows" -- (B, wc, 2 * ws) -  we split s into wc windows of size ws.
# in each window, real part occupies [0, ws), imag - [ws, 2 * ws)
# "wavenet" -- (B, 2, s)
#
# Now, the datasets give data in "stacked" format, while different models might use
# other formats


def stacked_to_wavenet(x):
    return rearrange(x, "b s c -> b c s")


def stacked_to_interleaving(x):
    return rearrange(x, "b s c -> b (s c)")


def stacked_to_windows(x, window_size):
    return rearrange(x, "b (wc ws) c -> b wc (c ws)", ws=window_size, c=2)


def stacked_to_windows_with_context(x, window_size, context_size):
    if isinstance(context_size, int):
        left_context = context_size
        right_context = 0
    else:
        left_context, right_context = context_size

    # Pads last dimension (2) by (0, 0) and second-to-last by (left_context, right_context)
    x = F.pad(x, (0, 0, left_context, right_context))

    total_window_size = left_context + window_size + right_context
    x = x.unfold(dimension=1, size=total_window_size, step=window_size)

    return rearrange(x, "b wc c ws -> b wc (c ws)")


def windows_to_complex(x):
    return torch.view_as_complex(rearrange(x, "b wc (c ws) -> b (wc ws) c", c=2))


def stacked_to_complex(x):
    return torch.view_as_complex(x)


def interleaving_to_complex(x):
    return torch.view_as_complex(rearrange(x, "b (s c) -> b s c", c=2))


def wavenet_to_complex(x):
    return torch.view_as_complex(rearrange(x, "b c s -> b s c").contiguous())


def wavenet_to_stacked(x):
    return rearrange(x, "b c s -> b s c").contiguous()


def complex_to_interleaving(x):
    return stacked_to_interleaving(torch.view_as_real(x))


def windows_to_stacked(x):
    return rearrange(x, "b wc (c ws) -> b (wc ws) c", c=2)


def interleaving_to_stacked(x):
    return rearrange(x, "b (s c) -> b s c", c=2)
