import torch
import numpy as np
from model.util import sanitize_sacred_arguments

class SequenceTransformer(torch.nn.Module):

    def __init__(
        self, input_dim, t_limit=1, att_hidden_dims=[128, 128, 128],
        att_num_heads=[8, 8, 8], mlp_hidden_dim=64, time_embed_std=30,
        embed_size=256, pos_enc_size=64
    ):
        """
        Initialize a time-dependent transformer for sequence data.
        Arguments:
            `input_dim`: dimension of input data, D
            `t_limit`: maximum time horizon
            `att_hidden_dims`: hidden dimension of attention layers
            `att_num_heads`: number of heads in each attention layer
            `mlp_hidden_dim`: hidden dimension of MLP layers before each
                attention layer
            `time_embed_std`: standard deviation of random weights to sample for
                time embeddings
            `embed_size`: size of the time embeddings
            `pos_enc_size`: size of positional encodings
        """
        super().__init__()
        
        assert embed_size % 2 == 0
        assert len(att_hidden_dims) == len(att_num_heads)
        num_att_layers = len(att_hidden_dims)

        self.creation_args = locals()
        del self.creation_args["self"]
        del self.creation_args["__class__"]
        self.creation_args = sanitize_sacred_arguments(self.creation_args)
        
        self.t_limit = t_limit
        self.pos_enc_size = pos_enc_size

        # Random embedding layer for time; the random weights are set at the
        # start and are not trainable
        self.time_embed_rand_weights = torch.nn.Parameter(
            torch.randn(embed_size // 2) * time_embed_std,
            requires_grad=False
        )
        
        # Dense layers to generate time embeddings
        self.time_dense_layers = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Linear(embed_size, embed_size),
                torch.nn.Sigmoid(),
                torch.nn.Linear(embed_size, embed_size)
            ) for _ in range(num_att_layers)
        ])

        # Dense layers before each attention layer
        hidden_dims = [input_dim] + att_hidden_dims
        self.pre_att_dense_layers = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Linear(
                    hidden_dims[i] + embed_size + pos_enc_size,
                    hidden_dims[i + 1]
                ),
                torch.nn.ReLU(),
                torch.nn.Linear(hidden_dims[i + 1], hidden_dims[i + 1]),
                torch.nn.ReLU(),
                torch.nn.Linear(hidden_dims[i + 1], hidden_dims[i + 1]),
                torch.nn.LayerNorm(hidden_dims[i + 1])
            ) for i in range(num_att_layers)
        ])

        # Attention layers
        self.att_layers = torch.nn.ModuleList([
            torch.nn.MultiheadAttention(
                att_hidden_dims[i], att_num_heads[i], batch_first=True
            ) for i in range(num_att_layers)
        ])
        self.post_att_layers = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.LayerNorm(att_hidden_dims[i]),
                torch.nn.Dropout(0.1)
            ) for i in range(num_att_layers)
        ])

        # Final dense projector
        self.final_dense = torch.nn.Linear(att_hidden_dims[-1], input_dim)

        self.swish = lambda x: x * torch.sigmoid(x)

    def _get_positional_encoding(self, seq_len, seq_dim):
        """
        Computes a positional encoding for a sequence of tokens.
        Arguments:
            `seq_len`: number of tokens, L'
            `seq_dim`: dimension of each token, D'
        Returns an L' x D'' tensor of encodings, to be concatenated with the
        token representations.
        """
        base = 1e4

        pos_enc = torch.empty((seq_len, seq_dim))

        pos_ran = torch.arange(seq_len)
        dim_ran = torch.arange(0, seq_dim, 2)

        pos_ran_tiled = torch.tile(pos_ran[:, None], (1, len(dim_ran)))
        dim_ran_tiled = torch.tile(dim_ran[None], (len(pos_ran), 1))

        trig_arg = pos_ran_tiled / torch.pow(base, dim_ran_tiled / seq_dim)

        pos_enc[:, dim_ran] = torch.sin(trig_arg)
        pos_enc[:, dim_ran + 1] = torch.cos(trig_arg)
        return pos_enc

    def forward(self, xt, t, mask=None):
        """
        Forward pass of the network.
        Arguments:
            `xt`: B x L x D tensor containing the images to train on
            `t`: B-tensor containing the times to train the network for each
                input
            `mask`: B x L boolean tensor denoting which positions are masked
        Returns a B x L x D tensor which consists of the prediction.
        """
        # Get the time embeddings for `t`
        # We embed the time as cos((t/T) * (2pi) * z) and sin((t/T) * (2pi) * z)
        time_embed_args = (t[:, None] / self.t_limit) * (2 * np.pi) * \
            self.time_embed_rand_weights[None, :]
        # Shape: B x (E / 2)
        time_embed = self.swish(
            torch.cat([
                torch.sin(time_embed_args), torch.cos(time_embed_args)
            ], dim=1)
        )
        # Shape: B x E

        # Get the positional encodings
        pos_enc = self._get_positional_encoding(xt.shape[1], self.pos_enc_size)
        pos_enc = torch.tile(pos_enc[None], (len(xt), 1, 1))  # Shape: B x L x P
        pos_enc = pos_enc.to(xt.device)

        x = xt
        for i in range(len(self.att_layers)):
            # Get time embedding for the layer
            layer_time_embed = self.time_dense_layers[i](time_embed)
            layer_time_embed = torch.tile(
                layer_time_embed[:, None], (1, xt.shape[1], 1)
            )  # Shape: B x L x D
                
            # Concatenate time embedding and positional encoding
            x = torch.cat([x, layer_time_embed, pos_enc], dim=2)

            # Dense network
            x = self.pre_att_dense_layers[i](x)

            # Attention
            x = self.att_layers[i](
                x, x, x, key_padding_mask=mask, need_weights=False
            )[0]
            x = self.post_att_layers[i](x)

        # Final dense
        return self.final_dense(x)

    def loss(self, pred_values, true_values, weights=None, mask=None):
        """
        Computes the loss of the neural network.
        Arguments:
            `pred_values`: a B x L x D tensor of predictions from the network
            `true_values`: a B x L x D tensor of true values to predict
            `weights`: if provided, a tensor broadcastable with B x D to weight
                the squared error by, prior to summing or averaging across
                dimensions
            `mask`: B x L boolean tensor denoting which positions are masked by
                padding
        Returns a scalar loss of mean-squared-error values, summed across the
        D dimension and averaged across the batch dimension.
        """
        # Compute loss as MSE
        squared_error = torch.square(true_values - pred_values)
        if weights is not None:
            squared_error = squared_error / weights

        mean_error = torch.mean(squared_error, dim=2)
        mean_error[mask] = 0
        
        return torch.sum(mean_error) / torch.sum(torch.logical_not(mask))


