import torch
import torch.nn as nn
import torch.nn.functional as F
from modeling_moebert import MoEBertModel
from transformers import BertConfig, DistilBertModel


class MoEBertEncoder(nn.Module):
    def __init__(self, 
        vocab_size = 120067,
        max_orig_positional_len = 2048,
        hidden_size = 768,
        num_hidden_layers = 6,
        num_attention_heads = 12,
        intermediate_size = 3072,
        intermediate_size_expert = 3072,
        num_expert_heads = 0,
        pad_token_id = 0,
        max_len = 16,
        hidden_act = 'relu',
        token_moe = True,
        moe_type = 'topk',
        topk = 1,
        hash_list_path = None,
        num_experts = 1,
        num_sparse_layers = 3,
        gradient_checkpointing = False,
        **kwargs
    ):
        super(MoEBertEncoder, self).__init__()
        self.config = BertConfig(
            vocab_size = vocab_size,
            max_orig_positional_len = max_orig_positional_len,
            hidden_size = hidden_size,
            num_hidden_layers = num_hidden_layers,
            num_attention_heads = num_attention_heads,
            intermediate_size = intermediate_size,
            intermediate_size_expert = intermediate_size_expert,
            num_expert_heads = num_expert_heads,
            pad_token_id = pad_token_id,
            hidden_act = hidden_act,
            model_type="xlm-roberta",
            token_moe = token_moe,
            moe_type = moe_type,
            topk = topk,
            hash_list_path = hash_list_path,
            num_experts = num_experts,
            num_sparse_layers = num_sparse_layers,
            gradient_checkpointing = gradient_checkpointing,
            **kwargs
        )
        self.encoder = MoEBertModel(self.config)
        
    def forward(self, in_data, router_labels = None, mean_pooling = True):

        if self.config.num_experts != 1:
            output = self.encoder(in_data['input_ids'], attention_mask=in_data['attention_mask'], router_labels = router_labels)
        else:
            output = self.encoder(in_data['input_ids'], attention_mask=in_data['attention_mask'])
        hidden_state = output.last_hidden_state
        if mean_pooling:
            attention_mask = output.attentions.unsqueeze(-1)
            masked_hidden_state = hidden_state * attention_mask
            embeddings = masked_hidden_state.sum(dim=1) / attention_mask.sum(dim=1)
        else: # CLS
            embeddings = hidden_state[:, 0, :]
        return embeddings, output.router_logits if hasattr(output, 'router_logits') else None


class CustomEmbeddingEncoder(nn.Module):
    def __init__(self, embedding_dim, **kwargs):
        super(CustomEmbeddingEncoder, self).__init__()
        self.encoder = MoEBertEncoder(**kwargs)
        self.transform = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(self.encoder.encoder.config.hidden_size, embedding_dim), 
            nn.Tanh()
        )

    def forward(self, input_ids, attention_mask, router_labels = None, mean_pooling = True):
        out, router_logits = self.encoder(
            {'input_ids': input_ids, 'attention_mask': attention_mask}, 
            router_labels = router_labels, 
            mean_pooling = mean_pooling
        )
        return F.normalize(self.transform(out)), router_logits


class SiameseEncoder(nn.Module):
    def __init__(self, embedding_dim, **kwargs):
        super(SiameseEncoder, self).__init__()
        self.embedding = CustomEmbeddingEncoder(embedding_dim, **kwargs)

    def forward(self, batch_data, mean_pooling = True):
        query_emb, query_router_logits = self.embedding(
            batch_data['query_input_ids'], 
            batch_data['query_attention_mask'], 
            router_labels = batch_data['query_hash_ids'],
            mean_pooling = mean_pooling
        )
        passage_emb, passage_router_logits = self.embedding(
            batch_data['passage_input_ids'], 
            batch_data['passage_attention_mask'], 
            router_labels = batch_data['passage_hash_ids'],
            mean_pooling = mean_pooling
        )
        return query_emb, passage_emb, query_router_logits, passage_router_logits


if __name__ == '__main__':
    model = SiameseEncoder(128, moe_type = 'hash', num_experts = 16, topk = 1, intermediate_size_expert = 768)
    import torch
    in_data = {
        'query_input_ids': torch.randint(0, 1000, (2, 10)),
        'query_attention_mask': torch.tensor([
            [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
        ]),
        'query_hash_ids': torch.tensor([
            [0, 0, 1, 1, 1, 0, 0, 0, 0, 0],
            [1, 1, 1, 4, 9, 0, 0, 1, 0, 0]
        ]),
        'passage_input_ids': torch.randint(0, 1000, (2, 10)),
        'passage_attention_mask': torch.tensor([
            [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
        ]),
        'passage_hash_ids': torch.tensor([
            [0, 0, 1, 1, 1, 0, 0, 0, 0, 0],
            [1, 1, 1, 4, 9, 0, 0, 0, 0, 0]
        ]),
    }
    query_emb, passage_emb, query_router_logits, passage_router_logits = model(in_data)
    print(query_emb.shape, passage_emb.shape)
    router_logits = tuple((a + b) / 2 for a, b in zip(query_router_logits, passage_router_logits))
    # print(router_logits)