import time

from torch import nn
import torch

from SourceCode.ModelModule.SparseSoftmax import Sparsemax


class BasicMemoryMatrixAndCM12(nn.Module):
    def __init__(self, row_dim, col_dim, edge_embedding_dim, device):
        super().__init__()
        self.row_dim = row_dim
        self.col_dim = col_dim
        self.edge_embedding_dim = edge_embedding_dim
        self.device = device
        self.memory_matrix = None
        self.sparse_max = Sparsemax(row_dim * col_dim)
        self.clear()


    def clear(self):
        with torch.autograd.no_grad():
            # self.memory_matrix.data = self.memory_matrix.data * 0
            self.memory_matrix = torch.zeros(self.row_dim, self.col_dim, self.edge_embedding_dim, device=self.device,
                                             requires_grad=False)

    def write(self, row_addresses, col_addresses, edge_embeddings):
        address = row_addresses.unsqueeze(2).matmul(col_addresses.unsqueeze(1))
        write_matrix = address.unsqueeze(-1).mul(edge_embeddings.unsqueeze(1).unsqueeze(1))
        write_matrix = write_matrix.sum(dim=0, keepdims=False)
        self.memory_matrix = self.memory_matrix + write_matrix

    def read(self, source, dest, edge_embedding):
        batch_size = source.shape[0]
        basic_read_matrix = self.basic_read_attention_sum(source, dest)

        cm_embedding = torch.where(edge_embedding > 0.00001, edge_embedding, torch.zeros_like(edge_embedding) + 0.00001)
        zero_add_vec = torch.where(torch.abs(edge_embedding) < 0.00001, torch.zeros_like(edge_embedding) + 10000,
                                   torch.zeros_like(edge_embedding))
        min_cm_read1 = self.cm_read1(basic_read_matrix, cm_embedding, zero_add_vec)
        min_cm_read2 = self.cm_read2(basic_read_matrix, cm_embedding, zero_add_vec)
        return torch.cat((basic_read_matrix, min_cm_read1, min_cm_read2), dim=-1)

    def basic_read_attention_sum(self, source, dest):
        addresses = source.unsqueeze(2).matmul(dest.unsqueeze(1))
        basic_read_matrix = addresses.unsqueeze(3).mul(self.memory_matrix.unsqueeze(0))
        basic_read_matrix = basic_read_matrix.sum((1, 2), keepdim=False)
        return basic_read_matrix

    def cm_read2(self, basic_read_matrix, cm_embedding, zero_add_vec):
        min_info, _ = basic_read_matrix.min(dim=-1, keepdim=True)
        basic_read_matrix_minusmin = basic_read_matrix - min_info
        basic_read_matrix_minusmin = torch.where(torch.abs(basic_read_matrix_minusmin) < 0.0001,
                                                 torch.zeros_like(basic_read_matrix_minusmin) + 100000,
                                                 basic_read_matrix_minusmin)
        basic_read_matrix_minusmin = basic_read_matrix_minusmin + zero_add_vec
        cm_read = basic_read_matrix_minusmin.div(cm_embedding)
        min_cm_read, _ = cm_read.min(dim=-1, keepdim=True)
        return torch.cat((min_info, min_cm_read), dim=-1)

    def cm_read1(self, basic_read_matrix, cm_embedding, zero_add_vec):
        # 如果 edge_embedding 的值很小的话，给它加上一个很大的数值，否则加0，保持原来的数值
        cm_basic_read_matrix = basic_read_matrix + zero_add_vec
        cm_read = cm_basic_read_matrix.div(cm_embedding)
        min_cm_read, _ = cm_read.min(dim=-1, keepdim=True)
        return min_cm_read