class SequenceTransformer2(torch.nn.Module):

    def __init__(
        self, input_dim, t_limit=1, num_att_layers=3, att_hidden_dim=128,
        att_num_heads=8, att_mlp_hidden_dim=64, time_embed_std=30,
        embed_size=256, pos_enc_size=64
    ):
        """
        Initialize a time-dependent transformer for sequence data.
        Arguments:
            `input_dim`: dimension of input data, D
            `t_limit`: maximum time horizon
            `num_att_layers`: number of attention layers
            `att_hidden_dim`: hidden dimension of attention layers
            `att_num_heads`: number of heads in each attention layer
            `att_mlp_hidden_dim`: hidden dimension of MLP layers for each
                attention layer
            `time_embed_std`: standard deviation of random weights to sample for
                time embeddings
            `embed_size`: size of the time embeddings
            `pos_enc_size`: size of positional encodings
        """
        super().__init__()
        
        assert embed_size % 2 == 0

        self.creation_args = locals()
        del self.creation_args["self"]
        del self.creation_args["__class__"]
        self.creation_args = sanitize_sacred_arguments(self.creation_args)
        
        self.t_limit = t_limit
        self.pos_enc_size = pos_enc_size

        # Random embedding layer for time; the random weights are set at the
        # start and are not trainable
        self.time_embed_rand_weights = torch.nn.Parameter(
            torch.randn(embed_size // 2) * time_embed_std,
            requires_grad=False
        )

        self.time_dense_layers = torch.nn.Sequential(
            torch.nn.Linear(embed_size, embed_size),
            torch.nn.Sigmoid(),
            torch.nn.Linear(embed_size, embed_size)
        )

        self.init_dense = torch.nn.Linear(
            input_dim + embed_size + pos_enc_size, att_hidden_dim
        )

        encoder = torch.nn.TransformerEncoderLayer(
            att_hidden_dim, att_num_heads, dim_feedforward=att_mlp_hidden_dim,
            batch_first=True
        )
        self.transformer = torch.nn.TransformerEncoder(encoder, num_att_layers)

        self.final_dense = torch.nn.Linear(att_hidden_dim, input_dim)

        self.swish = lambda x: x * torch.sigmoid(x)

    def _get_positional_encoding(self, seq_len, seq_dim):
        """
        Computes a positional encoding for a sequence of tokens.
        Arguments:
            `seq_len`: number of tokens, L'
            `seq_dim`: dimension of each token, D'
        Returns an L' x D'' tensor of encodings, to be concatenated with the
        token representations.
        """
        base = 1e4

        pos_enc = torch.empty((seq_len, seq_dim))

        pos_ran = torch.arange(seq_len)
        dim_ran = torch.arange(0, seq_dim, 2)

        pos_ran_tiled = torch.tile(pos_ran[:, None], (1, len(dim_ran)))
        dim_ran_tiled = torch.tile(dim_ran[None], (len(pos_ran), 1))

        trig_arg = pos_ran_tiled / torch.pow(base, dim_ran_tiled / seq_dim)

        pos_enc[:, dim_ran] = torch.sin(trig_arg)
        pos_enc[:, dim_ran + 1] = torch.cos(trig_arg)
        return pos_enc

    def forward(self, xt, t, mask=None):
        """
        Forward pass of the network.
        Arguments:
            `xt`: B x L x D tensor containing the images to train on
            `t`: B-tensor containing the times to train the network for each
                input
            `mask`: B x L boolean tensor denoting which positions are masked
        Returns a B x L x D tensor which consists of the prediction.
        """
        # Get the time embeddings for `t`
        # We embed the time as cos((t/T) * (2pi) * z) and sin((t/T) * (2pi) * z)
        time_embed_args = (t[:, None] / self.t_limit) * (2 * np.pi) * \
            self.time_embed_rand_weights[None, :]
        # Shape: B x (E / 2)
        time_embed = self.swish(
            torch.cat([
                torch.sin(time_embed_args), torch.cos(time_embed_args)
            ], dim=1)
        )  # Shape: B x E
        time_embed = self.time_dense_layers(time_embed)
        time_embed = torch.tile(
            time_embed[:, None], (1, xt.shape[1], 1)
        )

        # Get the positional encodings
        pos_enc = self._get_positional_encoding(xt.shape[1], self.pos_enc_size)
        pos_enc = torch.tile(pos_enc[None], (len(xt), 1, 1))  # Shape: B x L x P
        pos_enc = pos_enc.to(xt.device)

        x = self.init_dense(torch.cat([
            xt, time_embed, pos_enc
        ], dim=2))

        x = self.transformer(x, src_key_padding_mask=mask)

        return self.final_dense(x)

    def loss(self, pred_values, true_values, weights=None, mask=None):
        """
        Computes the loss of the neural network.
        Arguments:
            `pred_values`: a B x L x D tensor of predictions from the network
            `true_values`: a B x L x D tensor of true values to predict
            `weights`: if provided, a tensor broadcastable with B x D to weight
                the squared error by, prior to summing or averaging across
                dimensions
            `mask`: B x L boolean tensor denoting which positions are masked by
                padding
        Returns a scalar loss of mean-squared-error values, summed across the
        D dimension and averaged across the batch dimension.
        """
        # Compute loss as MSE
        squared_error = torch.square(true_values - pred_values)
        if weights is not None:
            squared_error = squared_error / weights
         
        mean_error = torch.mean(squared_error, dim=2)
        # print(pred_values, true_values, mean_error)
        mean_error[mask] = 0
        
        return torch.sum(mean_error) / torch.sum(torch.logical_not(mask))
       

class SequenceTransformer2_y(torch.nn.Module):

    def __init__(
        self, input_dim, num_att_layers=3, att_hidden_dim=128,
        att_num_heads=8, att_mlp_hidden_dim=64, time_embed_std=30,
        embed_size=256, pos_enc_size=64
    ):
        """
        Initialize a time-dependent transformer for sequence data.
        Arguments:
            `input_dim`: dimension of input data, D
            `num_att_layers`: number of attention layers
            `att_hidden_dim`: hidden dimension of attention layers
            `att_num_heads`: number of heads in each attention layer
            `att_mlp_hidden_dim`: hidden dimension of MLP layers for each
                attention layer
            `time_embed_std`: standard deviation of random weights to sample for
                time embeddings
            `embed_size`: size of the time embeddings
            `pos_enc_size`: size of positional encodings
        """
        super().__init__()
        


        self.creation_args = locals()
        del self.creation_args["self"]
        del self.creation_args["__class__"]
        self.creation_args = sanitize_sacred_arguments(self.creation_args)
        
        self.pos_enc_size = pos_enc_size 

        self.init_dense = torch.nn.Linear(
            input_dim  + pos_enc_size, att_hidden_dim
        )

        encoder = torch.nn.TransformerEncoderLayer(
            att_hidden_dim, att_num_heads, dim_feedforward=att_mlp_hidden_dim,
            batch_first=True
        )
        self.transformer = torch.nn.TransformerEncoder(encoder, num_att_layers)

        self.final_dense = torch.nn.Linear(att_hidden_dim,1)
        

        self.swish = lambda x: x * torch.sigmoid(x)

    def _get_positional_encoding(self, seq_len, seq_dim):
        """
        Computes a positional encoding for a sequence of tokens.
        Arguments:
            `seq_len`: number of tokens, L'
            `seq_dim`: dimension of each token, D'
        Returns an L' x D'' tensor of encodings, to be concatenated with the
        token representations.
        """
        base = 1e4

        pos_enc = torch.empty((seq_len, seq_dim))

        pos_ran = torch.arange(seq_len)
        dim_ran = torch.arange(0, seq_dim, 2)

        pos_ran_tiled = torch.tile(pos_ran[:, None], (1, len(dim_ran)))
        dim_ran_tiled = torch.tile(dim_ran[None], (len(pos_ran), 1))

        trig_arg = pos_ran_tiled / torch.pow(base, dim_ran_tiled / seq_dim)

        pos_enc[:, dim_ran] = torch.sin(trig_arg)
        pos_enc[:, dim_ran + 1] = torch.cos(trig_arg)
        return pos_enc

    def forward(self, xt, mask=None):
        """
        Forward pass of the network.
        Arguments:
            `xt`: B x L x D tensor containing the images to train on
            `t`: B-tensor containing the times to train the network for each
                input
            `mask`: B x L boolean tensor denoting which positions are masked
        Returns a B x L x D tensor which consists of the prediction.
        """

        # Get the positional encodings
        pos_enc = self._get_positional_encoding(xt.shape[1], self.pos_enc_size)
        pos_enc = torch.tile(pos_enc[None], (len(xt), 1, 1))  # Shape: B x L x P
        pos_enc = pos_enc.to(xt.device)

        x = self.init_dense(torch.cat([
            xt, pos_enc
        ], dim=2))

        x = self.transformer(x, src_key_padding_mask=mask)
        x2 = torch.sum(x,dim=1)
        return torch.sigmoid(self.final_dense(x2))

    def loss(self, pred_values, true_values, weights=None, mask=None):
        """
        Computes the loss of the neural network.
        Arguments:
            `pred_values`: a B x L x D tensor of predictions from the network
            `true_values`: a B x L x D tensor of true values to predict
            `weights`: if provided, a tensor broadcastable with B x D to weight
                the squared error by, prior to summing or averaging across
                dimensions
            `mask`: B x L boolean tensor denoting which positions are masked by
                padding
        Returns a scalar loss of mean-squared-error values, summed across the
        D dimension and averaged across the batch dimension.
        """
        # Compute loss as MSE
        squared_error = torch.square(true_values - pred_values)
        if weights is not None:
            squared_error = squared_error / weights
         
        mean_error = torch.mean(squared_error)
        # print(pred_values, true_values, mean_error)
        ####mean_error[mask] = 0
        
        return mean_error
    
if __name__ == "__main__":
    import feature.sequence_dataset as sequence_dataset

    # Define device
    if torch.cuda.is_available():
        DEVICE = "cuda"
    else:
        DEVICE = "cpu"

    dataset = sequence_dataset.SeqDataset(
        sequence_dataset.AMPSeqLoader(), sequence_dataset.PROTEIN_ALPHABET, 32
    )
    dataset.on_epoch_start()
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=None, num_workers=4, collate_fn=(lambda x: x)
    )
    batch = next(iter(data_loader))
    xt, mask = batch["x"].to(DEVICE).float(), batch["mask"].to(DEVICE)
    t = torch.rand(len(xt)).to(DEVICE)

    model = SequenceTransformer2(20).to(DEVICE)
    
    pred = model(xt, t, mask=mask)
    
    loss = model.loss(pred, torch.ones_like(pred), mask=mask)
