import torch.nn as nn
import torch.nn.functional as F
import torch
import math
from rlkit.torch.core import PyTorchModule

class BERT(PyTorchModule):
    """
    BERT model : Bidirectional Encoder Representations from Transformers.
    """
    def __init__(self, 
            hidden=100,
            input_size=5,
            output_size=2,
            n_layers=6,
            attn_heads=128,
            dropout=0.1,
            use_sequence_attention=True,
            use_multihead_attention=True,
            use_channel_attention=True,
            mode='parallel',
            #batch_size=1024,
            #batch_attention=False
        ):
        """
        :param vocab_size: vocab_size of total words
        :param hidden: BERT model hidden size
        :param n_layers: numbers of Transformer blocks(layers)
        :param attn_heads: number of attention heads
        :param dropout: dropout rate
        """
        self.save_init_params(locals())
        super().__init__()
        self.hidden = hidden
        self.input_size = input_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.attn_heads = attn_heads
        self.use_sequence_attention = use_sequence_attention
        self.use_multihead_attention = use_multihead_attention
        self.use_channel_attention = use_channel_attention
        self.mode = mode

        # multi-layers transformer blocks, deep network
        m_front = [
            nn.Linear(self.input_size, self.hidden),
            nn.Tanh(),
        ]
        self.front = nn.Sequential(*m_front)

        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(
                hidden,
                attn_heads,
                hidden * 4,
                dropout,
                self.use_sequence_attention,
                self.use_multihead_attention,
                self.use_channel_attention,
                self.mode,
                ) for _ in range(n_layers)
            ])

        self.tail=nn.Linear(self.hidden, self.output_size)
        

    def forward(self, x, segment_info=None, tanh=True):
        # attention masking for padded token
        # embedding the indexed sequence to sequence of vectors
        # running over multiple transformer blocks
        x = self.front(x)
        for transformer in self.transformer_blocks:
            x = transformer.forward(x)
        if tanh:
            x = torch.tanh(self.tail(x))
        else:
            x = self.tail(x)
        #x = torch.clamp(x, -1, 1)
        return x

class FlattenBERT(BERT):
    #if there are multiple inputs, concatenate along dim -1
    def forward(self, meta_size=16, batch_size=256, *inputs, **kwargs):
        flat_inputs = torch.cat(inputs, dim=-1).unsqueeze(1).view(meta_size, batch_size, -1)
        #return super().forward(flat_inputs, **kwargs).squeeze(1)
        return super().forward(flat_inputs, **kwargs).view(meta_size*batch_size, -1)

class TransformerBlock(nn.Module):
    """
    Bidirectional Encoder = Transformer (self-attention)
    Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
    """

    def __init__(
        self,
        hidden,
        attn_heads,
        feed_forward_hidden,
        dropout,
        use_sequence_attention=True,
        use_multihead_attention=True,
        use_channel_attention=False,
        mode='parallel'
        ):
        """
        :param hidden: hidden size of transformer
        :param attn_heads: head sizes of multi-head attention
        :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size
        :param dropout: dropout rate
        """
        super().__init__()
        self.use_channel_attention = use_channel_attention
        self.use_sequence_attention = use_sequence_attention
        self.mode = mode

        if self.use_sequence_attention:
            self.sequence_attention = MultiHeadedAttention(h=attn_heads, d_model=hidden) if use_multihead_attention else Attention()
        if self.use_channel_attention:
            self.batch_attention = ChannelAttention(feature_size=hidden, reduction=16)
        
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        if self.mode=='serialize':
            self.extra_input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)
    
    def tb_1_dim(self, x):
        t, b, dim = x.size()
        return x.view(t*b, 1, dim)

    def forward(self, x, mask=None):
        t, b, dim = x.size()
        if self.use_channel_attention and self.use_sequence_attention:
            if self.mode == 'parallel':
                x = self.input_sublayer(x, lambda _x: self.sequence_attention.forward(self.tb_1_dim(_x), self.tb_1_dim(_x), self.tb_1_dim(_x), mask=mask).view(t, b, dim)
                    + self.batch_attention.forward(_x) )
            elif self.mode == 'serialize':
                x = self.input_sublayer(x, lambda _x: self.sequence_attention.forward(self.tb_1_dim(_x), self.tb_1_dim(_x), self.tb_1_dim(_x), mask=mask).view(t, b, dim))
                x = self.extra_input_sublayer(x, lambda _x: self.batch_attention.forward(_x))
        elif self.use_channel_attention:
            x = self.input_sublayer(x, lambda _x: self.batch_attention.forward(_x))
        elif self.use_sequence_attention:
            x = self.input_sublayer(x, lambda _x: self.sequence_attention.forward(self.tb_1_dim(_x), self.tb_1_dim(_x), self.tb_1_dim(_x), mask=mask).view(t, b, dim))

        x = self.output_sublayer(x, self.feed_forward)

        return self.dropout(x)

class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = GELU()
        #self.activation = nn.LeakyReLU()

    def forward(self, x):
        return self.w_2(self.dropout(self.activation(self.w_1(x))))

class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """

    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer, residual=True):
        "Apply residual connection to any sublayer with the same size."
        if residual:
            return x + self.dropout(sublayer(self.norm(x)))
        else:
            return self.dropout(sublayer(self.norm(x)))
        #return x + self.dropout(sublayer(x))

class GELU(nn.Module):
    """
    Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
    """

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."

    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

class ChannelAttention(nn.Module):
    """
    Compute Batch Attention using channel attention
    """
    def __init__(self, feature_size=128, reduction=16, activation=nn.Tanh, output_activation=nn.Sigmoid):
        super().__init__()
        # global average pooling: feature --> point
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        # feature channel downscale and upscale --> channel weight

        self.fc = nn.Sequential(
            nn.Linear(feature_size, feature_size//reduction),
            activation(),
            nn.Linear(feature_size//reduction, feature_size),
            output_activation()
        )
    def forward(self, x, use_softmax=False, weight_only=False):
        y = self.fc(x)
        y = self.avg_pool(y)
        if use_softmax:
            y = F.softmax(y, dim=1)

        if weight_only:
            return y
        return x * y

class Attention(nn.Module):
    """
    Compute 'Scaled Dot Product Attention
    """

    def forward(self, query, key, value, mask=None, dropout=None):

        #print(query.shape, key.shape, value.shape)
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(query.size(-1))
        
        #print("score", scores.shape)
    

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        return torch.matmul(p_attn, value)


class MultiHeadedAttention(nn.Module):
    """
    Take in model size and number of heads.
    """

    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model % h == 0

        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h

        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model)
        self.attention = Attention()

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
                             for l, x in zip(self.linear_layers, (query, key, value))]

        # 2) Apply attention on all the projected vectors in batch.
        x = self.attention(query, key, value, mask=mask, dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

        return self.output_linear(x)
