import torch
import math
from typing import Iterable, Optional
from torch import nn


class PositionFeedForward(nn.Module):
    def __init__(self, d_in, d_out, rank=None):
        super().__init__()
        if rank is None:
            self.conv = nn.Conv1d(d_in, d_out, 1)
            self.factorized = False
        else:
            layer = nn.Linear(d_in, d_out)
            w = layer.weight.data
            self.bias = layer.bias
            u, s, v = torch.svd(w)
            s = torch.diag(s[:rank].sqrt())
            u = u[:, :rank]
            v = v.t()[:rank]
            self.u = nn.Parameter(u @ s)
            self.v = nn.Parameter(s @ v)
            self.factorized = True

    def forward(self, x):
        if self.factorized:
            w = self.u @ self.v
            return x @ w.t() + self.bias
        else:
            return self.conv(x.transpose(1, 2)).transpose(1, 2)


class SeqCNN(nn.Module):
    def __init__(
        self,
        chain_length: int,
        vocab_size: int = 21,
        n_positional: int = 20,
        hidden_features: int = 128,
        kernel_sizes: Iterable[int] = (15, 5, 3),
        activation: torch.nn = nn.SiLU,
    ):
        super().__init__()

        self.n_positional = n_positional
        self.chain_length = chain_length
        self.vocab_size = vocab_size
        self.hidden_features = hidden_features
        self.kernel_sizes = kernel_sizes

        self.input_seq = nn.Linear(vocab_size * chain_length, hidden_features)
        self.input_aa = nn.Linear(vocab_size + n_positional, hidden_features)
        self.activation = activation
        self.conv_layers = nn.ModuleList()

        for _, k_size in enumerate(kernel_sizes):
            self.conv_layers.append(
                nn.Conv1d(hidden_features, hidden_features, kernel_size=k_size, stride=1, padding=0)
            )

        self.output_seq = nn.Sequential(
            nn.Linear(2 * hidden_features, hidden_features),
            self.activation,
            nn.Linear(hidden_features, 1),
        )

    def forward(self, x):
        # sequence level embedding
        z_seq = self.input_seq(x.reshape(x.shape[0], self.vocab_size * self.chain_length))
        
        # AA level embedding
        p = positionalencoding1d(d_model=self.n_positional, length=x.shape[1]).unsqueeze(0)
        p = torch.tile(p, dims=(x.shape[0], 1, 1))
        p = p.to(x.device)
        z_aa = self.activation(self.input_aa(torch.cat((x, p), dim=2)))

        z_aa = z_aa.permute(0, 2, 1)
        for conv_layer in self.conv_layers:
            z_aa = self.activation(conv_layer(z_aa))
        
        z_aa_seq = torch.mean(z_aa, dim=2)

        # joint embedding
        h = torch.cat((z_seq, z_aa_seq), dim=-1)
        energy = self.output_seq(h).squeeze(dim=-1)

        return energy, h


def positionalencoding1d(d_model: int, length: int):
    """
    :param d_model: dimension of the model
    :param length: length of positions
    :return: length*d_model position matrix
    adapted from https://github.com/wzlxjtu/PositionalEncoding2D
    """
    if d_model % 2 != 0:
        raise ValueError(
            "Cannot use sin/cos positional encoding with " "odd dim (got dim={:d})".format(d_model)
        )
    pe = torch.zeros(length, d_model)
    position = torch.arange(0, length).unsqueeze(1)
    div_term = torch.exp(
        (torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model))
    )
    pe[:, 0::2] = torch.sin(position.float() * div_term)
    pe[:, 1::2] = torch.cos(position.float() * div_term)

    return pe

class SeqCNN1d(nn.Module):
    def __init__(
        self,
        chain_length: int,
        n_tokens: int = 21,
        hidden_features: int = 128,
        kernel_sizes: Iterable[int] = (15, 5, 3),
        num_classes: Optional[int] = 1,
        activation: nn = nn.ReLU,
        aggregate: Optional[bool] = False,
    ):
        super().__init__()

        self.chain_length = chain_length
        self.n_tokens = n_tokens
        self.hidden_features = hidden_features
        self.kernel_sizes = kernel_sizes
        self.num_classes = num_classes

        self.input_aa = nn.Linear(n_tokens, hidden_features)
        self.activation = activation
        self.conv_layers = nn.ModuleList()

        for _, k_size in enumerate(kernel_sizes):
            self.conv_layers.append(
                nn.Conv1d(hidden_features, hidden_features, kernel_size=k_size, stride=1)
            )

        self.aggregate = aggregate
        if self.aggregate:
            self.input_seq = nn.Linear(n_tokens * chain_length, hidden_features)
        else:
            self.input_seq = nn.Linear(n_tokens, hidden_features)
        
        self.output_seq = nn.Sequential(
            nn.Linear(2 * hidden_features, hidden_features),
            self.activation,
            nn.Linear(hidden_features, self.num_classes),
        )

    def forward(self, x):

        #print('X', x.shape)
        z_aa = self.activation(self.input_aa(x))

        z_aa = z_aa.permute(0, 2, 1)
        #print(z_aa.shape)
        
        for conv_layer in self.conv_layers:
            z_aa = self.activation(conv_layer(z_aa))
        
        if self.aggregate:
            # sequence level embedding
            #print('aggregate')
            z_seq = self.input_seq(x.reshape(x.shape[0], self.n_tokens * self.chain_length))
            #print(z_seq.shape)
            z_aa_seq = torch.mean(z_aa, dim=2)
            #print(z_aa_seq.shape, z_seq.shape)
            # joint embedding
            h = torch.cat((z_seq, z_aa_seq), dim=-1)
        else:
            #residue-level embedding
            #print("per-reside")
            z_seq = self.input_seq(x)
            h = torch.cat((z_seq, z_aa.permute(0, 2, 1)), dim=-1)

        return self.output_seq(h).squeeze(dim=-1)