class AccelerateBasicMemoryMatrixAndCM12(nn.Module):
    def __init__(self, row_dim, col_dim, edge_embedding_dim, device):
        super().__init__()
        self.row_dim = row_dim
        self.col_dim = col_dim
        self.edge_embedding_dim = edge_embedding_dim
        self.device = device
        self.memory_matrix = None
        self.sparse_max = Sparsemax(row_dim * col_dim)
        self.clear()


    def clear(self):
        with torch.autograd.no_grad():
            # self.memory_matrix.data = self.memory_matrix.data * 0
            self.memory_matrix = torch.zeros(self.row_dim, self.col_dim, self.edge_embedding_dim, device=self.device,
                                             requires_grad=False)

    def write(self, row_addresses, col_addresses, edge_embeddings):
        address = row_addresses.unsqueeze(2).matmul(col_addresses.unsqueeze(1))
        write_matrix = address.unsqueeze(-1).mul(edge_embeddings.unsqueeze(1).unsqueeze(1))
        write_matrix = write_matrix.sum(dim=0, keepdims=False)
        self.memory_matrix = self.memory_matrix + write_matrix

    def read(self, source, dest, edge_embedding):
        batch_size = source.shape[0]
        basic_read_matrix = self.basic_read_attention_sum(source, dest)

        cm_embedding = torch.where(edge_embedding > 0.00001, edge_embedding, torch.zeros_like(edge_embedding) + 0.00001)
        zero_add_vec = torch.where(torch.abs(edge_embedding) < 0.00001, torch.zeros_like(edge_embedding) + 10000,
                                   torch.zeros_like(edge_embedding))
        min_cm_read1 = self.cm_read1(basic_read_matrix, cm_embedding, zero_add_vec)
        min_cm_read2 = self.cm_read2(basic_read_matrix, cm_embedding, zero_add_vec)
        return torch.cat((basic_read_matrix, min_cm_read1, min_cm_read2), dim=-1)

    def basic_read_attention_sum(self, source, dest):
        addresses = source.unsqueeze(2).matmul(dest.unsqueeze(1))
        basic_read_matrix = addresses.view(source.shape[0],-1).mm(self.memory_matrix.view(-1,self.edge_embedding_dim))
        return basic_read_matrix

    def cm_read2(self, basic_read_matrix, cm_embedding, zero_add_vec):
        min_info, _ = basic_read_matrix.min(dim=-1, keepdim=True)
        basic_read_matrix_minusmin = basic_read_matrix - min_info
        basic_read_matrix_minusmin = torch.where(torch.abs(basic_read_matrix_minusmin) < 0.0001,
                                                 torch.zeros_like(basic_read_matrix_minusmin) + 100000,
                                                 basic_read_matrix_minusmin)
        basic_read_matrix_minusmin = basic_read_matrix_minusmin + zero_add_vec
        cm_read = basic_read_matrix_minusmin.div(cm_embedding)
        min_cm_read, _ = cm_read.min(dim=-1, keepdim=True)
        return torch.cat((min_info, min_cm_read), dim=-1)

    def cm_read1(self, basic_read_matrix, cm_embedding, zero_add_vec):
        cm_basic_read_matrix = basic_read_matrix + zero_add_vec
        cm_read = cm_basic_read_matrix.div(cm_embedding)
        min_cm_read, _ = cm_read.min(dim=-1, keepdim=True)
        return min_cm_read



