import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from vector_quantize_pytorch import VectorQuantize, ResidualVQ
from quant.fsq import FSQ, fsq_level_book


def generate_causal_attn_mask(seq_len, lookahead_tokens, device):
    return torch.tril(
        torch.ones((seq_len, seq_len), dtype=torch.bool, device=device),
        diagonal=lookahead_tokens
    )


def make_causal_conv1d_weight_mask(
        weight: torch.Tensor,
        lookahead_tokens: int = 0,
        *,
        is_transposed: bool = False
) -> torch.Tensor:
    w = weight if not is_transposed else weight.permute(1, 0, 2)

    c_out, c_in, k = w.shape
    centre = k // 2
    limit = centre + lookahead_tokens
    limit = min(limit, k)

    tap_mask = w.new_zeros(k)
    tap_mask[:limit] = 1.0
    tap_mask = tap_mask.view(1, 1, k).expand(c_out, c_in, k)

    return tap_mask if not is_transposed else tap_mask.permute(1, 0, 2)


class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels, factor, is_causal=False, lookahead_tokens=0):
        super().__init__()
        self.is_causal = is_causal
        self.lookahead_tokens = lookahead_tokens
        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=2 * factor + 1,
            stride=factor,
            padding=factor,
        )

    def forward(self, x):
        if not self.is_causal:
            return self.conv(x)

        with torch.no_grad():
            mask = make_causal_conv1d_weight_mask(
                self.conv.weight,
                self.lookahead_tokens,
                is_transposed=False
            )

        w_masked = self.conv.weight * mask
        return F.conv1d(
            x, w_masked,
            bias=self.conv.bias,
            stride=self.conv.stride,
            padding=self.conv.padding,
            dilation=self.conv.dilation,
            groups=self.conv.groups,
        )


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels, factor, is_causal=False, lookahead_tokens=0):
        super().__init__()
        self.is_causal = is_causal
        self.lookahead_tokens = lookahead_tokens
        self.conv = nn.ConvTranspose1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=2 * factor,
            stride=factor,
            padding=(factor // 2) + (factor % 2),
            output_padding=factor % 2
        )

    def forward(self, x):
        if not self.is_causal:
            return self.conv(x)

        with torch.no_grad():
            mask = make_causal_conv1d_weight_mask(
                self.conv.weight,
                self.lookahead_tokens,
                is_transposed=True
            )

        w_masked = self.conv.weight * mask
        return F.conv_transpose1d(
            x, w_masked,
            bias=self.conv.bias,
            stride=self.conv.stride,
            padding=self.conv.padding,
            output_padding=self.conv.output_padding,
            groups=self.conv.groups,
            dilation=self.conv.dilation,
        )


class ConvBlock1d(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 1,
        dilation: int = 1,
        num_groups: int = 8,
        is_causal: bool = False,
        lookahead_tokens: int = 0
    ):
        super().__init__()
        self.is_causal = is_causal
        self.lookahead_tokens = lookahead_tokens
        self.norm = nn.GroupNorm(num_groups=min(num_groups, in_channels), num_channels=in_channels)
        self.activation = nn.SiLU()

        self.project = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
        )

    def forward(self, x):
        x = self.norm(x)
        x = self.activation(x)

        if self.is_causal:
            with torch.no_grad():
                tmp_mask = make_causal_conv1d_weight_mask(self.project.weight, self.lookahead_tokens).clone()

            masked_weight = self.project.weight * tmp_mask
            return F.conv1d(
                x,
                masked_weight,
                bias=self.project.bias,
                stride=self.project.stride,
                padding=self.project.padding,
                dilation=self.project.dilation,
                groups=self.project.groups,
            )
        else:
            return self.project(x)


class ResnetBlock1d(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 1,
        is_causal: bool = False,
        lookahead_tokens: int = 0
    ):
        super().__init__()
        self.block1 = ConvBlock1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            is_causal=is_causal,
            lookahead_tokens=lookahead_tokens
        )

        self.block2 = ConvBlock1d(
            in_channels=out_channels,
            out_channels=out_channels,
            is_causal=is_causal,
            lookahead_tokens=lookahead_tokens
        )

        self.to_out = (
            nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
            if in_channels != out_channels
            else nn.Identity()
        )

    def forward(self, x):
        return self.block2(self.block1(x)) + self.to_out(x)


class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, mlp_ratio=4.0, dropout=0.0, inner_layers=0):
        super().__init__()
        hidden_dim = int(in_dim * mlp_ratio)

        layers = [nn.Linear(in_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout)]
        for i in range(inner_layers):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout)])
        layers.extend([nn.Linear(hidden_dim, out_dim), nn.Dropout(dropout)])

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