class CNNTokenizer(nn.Module):
    def __init__(
        self,
        n_tokens: int = 21,
        d_model: int = 256,
        kernel_sizes: Iterable[int] = (15, 5, 3)
    ):
        super().__init__()
        layers = [
            ByteNetBlock(
                d_model if i > 0 else n_tokens,
                d_model,
                d_model,
                kernel_size,
                dilation=1,
                rank=None,
            )
            for i, kernel_size in enumerate(kernel_sizes)
        ]

        self.layers = nn.ModuleList(layers)

    def forward(self, x: torch.Tensor, input_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, input_mask=input_mask)

        return x


class MaskedConv1d(nn.Conv1d):
    """A masked 1-dimensional convolution layer.

    Takes the same arguments as torch.nn.Conv1D, except that the padding is set automatically.

         Shape:
            Input: (N, L, in_channels)
            input_mask: (N, L, 1), optional
            Output: (N, L, out_channels)
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
    ):
        """
        :param in_channels: input channels
        :param out_channels: output channels
        :param kernel_size: the kernel width
        :param stride: filter shift
        :param dilation: dilation factor
        :param groups: perform depth-wise convolutions
        :param bias: adds learnable bias to output
        """
        padding = dilation * (kernel_size - 1) // 2
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding=padding,
        )

    def forward(self, x, input_mask: Optional[torch.Tensor] = None):
        if input_mask is not None:
            # padding mask
            x.masked_fill_(input_mask[..., None], 0.0)
        return super().forward(x.transpose(1, 2)).transpose(1, 2)


class ByteNetBlock(nn.Module):
    """Residual block from ByteNet paper (https://arxiv.org/abs/1610.10099).

    Shape:
       Input: (N, L, d_in)
       input_mask: (N, L, 1), optional
       Output: (N, L, d_out)

    """

    def __init__(
        self, d_in, d_h, d_out, kernel_size, dilation=1, groups=1, 
        rank=None
    ):
        super().__init__()
        self.conv = MaskedConv1d(
            d_h, d_h, kernel_size=kernel_size, dilation=dilation, groups=groups
        )

        self.res_connection = d_in == d_out
        act = nn.ReLU

        layers1 = [
            nn.LayerNorm(d_in),
            act(),
            PositionFeedForward(d_in, d_h, rank=rank),
            nn.LayerNorm(d_h),
            act(),
        ]
        layers2 = [
            nn.LayerNorm(d_h),
            act(),
            PositionFeedForward(d_h, d_out, rank=rank),
        ]
        self.sequence1 = nn.Sequential(*layers1)
        self.sequence2 = nn.Sequential(*layers2)

    def forward(self, x, input_mask=None):
        """
        :param x: (batch, length, in_channels)
        :param input_mask: (batch, length, 1)
        :return: (batch, length, out_channels)
        """
        rep = self.sequence2(self.conv(self.sequence1(x), input_mask=input_mask))
        if self.res_connection:
            return x + rep
        return rep



class Transformer(nn.Module):
    """
    Ingests a sequence and returns a sequence of the same length.
    """

    def __init__(
        self,
        n_tokens: int = 512,
        d_model: int = 256,
        n_head: int = 8,
        dim_feedforward: int = 2048,
        encoder_depth: int = 3,
        decoder_depth: int = 3,
        activation: str = "gelu",
        length: int = 301,
        return_energies: bool = False,
        context_tokens: int = 0,
        num_classes: int = 21,
    ):
        super().__init__()
        self.register_buffer("positional_encoding", positionalencoding1d(d_model, length))
        self.tokenizer = CNNTokenizer(n_tokens=n_tokens+context_tokens, d_model=d_model)
        
        self.length = length
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=n_head,
            num_encoder_layers=encoder_depth,
            num_decoder_layers=decoder_depth,
            dim_feedforward=dim_feedforward,
            dropout=0.0,
            activation=activation,
            batch_first=True,
        )
        
        self.num_classes = num_classes
        self.mlp = nn.Linear(d_model, 
                             1 if return_energies else self.num_classes)


    @property
    def output_shape(self) -> int:
        return self.mlp.weight.size(0)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None, 
                attention_mask: torch.Tensor = None) -> torch.Tensor:

        x = self.tokenizer(x)
        x = x + self.positional_encoding

        x = self.transformer(x, x, src_key_padding_mask = mask, tgt_key_padding_mask=mask)
        if self.output_shape == 1:
            x = torch.mean(x, dim=1)
            x = self.mlp(x)
        else:
            x = self.mlp(x)
        
        return x