class SparseBasicMemoryMatrixAndCM12(nn.Module):
    def __init__(self, row_dim, col_dim, edge_embedding_dim, device):
        super().__init__()
        self.row_dim = row_dim
        self.col_dim = col_dim
        self.edge_embedding_dim = edge_embedding_dim
        self.device = device
        self.memory_matrix = None
        self.sparse_max = Sparsemax(row_dim * col_dim)
        self.clear()


    def clear(self):
        with torch.autograd.no_grad():
            # self.memory_matrix.data = self.memory_matrix.data * 0
            self.memory_matrix = torch.zeros(self.row_dim, self.col_dim, self.edge_embedding_dim, device=self.device,
                                             requires_grad=False)

    def write(self, row_addresses, col_addresses, edge_embeddings):
        row_addresses = row_addresses.unsqueeze(2)
        # col_addresses = torch.where(col_addresses>0.5,1,0)
        batch_size = row_addresses.shape[0]
        row_num = row_addresses.shape[1]
        col_num = col_addresses.shape[1]
        address = row_addresses.matmul(col_addresses.unsqueeze(1))
        address = address.view(batch_size,row_num*col_num).unsqueeze(-1).repeat(1,1,self.edge_embedding_dim)
        write_matrix = address.mul(edge_embeddings.unsqueeze(1))
        write_matrix = write_matrix.sum(dim=0)
        write_matrix = write_matrix.view(row_num,col_num,self.edge_embedding_dim)
        self.memory_matrix = self.memory_matrix + write_matrix

    def read(self, source, dest, edge_embedding):
        print('sparse read')
        batch_size = source.shape[0]
        basic_read_matrix = self.basic_read_attention_sum(source, dest)

        cm_embedding = torch.where(edge_embedding > 0.00001, edge_embedding, torch.zeros_like(edge_embedding) + 0.00001)
        zero_add_vec = torch.where(torch.abs(edge_embedding) < 0.00001, torch.zeros_like(edge_embedding) + 10000,
                                   torch.zeros_like(edge_embedding))
        min_cm_read1 = self.cm_read1(basic_read_matrix, cm_embedding, zero_add_vec)
        min_cm_read2 = self.cm_read2(basic_read_matrix, cm_embedding, zero_add_vec)
        return torch.cat((basic_read_matrix, min_cm_read1, min_cm_read2), dim=-1)

    def basic_read_attention_sum(self, source, dest):
        # addresses = torch.smm(source.unsqueeze(2).to_sparse_coo(), dest.unsqueeze(1))
        addresses = source.unsqueeze(2).matmul(dest.unsqueeze(1))
        addresses = addresses.view(source.shape[0],-1)
        a = time.time()
        addresses = addresses.to_sparse_csr()
        b = time.time()
        print('b-a',b-a)
        basic_read_matrix = addresses.matmul(self.memory_matrix.view(-1,self.edge_embedding_dim))
        c = time.time()
        print('c-b',c-b)

        # basic_read_matrix = addresses.unsqueeze(3).mul(self.memory_matrix.unsqueeze(0))
        # basic_read_matrix = basic_read_matrix.sum((1, 2), keepdim=False)
        return basic_read_matrix

    def cm_read2(self, basic_read_matrix, cm_embedding, zero_add_vec):
        min_info, _ = basic_read_matrix.min(dim=-1, keepdim=True)
        basic_read_matrix_minusmin = basic_read_matrix - min_info
        basic_read_matrix_minusmin = torch.where(torch.abs(basic_read_matrix_minusmin) < 0.0001,
                                                 torch.zeros_like(basic_read_matrix_minusmin) + 100000,
                                                 basic_read_matrix_minusmin)
        basic_read_matrix_minusmin = basic_read_matrix_minusmin + zero_add_vec
        cm_read = basic_read_matrix_minusmin.div(cm_embedding)
        min_cm_read, _ = cm_read.min(dim=-1, keepdim=True)
        return torch.cat((min_info, min_cm_read), dim=-1)

    def cm_read1(self, basic_read_matrix, cm_embedding, zero_add_vec):
        cm_basic_read_matrix = basic_read_matrix + zero_add_vec
        cm_read = cm_basic_read_matrix.div(cm_embedding)
        min_cm_read, _ = cm_read.min(dim=-1, keepdim=True)
        return min_cm_read