def build_rope_frequencies(d_model: int, max_position: int) -> torch.Tensor:
    half   = d_model // 2
    theta  = 10000.0 ** (-2 * torch.arange(half).float() / d_model)
    pos    = torch.arange(max_position, dtype=torch.float).unsqueeze(1)
    angle  = pos * theta
    return torch.polar(torch.ones_like(angle), angle)


def apply_rotary_emb(x: torch.Tensor,
                     freqs_cis: torch.Tensor,
                     start_idx: int = 0) -> torch.Tensor:
    B, T, D = x.shape
    freqs   = freqs_cis[start_idx : start_idx + T]
    xc      = torch.view_as_complex(x.view(B, T, D//2, 2).float())
    xc      = xc * freqs.unsqueeze(0)
    return torch.view_as_real(xc).view(B, T, D).type_as(x)


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, c, num_heads, is_causal=False, lookahead_tokens=0, use_rope=False, max_position=40960):
        super(MultiHeadSelfAttention, self).__init__()
        assert c % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = c // num_heads
        self.is_causal = is_causal
        self.use_rope = use_rope
        self.lookahead_tokens = lookahead_tokens

        if self.use_rope:
            self.register_buffer("freqs_cis", build_rope_frequencies(self.head_dim, max_position), persistent=False)
        else:
            self.freqs_cis = None

        self.query = nn.Linear(c, c)
        self.key = nn.Linear(c, c)
        self.value = nn.Linear(c, c)
        self.out_proj = nn.Linear(c, c)
        self.scale = self.head_dim ** 0.5

    def forward(self, x):
        B, N, C = x.size()

        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        Q = Q.view(B, N, self.num_heads, self.head_dim)
        K = K.view(B, N, self.num_heads, self.head_dim)
        V = V.view(B, N, self.num_heads, self.head_dim)

        if self.use_rope and self.freqs_cis is not None:
            Q = Q.permute(0, 2, 1, 3).reshape(B * self.num_heads, N, self.head_dim)
            K = K.permute(0, 2, 1, 3).reshape(B * self.num_heads, N, self.head_dim)

            Q = apply_rotary_emb(Q, self.freqs_cis[:N])
            K = apply_rotary_emb(K, self.freqs_cis[:N])

            Q = Q.view(B, self.num_heads, N, self.head_dim).permute(0, 2, 1, 3)
            K = K.view(B, self.num_heads, N, self.head_dim).permute(0, 2, 1, 3)

        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        if self.is_causal:
            mask = generate_causal_attn_mask(N, self.lookahead_tokens, device=attn_scores.device)
            attn_scores = attn_scores.masked_fill(~mask.unsqueeze(0).unsqueeze(1), float('-inf'))

        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)

        attn_output = attn_output.transpose(1, 2).contiguous().view(B, N, C)

        output = self.out_proj(attn_output)
        return output


class TransformerBlock(nn.Module):
    def __init__(self, c, num_heads, mlp_ratio=4.0, dropout=0.1, is_causal=False, lookahead_tokens=0, use_rope=False, max_position=40960):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(c)
        self.attn = MultiHeadSelfAttention(
            c,
            num_heads,
            is_causal=is_causal,
            lookahead_tokens=lookahead_tokens,
            use_rope=use_rope,
            max_position=max_position
        )
        self.dropout1 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(c)
        self.mlp = MLP(in_dim=c, out_dim=c, mlp_ratio=mlp_ratio, dropout=dropout)

    def forward(self, x):
        x = x + self.dropout1(self.attn(self.norm1(x)))
        x = x + self.mlp(self.norm2(x))
        return x


