import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from ..lotfr_module.linear_attention import LinearAttention


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


class TransformerEncoderGeneral(nn.Module):
    """A Local Feature Transformer Guilded with depth information."""
    def __init__(self, config):
        super(TransformerEncoderGeneral, self).__init__()
        
        self.config = config
        self.d_model = config['d_model']
        self.nhead = config['nhead']
        self.n_visual_layers = config['num_visual_layers']
        # setup encoder layers
        encoder_layer = TransformerEncoderLayerGeneral(self.d_model, self.nhead)
        self.selfAttn_layers = nn.ModuleList(
            [copy.deepcopy(encoder_layer) for _ in range(self.n_visual_layers)])
        self.crossAttn_layers = nn.ModuleList(
            [copy.deepcopy(encoder_layer) for _ in range(self.n_visual_layers)])
        self._reset_parameters()
        
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def forward(self, feat0, feat1, pos_embed0, pos_embed1, mask0=None, mask1=None):
        """Transform visual feature together with depth features
        Args:
            feat0 (torch.Tensor): [N, L, C]
            feat1 (torch.Tensor): [N, S, C]
            pos_embed0 (torch.Tensor): [N, L, C]
            pos_embed1 (torch.Tensor): [N, S, C]
            mask0 (torch.Tensor): [N, L] (optional)
            mask1 (torch.Tensor): [N, S] (optional)
        """
        assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
        for selfAttn_layer, crossAttn_layer in zip(self.selfAttn_layers, self.crossAttn_layers):
            # First self attention with visual information
            feat0 = selfAttn_layer(feat0, feat0, pos_embed0, pos_embed0, mask0, mask0)
            feat1 = selfAttn_layer(feat1, feat1, pos_embed1, pos_embed1, mask1, mask1)
            # Finally cross attention between two image features
            feat0_after = crossAttn_layer(feat0, feat1, pos_embed0, pos_embed1, mask0, mask1)
            feat1_after = crossAttn_layer(feat1, feat0, pos_embed1, pos_embed0, mask1, mask0)
            # assign new value
            feat0 = feat0_after
            feat1 = feat1_after
        return feat0, feat1


class DepthGuidedEncoder(nn.Module):
    """A Local Feature Transformer Guilded with depth information."""
    def __init__(self, config):
        super(DepthGuidedEncoder, self).__init__()
        
        self.config = config
        self.d_model = config['d_model']
        self.nhead = config['nhead']
        self.n_visual_layers = config['num_visual_layers']
        self.n_depth_layers = config['num_depth_layers']
        # setup encoder layers
        encoder_layer = TransformerEncoderLayerGeneral(self.d_model, self.nhead)
        self.depth_layers = nn.ModuleList(
            [copy.deepcopy(encoder_layer) for _ in range(self.n_depth_layers)])
        self.selfAttn_layers = nn.ModuleList(
            [copy.deepcopy(encoder_layer) for _ in range(self.n_visual_layers)])
        self.crossAttn_layers = nn.ModuleList(
            [copy.deepcopy(encoder_layer) for _ in range(self.n_visual_layers)])
        self._reset_parameters()
        
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def forward(self, feat0, feat1, pos_embed0, pos_embed1, depth_embed0, depth_embed1, mask0=None, mask1=None, ):
        """Transform visual feature together with depth features
        Args:
            feat0 (torch.Tensor): [N, L, C]
            feat1 (torch.Tensor): [N, S, C]
            pos_embed0 (torch.Tensor): [N, L, C]
            pos_embed1 (torch.Tensor): [N, S, C]
            mask0 (torch.Tensor): [N, L] (optional)
            mask1 (torch.Tensor): [N, S] (optional)
        """
        assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
        for depth_layer, selfAttn_layer, crossAttn_layer in zip(self.depth_layers, self.selfAttn_layers, self.crossAttn_layers):
            # Fetch depth information into features
            feat0 = depth_layer(feat0, depth_embed0, None, None, mask0, mask0)
            feat1 = depth_layer(feat1, depth_embed1, None, None, mask1, mask1)
            # First self attention with visual information
            feat0 = selfAttn_layer(feat0, feat0, pos_embed0, pos_embed0, mask0, mask0)
            feat1 = selfAttn_layer(feat1, feat1, pos_embed1, pos_embed1, mask1, mask1)
            # Finally cross attention between two image features
            feat0_after = crossAttn_layer(feat0, feat1, pos_embed0, pos_embed1, mask0, mask1)
            feat1_after = crossAttn_layer(feat1, feat0, pos_embed1, pos_embed0, mask1, mask0)
            # assign new value
            feat0 = feat0_after
            feat1 = feat1_after
        return feat0, feat1