class BasicMemoryMatrixAndCM12ForDegree(nn.Module):
    def __init__(self, row_dim, col_dim, edge_embedding_dim, device):
        super().__init__()
        self.row_dim = row_dim
        self.col_dim = col_dim
        self.edge_embedding_dim = edge_embedding_dim
        self.device = device
        self.memory_matrix = None
        self.sparse_max = Sparsemax(row_dim * col_dim)

        self.read_dim = edge_embedding_dim +3 +1
        self.clear()

    def clear(self):
        with torch.autograd.no_grad():
            self.memory_matrix = torch.zeros(self.row_dim, self.col_dim, self.edge_embedding_dim, device=self.device,
                                             requires_grad=False)

    def write(self, row_addresses, col_addresses, edge_embeddings):
        address = row_addresses.unsqueeze(2).matmul(col_addresses.unsqueeze(1))
        write_matrix = address.unsqueeze(-1).mul(edge_embeddings.unsqueeze(1).unsqueeze(1))
        write_matrix = write_matrix.sum(dim=0, keepdims=False)
        self.memory_matrix = self.memory_matrix + write_matrix

    def read(self, row_address,node_embedding):
        batch_size = row_address.shape[0]
        node_embedding_length = node_embedding.shape[1]
        row_addresses = row_address.view(row_address.shape[0], row_address.shape[1], 1, 1)
        batch_info_matrix = row_addresses.mul(self.memory_matrix.unsqueeze(0))[:, :, :, :node_embedding_length]
        batch_info_matrix = batch_info_matrix.sum(dim=1)
        cm_embedding = torch.where(node_embedding > 0.00001, node_embedding, torch.zeros_like(node_embedding) + 0.00001)
        zero_add_vec = torch.where(torch.abs(node_embedding) < 0.0001, torch.zeros_like(node_embedding) + 100000,
                                   torch.zeros_like(node_embedding))
        col_info = self.fine_grit_read(batch_info_matrix,node_embedding,cm_embedding,zero_add_vec)

        return col_info

    def fine_grit_read(self,batch_info_matrix,node_embedding,cm_embedding,zero_add_vec):
        cm_embedding_repeat = cm_embedding.unsqueeze(1)
        cm_embedding_repeat = cm_embedding_repeat.repeat(1, self.col_dim, 1)
        zero_add_vec_repeat = zero_add_vec.unsqueeze(1)
        zero_add_vec_repeat = zero_add_vec_repeat.repeat(1, self.col_dim, 1)
        # slot_info_matrix = batch_info_matrix[:,col_slot_num,:]
        cm_info1 = self.cm_read1(batch_info_matrix,cm_embedding_repeat,zero_add_vec_repeat)
        cm_info2 = self.cm_read2(batch_info_matrix,cm_embedding_repeat,zero_add_vec_repeat)
        read_info = torch.cat((batch_info_matrix, cm_info1, cm_info2, cm_embedding_repeat), dim=-1)
        return read_info

    def cm_read2(self, basic_read_matrix, cm_embedding, zero_add_vec):
        min_info, _ = basic_read_matrix.min(dim=-1, keepdim=True)
        basic_read_matrix_minusmin = basic_read_matrix - min_info
        basic_read_matrix_minusmin = torch.where(torch.abs(basic_read_matrix_minusmin) < 0.0001,
                                                 torch.zeros_like(basic_read_matrix_minusmin) + 100000,
                                                 basic_read_matrix_minusmin)
        basic_read_matrix_minusmin = basic_read_matrix_minusmin + zero_add_vec
        cm_read = basic_read_matrix_minusmin.div(cm_embedding)
        min_cm_read, _ = cm_read.min(dim=-1, keepdim=True)
        return torch.cat((min_info, min_cm_read),dim=-1)

    def cm_read1(self, basic_read_matrix, cm_embedding, zero_add_vec):
        cm_basic_read_matrix = basic_read_matrix + zero_add_vec
        cm_read = cm_basic_read_matrix.div(cm_embedding)
        min_cm_read, _ = cm_read.min(dim=-1, keepdim=True)
        return min_cm_read