class DownSampler(nn.Module):
    def __init__(self, patch_channels, channels, emb_channels, resnet_count, use_norm=False,
                 num_transformer_blocks=0, num_heads=8, mlp_ratio=4.0, dropout=0.0,
                 mlp_channels=128, proj_channels=64, is_causal=False, lookahead_tokens=0, use_rope=False):
        super().__init__()
        self.patch_channels = patch_channels
        self.mlp_channels = mlp_channels

        self.use_norm = use_norm
        self.proj_channels = proj_channels

        if self.use_norm:
            self.norm_proj = nn.Linear(1, proj_channels)

        self.mlp = MLP(in_dim=patch_channels, out_dim=mlp_channels - (self.proj_channels if use_norm else 0))

        lookahead = 1 if lookahead_tokens > 0 else 0

        layers = [ResnetBlock1d(in_channels=mlp_channels,
                                out_channels=channels[0],
                                kernel_size=7,
                                padding=3,
                                is_causal=is_causal,
                                lookahead_tokens=lookahead)]

        for i in range(len(channels) - 1):
            layers.append(DownSample(in_channels=channels[i],
                                     out_channels=channels[i + 1],
                                     factor=2,
                                     is_causal=is_causal,
                                     lookahead_tokens=lookahead))
            for _ in range(resnet_count):
                layers.append(ResnetBlock1d(in_channels=channels[i + 1],
                                            out_channels=channels[i + 1],
                                            is_causal=is_causal,
                                            lookahead_tokens=lookahead,
                                            ))

        self.net = nn.Sequential(*layers)

        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(
                c=channels[-1],
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout,
                is_causal=is_causal,
                lookahead_tokens=lookahead_tokens // (2**(len(channels) - 1)) // num_transformer_blocks,
                use_rope=use_rope
            )
            for _ in range(num_transformer_blocks)
        ])

        self.final_conv = nn.Conv1d(in_channels=channels[-1], out_channels=emb_channels, kernel_size=1)

    def forward(self, x):
        assert x.shape[1] % self.patch_channels == 0
        n = x.shape[1] // self.patch_channels
        x = rearrange(x, "b (n pc) -> (b n) pc", n=n, pc=self.patch_channels)
        x = self.mlp(x)
        x = rearrange(x, "(b n) c -> b c n", n=n, c=self.mlp_channels - (self.proj_channels if self.use_norm else 0))

        if self.use_norm:
            c = x.shape[1]
            norm = ((x ** 2).sum(dim=1).sqrt() / c ** 0.25)
            norm = norm.unsqueeze(1)
            log_norm = torch.log(norm + 1e-6)
            log_norm = (self.norm_proj(torch.transpose(log_norm, 1, 2)))
            log_norm = torch.transpose(log_norm, 1, 2)
            x = torch.cat((x / norm, log_norm), dim=1)

        x = self.net(x)

        x = x.permute(0, 2, 1)
        for block in self.transformer_blocks:
            x = block(x)
        x = x.permute(0, 2, 1)

        return self.final_conv(x)


class UpSampler(nn.Module):
    def __init__(self, patch_channels, channels, emb_channels, resnet_count, use_norm=False,
                 num_transformer_blocks=0, num_heads=8, mlp_ratio=4.0, dropout=0.0,
                 mlp_channels=128, proj_channels=64, is_causal=False, lookahead_tokens=0, use_rope=False):
        super().__init__()
        self.patch_channels = patch_channels
        self.mlp_channels = mlp_channels

        self.use_norm = use_norm
        self.proj_channels = proj_channels

        if self.use_norm:
            self.norm_proj = nn.Linear(proj_channels, 1)

        self.first_conv = nn.Conv1d(in_channels=emb_channels, out_channels=channels[-1], kernel_size=1)

        lookahead = 1 if lookahead_tokens > 0 else 0

        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(
                c=channels[-1],
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout,
                is_causal=is_causal,
                lookahead_tokens=lookahead_tokens // (2**(len(channels) - 1)) // num_transformer_blocks,
                use_rope=use_rope
            )
            for _ in range(num_transformer_blocks)
        ])

        layers = []

        for i in range(len(channels) - 2, -1, -1):
            layers.append(UpSample(in_channels=channels[i + 1],
                                   out_channels=channels[i],
                                   factor=2,
                                   is_causal=is_causal,
                                   lookahead_tokens=lookahead))
            for _ in range(resnet_count):
                layers.append(ResnetBlock1d(in_channels=channels[i],
                                            out_channels=channels[i],
                                            is_causal=is_causal,
                                            lookahead_tokens=lookahead))

        layers.append(ResnetBlock1d(in_channels=channels[0],
                                out_channels=mlp_channels,
                                kernel_size=7,
                                padding=3,
                                is_causal=is_causal,
                                lookahead_tokens=lookahead))

        self.net = nn.Sequential(*layers)
        self.mlp = MLP(in_dim=mlp_channels - (self.proj_channels if use_norm else 0), out_dim=patch_channels)

    def forward(self, x):
        x = self.first_conv(x)

        x = x.permute(0, 2, 1)
        for block in self.transformer_blocks:
            x = block(x)
        x = x.permute(0, 2, 1)

        x = self.net(x)

        if self.use_norm:
            x, log_norm = x[:, :-self.proj_channels], x[:, self.proj_channels:]
            log_norm = self.norm_proj(torch.transpose(log_norm, 1, 2))
            log_norm = torch.transpose(log_norm, 1, 2)
            norm = torch.exp(log_norm)
            x = x * norm

        b = x.shape[0]
        x = rearrange(x, "b c n -> (b n) c", c=self.mlp_channels - (self.proj_channels if self.use_norm else 0))
        x = self.mlp(x)
        x = rearrange(x, "(b n) pc -> b (n pc)", b=b, pc=self.patch_channels)
        return x