class TransformerEncoder(nn.Module):
    
    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        
    def forward(self, src, pos_encoder, src_shape, source_mask=None):
        output = src
        
        for layer in self.layers:
            output = layer(output, output, pos_encoder, src_shape, src_shape, x_mask=None, source_mask=source_mask)
            
        if self.norm is not None:
            output = self.norm(output)
        
        return output
    

class TransformerEncoderLayer(nn.Module):
    
    def __init__(self, d_model, nhead):
        super(TransformerEncoderLayer, self).__init__()
        self.dim = d_model // nhead
        self.nhead = nhead
        # MultiHead Linear Attention
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        
        self.self_attn = LinearAttention()
        self.merge = nn.Linear(d_model, d_model, bias=False)
        
        self.FFN = nn.Sequential(
            nn.Linear(d_model*2, d_model*2, bias=False),
            nn.ReLU(True),
            nn.Linear(d_model*2, d_model, bias=False)
        )
        
        # norm and dropout
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
    def forward(self, x, source, pos_encoder=None, x_shape=None, src_shape=None, x_mask=None, source_mask=None, dp_query=None, dp_key=None):
        """
        Args:
            x (torch.Tensor): [N, L, C]
            source (torch.Tensor): [N, S, C]
            x_mask (torch.Tensor): [N, L] (optional)
            source_mask (torch.Tensor): [N, S] (optional)
        """
        bs = x.size(0)
        # Embed 2d PE
        if pos_encoder is not None:
            query, key, value = pos_encoder(x, x_shape), pos_encoder(source, src_shape), source
        else:
            query, key, value = x, source, source
        # Embed depth PE
        if dp_query is not None: # Note in this case
            query = query + dp_query
            key = key + dp_key
        
        # multi-head attention
        query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
        key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
        value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
        
        message = self.self_attn(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
        message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C]
        message = self.norm1(message)
        
        # feed-forward network
        message = self.FFN(torch.cat([x, message], dim=2))
        message = self.norm2(message)
        
        return x + message


class TransformerEncoderLayerGeneral(nn.Module):
    
    def __init__(self, d_model, nhead):
        super(TransformerEncoderLayerGeneral, self).__init__()
        self.dim = d_model // nhead
        self.nhead = nhead
        # MultiHead Linear Attention
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        
        self.self_attn = LinearAttention()
        self.merge = nn.Linear(d_model, d_model, bias=False)
        
        self.FFN = nn.Sequential(
            nn.Linear(d_model*2, d_model*2, bias=False),
            nn.ReLU(True),
            nn.Linear(d_model*2, d_model, bias=False)
        )
        
        # norm and dropout
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
    def forward(self, x, source, query_embed=None, key_embed=None, x_mask=None, source_mask=None):
        """
        Args:
            x (torch.Tensor): [N, L, C]
            source (torch.Tensor): [N, S, C]
            x_mask (torch.Tensor): [N, L] (optional)
            source_mask (torch.Tensor): [N, S] (optional)
        """
        bs = x.size(0)
        # Embed PE
        if query_embed is not None and key_embed is not None:
            query = x + query_embed
            key = source + key_embed
            value = source
        else:
            query, key, value = x, source, source
        
        # multi-head attention
        query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
        key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
        value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
        
        message = self.self_attn(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
        message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C]
        message = self.norm1(message)
        
        # feed-forward network
        message = self.FFN(torch.cat([x, message], dim=2))
        message = self.norm2(message)
        
        return x + message