class SparseBasicMemoryMatrixAndCM12ForDegree(nn.Module):
    def __init__(self, row_dim, col_dim, edge_embedding_dim, device):
        super().__init__()
        self.row_dim = row_dim
        self.col_dim = col_dim
        self.edge_embedding_dim = edge_embedding_dim
        self.device = device
        self.memory_matrix = None
        self.sparse_max = Sparsemax(row_dim * col_dim)
        self.read_dim = edge_embedding_dim +3 +1
        self.clear()

    def clear(self):
        with torch.autograd.no_grad():
            self.memory_matrix = torch.zeros(self.row_dim, self.col_dim, self.edge_embedding_dim, device=self.device,
                                             requires_grad=False)

    def write(self, row_addresses, col_addresses, edge_embeddings):
        address = row_addresses.unsqueeze(2).matmul(col_addresses.unsqueeze(1))
        write_matrix = address.unsqueeze(-1).mul(edge_embeddings.unsqueeze(1).unsqueeze(1))
        write_matrix = write_matrix.sum(dim=0, keepdims=False)
        self.memory_matrix = self.memory_matrix + write_matrix

    def read(self, row_address,node_embedding):
        batch_size = row_address.shape[0]
        node_embedding_length = node_embedding.shape[1]
        a = time.time()
        row_address = row_address
        b = time.time()
        print('b-a:',b-a)
        a = time.time()

        harf_memory = self.memory_matrix.view(self.row_dim,-1)
        batch_info_matrix = torch.matmul(row_address,harf_memory).view(batch_size,self.col_dim,-1)
        b = time.time()
        print('b-a:',b-a)
        a = time.time()

        batch_info_matrix = batch_info_matrix[:,:,:node_embedding_length]
        b = time.time()
        print('b-a:',b-a)

        cm_embedding = torch.where(node_embedding > 0.00001, node_embedding, torch.zeros_like(node_embedding) + 0.00001)
        zero_add_vec = torch.where(torch.abs(node_embedding) < 0.0001, torch.zeros_like(node_embedding) + 100000,
                                   torch.zeros_like(node_embedding))
        col_info = self.fine_grit_read(batch_info_matrix,node_embedding,cm_embedding,zero_add_vec)

        return col_info

    def fine_grit_read(self,batch_info_matrix,node_embedding,cm_embedding,zero_add_vec):
        cm_embedding_repeat = cm_embedding.unsqueeze(1)
        cm_embedding_repeat = cm_embedding_repeat.repeat(1, self.col_dim, 1)
        zero_add_vec_repeat = zero_add_vec.unsqueeze(1)
        zero_add_vec_repeat = zero_add_vec_repeat.repeat(1, self.col_dim, 1)
        # slot_info_matrix = batch_info_matrix[:,col_slot_num,:]
        cm_info1 = self.cm_read1(batch_info_matrix,cm_embedding_repeat,zero_add_vec_repeat)
        cm_info2 = self.cm_read2(batch_info_matrix,cm_embedding_repeat,zero_add_vec_repeat)
        read_info = torch.cat((batch_info_matrix, cm_info1, cm_info2, cm_embedding_repeat), dim=-1)
        return read_info

    def cm_read2(self, basic_read_matrix, cm_embedding, zero_add_vec):
        min_info, _ = basic_read_matrix.min(dim=-1, keepdim=True)
        basic_read_matrix_minusmin = basic_read_matrix - min_info
        basic_read_matrix_minusmin = torch.where(torch.abs(basic_read_matrix_minusmin) < 0.0001,
                                                 torch.zeros_like(basic_read_matrix_minusmin) + 100000,
                                                 basic_read_matrix_minusmin)
        basic_read_matrix_minusmin = basic_read_matrix_minusmin + zero_add_vec
        cm_read = basic_read_matrix_minusmin.div(cm_embedding)
        min_cm_read, _ = cm_read.min(dim=-1, keepdim=True)
        return torch.cat((min_info, min_cm_read),dim=-1)

    def cm_read1(self, basic_read_matrix, cm_embedding, zero_add_vec):
        cm_basic_read_matrix = basic_read_matrix + zero_add_vec
        cm_read = cm_basic_read_matrix.div(cm_embedding)
        min_cm_read, _ = cm_read.min(dim=-1, keepdim=True)
        return min_cm_read