class Autoencoder(nn.Module):
    def __init__(self, patch_channels,
                 channels,
                 resnet_count,
                 signal_length,
                 use_fsq=False,
                 use_norm=False,
                 num_transformer_blocks=0,
                 fsq_levels=None,
                 fsq_bits=None,
                 use_vq=False,
                 use_rvq=False,
                 rvq_tokens=2,
                 vq_bits=None,
                 vq_codebook_size=None,
                 is_causal=False,
                 lookahead_tokens=0,
                 use_rope=False,
                 vq_codebook_dim=8,
                 latent_dim=3):
        super().__init__()
        self.patch_channels = patch_channels
        self.channels = channels
        self.use_fsq = use_fsq
        self.use_vq = use_vq or use_rvq
        self.signal_length = signal_length
        
        if self.use_fsq:
            if fsq_bits is not None:
                assert fsq_levels is None, "Cannot define both fsq_levels and fsq_bits"
                fsq_levels = fsq_level_book[fsq_bits]
            if fsq_levels is None:
                raise ValueError("fsq_levels should be present when using FSQ")
            self.fsq = FSQ(fsq_levels)
            emb_channels = len(fsq_levels)
        elif self.use_vq:
            codebook_size = 2**vq_bits if vq_bits is not None else vq_codebook_size
            
            if use_vq:
                self.vq = VectorQuantize(dim=channels[-1], codebook_size=codebook_size, kmeans_init=True, kmeans_iters=10, codebook_dim=vq_codebook_dim)
            elif use_rvq:
                self.vq = ResidualVQ(dim=channels[-1], num_quantizers=rvq_tokens, codebook_size=codebook_size, codebook_dim=8, kmeans_init = True, kmeans_iters = 10)

            emb_channels = channels[-1]
        else:
            emb_channels = latent_dim

        self.down_sampler = DownSampler(patch_channels,
                                        channels,
                                        emb_channels,
                                        resnet_count,
                                        use_norm=use_norm,
                                        num_transformer_blocks=num_transformer_blocks,
                                        is_causal=is_causal,
                                        lookahead_tokens=lookahead_tokens,
                                        use_rope=use_rope)
        self.up_sampler = UpSampler(patch_channels,
                                    channels,
                                    emb_channels,
                                    resnet_count,
                                    use_norm=use_norm,
                                    num_transformer_blocks=num_transformer_blocks,
                                    is_causal=is_causal,
                                    lookahead_tokens=lookahead_tokens,
                                    use_rope=use_rope)

    def forward(self, x):
        x = self.down_sampler(x)
        if self.use_fsq:
            x = rearrange(x, "b c n -> b n c")
            x = self.fsq(x)
            x = rearrange(x, "b n c -> b c n")
        elif self.use_vq:
            x = rearrange(x, "b c n -> b n c")
            x, _, quant_loss = self.vq(x)
            x = rearrange(x, "b n c -> b c n")
            x = self.up_sampler(x)
            return x, torch.mean(quant_loss) 
        x = self.up_sampler(x)
        return x

    def get_input_length(self):
        return (self.signal_length * 2)

    def get_token_count(self):
        return (self.signal_length * 2) // (self.patch_channels * 2 ** (len(self.channels) - 1))

    def encode(self, x):
        assert x.shape[1] == self.get_input_length()
        x = self.down_sampler(x)
        if self.use_fsq:
            x = rearrange(x, "b c n -> b n c")
            x = self.fsq(x)
            x = self.fsq.codes_to_indices(x)
        elif self.use_vq:
            x = rearrange(x, "b c n -> b n c")
            _, x, _ = self.vq(x)
        assert x.shape[1] == self.get_token_count()
        return x

    def decode(self, x):
        assert x.shape[1] == self.get_token_count()
        if self.use_fsq:
            x = self.fsq.indices_to_codes(x)
            x = rearrange(x, "b n c -> b c n")
        elif self.use_vq:
            x = self.vq.get_output_from_indices(x)
            x = rearrange(x, "b n c -> b c n")
        x = self.up_sampler(x)
        assert x.shape[1] == self.get_input_length()
        return x


if __name__ == "__main__":
    downsampler = DownSampler(patch_channels=16, channels=[32, 32, 32, 32], emb_channels=8, resnet_count=2)
    upsampler = UpSampler(patch_channels=16, channels=[32, 32, 32, 32], emb_channels=8, resnet_count=2)
    inp = torch.randn((3, 40960))
    print("inp", inp.shape)
    emb = downsampler(inp)
    print("emb", emb.shape)
    out = upsampler(emb)
    print("out", out.shape)
