import math
import os
import torch
import torch.nn as nn 
from torch.nn.parameter import Parameter

class ByteSubwordCombine(nn.Module):
    """
    num_embeddings: dictionary size
    """
    def __init__(self, num_embeddings, embedding_dim, padding_idx=258, layernum=1, interdim=0,
                 std=1.0, relu_dropout=0.1, aggre='avg', subword_bytes_file='', embeddings_file='./embeddings') -> None:
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.layernum = layernum
        self.aggre = aggre
        self.relu_dropout = relu_dropout
        
        self.byte_dict_size = 128
        self.max_byte_seq_len = 16


        self.alpha_word = 0.5
        self.alpha_byte = 1 - self.alpha_word
        
        self.padding_idx = self.byte_dict_size
        self.embedding_file = embeddings_file
        self.subword_bytes_file = subword_bytes_file

        if os.path.isfile(subword_bytes_file):
            subword2bytes = torch.load(subword_bytes_file)
        else:
            subword2bytes = torch.randint(0, self.byte_dict_size, (num_embeddings, self.max_byte_seq_len), dtype=torch.long)
            dirname = '/'.join(subword_bytes_file.split('/')[:-1])
            if not os.path.isdir(dirname):
                os.makedirs(dirname)
            torch.save(subword2bytes, subword_bytes_file)
        
        self.register_buffer('subword2bytes', subword2bytes)

        #  dim: [256, embedding_dim]

        if os.path.isfile(self.embedding_file):
            weight = torch.load(self.embedding_file)
            self.byte_embedding = nn.Embedding(self.byte_dict_size + 1, embedding_dim, _weight=weight, padding_idx=self.byte_dict_size)
        else:
            self.byte_embedding = nn.Embedding(self.byte_dict_size + 1, embedding_dim, padding_idx=self.byte_dict_size)

        self.weight = self.byte_embedding.weight

        if interdim == 0:
            interdim = embedding_dim

        self.embedding = nn.Embedding(self.num_embeddings, embedding_dim)


        subword = torch.tensor([i for i in range(self.num_embeddings)])
        self.register_buffer('subword', subword)

        # mean_m = torch.zeros(self.max_byte_seq_len, embedding_dim)
        # std_m = torch.Tensor(self.max_byte_seq_len, embedding_dim)
        # nn.init.constant_(std_m, std * (self.max_byte_seq_len ** -0.5))
        # self.pos_embedding = torch.normal(mean_m, std_m)

        

    
        
        # self.weight_matrices = nn.ParameterList([nn.Parameter(torch.Tensor(embedding_dim, interdim)) if i+1 == self.layernum
        #                                          else (nn.Parameter(torch.Tensor(interdim, embedding_dim)) if i == 0
        #                                          else nn.Parameter(torch.Tensor(interdim, embedding_dim))) for i in range(self.layernum)])

        self.linears = nn.ModuleList([nn.Linear(interdim, embedding_dim) if i+1 == self.layernum
                                      else (nn.Linear(embedding_dim, interdim) if i == 0
                                      else nn.Linear(embedding_dim, interdim)) for i in range(self.layernum)])


    def construct_matrix_for_output_layer(self):
        """
        construct the embedding matrix for the decoder output layer
        if we share the embedding between the encoder and decoder, the embedding
        matrix for decoder layer is needed.
        """
        
        matrix = self.forward(self.subword, dropout=0)
        return matrix


    def forward(self, input, dropout=None):
        bytes_input = self.subword2bytes[input]

        embeddings = self.byte_embedding(bytes_input)
        word_embeddings = self.embedding(bytes_input)


        if embeddings.shape[0] == self.num_embeddings:
            byte_embedding = embeddings.sum(dim=1)
            # embed = embed_sum / nb
        else:  
            byte_embedding = embeddings.sum(dim=2)
            # embed = embed_sum / nb

        relu_dropout = self.relu_dropout if dropout is None else dropout

        embed = word_embeddings + byte_embedding

        for i in range(self.layernum):
            # embed = nn.functional.linear(embed, self.weight_matrices[i])
            # print("embedding size:", embed.size())
            embed = self.linears[i](embed)
            if i + 1 != self.layernum:
                embed = nn.functional.relu(embed)
                embed = nn.functional.dropout(embed, p=relu_dropout, training=self.training)
        return embed


if __name__ == "__main__":
    subword_bytes_file='/home/mengjiao/Documents/workspace/2023S/byte_subword/subword_byte_table.pt'
    if os.path.isfile(subword_bytes_file):
        subword_bytes_table = torch.load(subword_bytes_file)
    else:
        print("please provide the subword to bytes table.")
        exit()
    subword_bytes_table = subword_bytes_table.float()
    # print(subword_bytes_table.size())
    # exit()
    subword_bytes_table.requires_grad = False
    max_byte_seq_len = subword_bytes_table.shape[1]

    if not os.path.isfile(subword_bytes_file):
        # randint(dim0, dim1, (low, high))
        bytesequence = torch.randint(0, 4, (10, 6))
        torch.save(bytesequence, subword_bytes_file)
    bytesequence = torch.load(subword_bytes_file)
    print(bytesequence)

    embed_layer = ByteSubword(10, 3, subword_bytes_file=subword_bytes_file)
    input = torch.LongTensor([[0, 0, 2, 8],[1, 3, 4, 4]])
    embed = embed_layer(input)

    print(embed.size())

    print(embed)
    
    # input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
    # input = torch.LongTensor([[0, 0, 2, 8],[1, 3, 4, 4]])
    
    # subword_bytes_sequence = subword_bytes(input)