class AccelerateBasicMemoryMatrixAndCM12ForDegree(nn.Module):
    def __init__(self, row_dim, col_dim, edge_embedding_dim, device):
        super().__init__()
        self.row_dim = row_dim
        self.col_dim = col_dim
        self.edge_embedding_dim = edge_embedding_dim
        self.device = device
        self.memory_matrix = None
        self.sparse_max = Sparsemax(row_dim * col_dim)
        self.read_dim = edge_embedding_dim +3 +1
        self.clear()

    def clear(self):
        with torch.autograd.no_grad():
            # self.memory_matrix.data = self.memory_matrix.data * 0
            self.memory_matrix = torch.zeros(self.row_dim, self.col_dim, self.edge_embedding_dim, device=self.device,
                                             requires_grad=False)

    def write(self, row_addresses, col_addresses, edge_embeddings):
        address = row_addresses.unsqueeze(2).matmul(col_addresses.unsqueeze(1))
        write_matrix = address.unsqueeze(-1).mul(edge_embeddings.unsqueeze(1).unsqueeze(1))
        write_matrix = write_matrix.sum(dim=0, keepdims=False)
        self.memory_matrix = self.memory_matrix + write_matrix

    def read(self, row_address,node_embedding):
        batch_size = row_address.shape[0]
        node_embedding_length = node_embedding.shape[1]
        harf_memory = self.memory_matrix.view(self.row_dim,-1)
        batch_info_matrix = torch.matmul(row_address,harf_memory).view(batch_size,self.col_dim,-1)
        batch_info_matrix = batch_info_matrix[:,:,:node_embedding_length]
        cm_embedding = torch.where(node_embedding > 0.00001, node_embedding, torch.zeros_like(node_embedding) + 0.00001)
        zero_add_vec = torch.where(torch.abs(node_embedding) < 0.0001, torch.zeros_like(node_embedding) + 100000,
                                   torch.zeros_like(node_embedding))
        col_info = self.fine_grit_read(batch_info_matrix,node_embedding,cm_embedding,zero_add_vec)

        return col_info

    def fine_grit_read(self,batch_info_matrix,node_embedding,cm_embedding,zero_add_vec):
        cm_embedding_repeat = cm_embedding.unsqueeze(1)
        cm_embedding_repeat = cm_embedding_repeat.repeat(1, self.col_dim, 1)
        zero_add_vec_repeat = zero_add_vec.unsqueeze(1)
        zero_add_vec_repeat = zero_add_vec_repeat.repeat(1, self.col_dim, 1)
        # slot_info_matrix = batch_info_matrix[:,col_slot_num,:]
        cm_info1 = self.cm_read1(batch_info_matrix,cm_embedding_repeat,zero_add_vec_repeat)
        cm_info2 = self.cm_read2(batch_info_matrix,cm_embedding_repeat,zero_add_vec_repeat)
        read_info = torch.cat((batch_info_matrix, cm_info1, cm_info2, cm_embedding_repeat), dim=-1)
        return read_info

    def cm_read2(self, basic_read_matrix, cm_embedding, zero_add_vec):
        min_info, _ = basic_read_matrix.min(dim=-1, keepdim=True)
        basic_read_matrix_minusmin = basic_read_matrix - min_info
        basic_read_matrix_minusmin = torch.where(torch.abs(basic_read_matrix_minusmin) < 0.0001,
                                                 torch.zeros_like(basic_read_matrix_minusmin) + 100000,
                                                 basic_read_matrix_minusmin)
        basic_read_matrix_minusmin = basic_read_matrix_minusmin + zero_add_vec
        cm_read = basic_read_matrix_minusmin.div(cm_embedding)
        min_cm_read, _ = cm_read.min(dim=-1, keepdim=True)
        return torch.cat((min_info, min_cm_read),dim=-1)

    def cm_read1(self, basic_read_matrix, cm_embedding, zero_add_vec):
        cm_basic_read_matrix = basic_read_matrix + zero_add_vec
        cm_read = cm_basic_read_matrix.div(cm_embedding)
        min_cm_read, _ = cm_read.min(dim=-1, keepdim=True)
        return min_cm_read

