import time

from AbstractClass.AbstractModel import *
import torch



class BasicModel(AbstractModel):
    def __init__(self, attention_matrix, embedding_net_1,embedding_net_2, weight_decode_net, refine_net, memory_matrix):
        super(BasicModel, self).__init__()
        self.refine_net = refine_net
        self.embedding_net_1 = embedding_net_1
        self.embedding_net_2 = embedding_net_2
        self.memory_matrix = memory_matrix
        self.attention_matrix = attention_matrix
        self.weight_decode_net = weight_decode_net

    def write(self, input_x, input_y):
        source_embedding, dest_embedding, edge_embedding = self.get_embedding(input_x)
        source_refined, dest_refined = self.get_refined(source_embedding, dest_embedding)
        row_address, col_address = self.get_address(source_refined, dest_refined)
        edge_embedding = edge_embedding * input_y.view(-1, 1)
        self.memory_matrix.write(row_address, col_address, edge_embedding)

    def clear(self):
        self.memory_matrix.clear()

    def normalize_attention_matrix(self):
        self.attention_matrix.normalize()

    def get_embedding(self, input_x):
        source = input_x[:, :input_x.shape[1] // 2]
        dest = input_x[:, input_x.shape[1] // 2:]
        source = self.embedding_net_1(source)
        dest = self.embedding_net_1(dest)
        edge_embedding = torch.cat((self.embedding_net_2(source), self.embedding_net_2(dest)), dim=1)
        return source, dest, edge_embedding

    def get_refined(self, source_embedding, dest_embedding):
        source = self.refine_net(source_embedding)
        dest = self.refine_net(dest_embedding)
        return source, dest

    def get_address(self, source_refined, dest_refined):
        row_address = self.attention_matrix(source_refined)
        col_address = self.attention_matrix(dest_refined)
        return row_address, col_address

    def query(self, input_x, stream_length):
        source_represetation, dest_represetation, edge_embedding = self.get_embedding(input_x)
        source_refined, dest_refined = self.get_refined(source_represetation, dest_represetation)
        row_address, col_address = self.get_address(source_refined, dest_refined)
        basic_read_info = self.memory_matrix.read(row_address, col_address, edge_embedding)
        decode_info = torch.cat((basic_read_info, edge_embedding, stream_length), dim=1)
        weight_pred = self.weight_decode_net(decode_info)
        return weight_pred


class ExtensionModelForExist(AbstractModel):
    def __init__(self, basic_model, decode_net):
        super(ExtensionModelForExist, self).__init__()
        self.decode_net = decode_net
        self.refine_net = basic_model.refine_net
        self.embedding_net_1 = basic_model.embedding_net_1
        self.embedding_net_2 = basic_model.embedding_net_2

        self.memory_matrix = basic_model.memory_matrix
        self.attention_matrix = basic_model.attention_matrix

    def write(self, input_x, input_y):
        source, dest, edge_embedding = self.get_embedding(input_x)
        source_refined, dest_refined = self.get_refined(source, dest)
        row_address, col_address = self.get_address(source_refined, dest_refined)
        edge_embedding = edge_embedding * input_y.view(-1, 1)
        self.memory_matrix.write(row_address, col_address, edge_embedding)

    def clear(self):
        self.memory_matrix.clear()

    def normalize_attention_matrix(self):
        self.attention_matrix.normalize()

    def get_embedding(self, input_x):
        source = input_x[:, :input_x.shape[1] // 2]
        dest = input_x[:, input_x.shape[1] // 2:]
        source = self.embedding_net_1(source)
        dest = self.embedding_net_1(dest)
        edge_embedding = torch.cat((self.embedding_net_2(source), self.embedding_net_2(dest)), dim=1)
        return source, dest, edge_embedding

    def get_refined(self, source_embedding, dest_embedding):
        source = self.refine_net(source_embedding)
        dest = self.refine_net(dest_embedding)
        return source, dest

    def get_address(self, source_refined, dest_refined):
        row_address = self.attention_matrix(source_refined)
        col_address = self.attention_matrix(dest_refined)
        return row_address, col_address

    def query(self, input_x, stream_length):
        source_represetation, dest_represetation, edge_embedding = self.get_embedding(input_x)
        source_refined, dest_refined = self.get_refined(source_represetation, dest_represetation)
        row_address, col_address = self.get_address(source_refined, dest_refined)
        basic_read_info = self.memory_matrix.read(row_address, col_address, edge_embedding)
        decode_info = torch.cat((basic_read_info, edge_embedding, stream_length), dim=1)
        exist_pred = self.decode_net(decode_info)
        return exist_pred


class ExtensionModelForDegree(AbstractModel):
    def __init__(self, basic_model, decode_net,memory_matrix):
        super(ExtensionModelForDegree, self).__init__()
        self.decode_net = decode_net
        self.memory_matrix = memory_matrix
        self.refine_net = basic_model.refine_net
        self.attention_matrix = basic_model.attention_matrix
        self.embedding_net_1 = basic_model.embedding_net_1
        self.embedding_net_2 = basic_model.embedding_net_2

    def write(self, input_x, input_y):
        source, dest, edge_embedding = self.get_embedding(input_x)
        source_refined, dest_refined = self.get_refined(source, dest)
        row_address, col_address = self.get_address(source_refined, dest_refined)
        edge_embedding = edge_embedding * input_y.view(-1, 1)
        self.memory_matrix.write(row_address, col_address, edge_embedding)

    def clear(self):
        self.memory_matrix.clear()

    def normalize_attention_matrix(self):
        self.attention_matrix.normalize()

    def get_embedding(self, input_x):
        source = input_x[:, :input_x.shape[1] // 2]
        dest = input_x[:, input_x.shape[1] // 2:]
        source = self.embedding_net_1(source)
        dest = self.embedding_net_1(dest)
        edge_embedding = torch.cat((self.embedding_net_2(source), self.embedding_net_2(dest)), dim=1)
        return source, dest, edge_embedding

    def get_node_embedding(self,node_input):
        node_represetation = self.embedding_net_1(node_input)
        node_embedding = self.embedding_net_2(node_represetation)
        return node_represetation,node_embedding

    def get_refined(self, source_represetation, dest_represetation):
        source = self.refine_net(source_represetation)
        dest = self.refine_net(dest_represetation)
        return source, dest

    def get_address(self, source_refined, dest_refined):
        row_address = self.attention_matrix(source_refined)
        col_address = self.attention_matrix(dest_refined)
        return row_address, col_address

    # query for degree
    def query(self, node_input, stream_length):
        node_represetation,node_embedding = self.get_node_embedding(node_input)
        source_refined = self.refine_net(node_represetation)
        row_address = self.attention_matrix(source_refined)
        batch_size = row_address.shape[0]
        basic_read_info = self.memory_matrix.read(row_address,node_embedding,)
        # decode_info = torch.cat((basic_read_info, stream_length), dim=1)
        stream_length = stream_length.unsqueeze(1)
        basic_read_info = torch.cat((basic_read_info,stream_length.repeat(1,self.memory_matrix.col_dim, 1)), dim=-1)
        degree_pred = self.decode_net(basic_read_info)
        return degree_pred
