import torch
import torch_geometric
from models.algorithm_reasoner import sinkhorn_normalize
class Batchle:
    ...

def sinkhorn_normalize_dense(batch, y, temperature, num_nodes_per_batch, steps=10, add_noise=False):

    Inf = 1e5
    from_, to = batch.edge_index[0], batch.edge_index[1]
    mat = torch_geometric.utils.to_dense_adj(batch.edge_index, edge_attr=y, batch=batch.batch)
    eye = torch.eye(mat.shape[1]).bool().unsqueeze(0).expand_as(mat)#.repeat
    mat[eye] = - Inf
    y=mat


    if add_noise:
        eps = -torch.log(-torch.log(torch.rand_like(y) + 1e-12) + 1e-12)
        y = y + eps

    y = y / temperature# - torch.log(torch.tensor(temperature)).to(y)

    # y = y.exp()
    for _ in range(steps):
        lse = torch.logsumexp(y, dim=-1).unsqueeze(-1)
        y = y - lse
        # y = torch.nn.functional.normalize(y, p=1, dim=-1)
        lse = torch.logsumexp(y, dim=-2).unsqueeze(-2)
        y = y - lse
        # y = torch_scatter.scatter_log_softmax(y, from_, dim_size=batch.num_nodes)
        # y = torch_scatter.scatter_log_softmax(y, to, dim_size=batch.num_nodes)

    return y

def sinkhorn_normalize_dense_logsoftmax(batch, y, temperature, num_nodes_per_batch, steps=10, add_noise=False):

    Inf = 1e5
    from_, to = batch.edge_index[0], batch.edge_index[1]
    mat = torch_geometric.utils.to_dense_adj(batch.edge_index, edge_attr=y, batch=batch.batch)
    eye = torch.eye(mat.shape[1]).bool().unsqueeze(0).expand_as(mat)#.repeat
    mat[eye] = - Inf
    y=mat


    if add_noise:
        eps = -torch.log(-torch.log(torch.rand_like(y) + 1e-12) + 1e-12)
        y = y + eps

    y = y / temperature# - torch.log(torch.tensor(temperature)).to(y)

    # y = y.exp()
    for _ in range(steps):
        y = torch.log_softmax(y, dim=-1)
        y = torch.log_softmax(y, dim=-2)
        # lse = torch.logsumexp(y, dim=-1).unsqueeze(-1)
        # y = y - lse
        # y = torch.nn.functional.normalize(y, p=1, dim=-1)
        # lse = torch.logsumexp(y, dim=-2).unsqueeze(-2)
        # y = y - lse
        # y = torch_scatter.scatter_log_softmax(y, from_, dim_size=batch.num_nodes)
        # y = torch_scatter.scatter_log_softmax(y, to, dim_size=batch.num_nodes)

    return y

if __name__ == '__main__':
    batchle = Batchle()
    BS = 2
    NN = 6
    STEPS = 10
    batchle.edge_index, _ = torch_geometric.utils.dense_to_sparse(torch.ones(2, NN, NN))
    batchle.num_nodes = NN*BS
    batchle.batch = torch.tensor([0]*NN+[1]*NN)
    # y = torch.full((batchle.num_nodes, batchle.num_nodes), -1e5)
    # y[0][1]=15
    # y[0][2]=12
    # y[1][2]=15
    # y[2][3]=15
    # y[2][4]=12
    # y[3][4]=15
    # y[4][0]=15
    # y = y.view(-1)

    # y_normed = sinkhorn_normalize(batchle, y, 0.1, steps=STEPS)
    # y_dense_normed = sinkhorn_normalize_dense(batchle, y, 0.1, steps=STEPS)
    # print(y_normed.view(batchle.num_nodes, batchle.num_nodes).exp())
    # print(y_dense_normed[0].view(batchle.num_nodes, batchle.num_nodes).exp())
    # print(y_normed.view(batchle.num_nodes, batchle.num_nodes).argmax(dim=-1))
    # breakpoint()
    while True:
        y = torch.randn(batchle.edge_index.shape[1])

        print(batchle.edge_index, y)
        # y_normed = sinkhorn_normalize(batchle, y, 1, steps=STEPS)
        y_dense_normed = sinkhorn_normalize_dense(batchle, y, .05, NN, steps=STEPS)
        y_dense_normed_lsm = sinkhorn_normalize_dense_logsoftmax(batchle, y, .05, NN, steps=STEPS)
        # print(y_normed.view(BS, NN, NN).exp())
        print(y_dense_normed.view(BS, NN, NN).exp())
        print(y_dense_normed_lsm.view(BS, NN, NN).exp())
        # assert torch.allclose(y_normed.view(BS, NN, NN), y_dense_normed.view(BS, NN, NN), rtol=1e-3)
        assert torch.allclose(y_dense_normed_lsm.view(BS, NN, NN), y_dense_normed.view(BS, NN, NN), rtol=1e-2)
        print(y_dense_normed.view(BS, NN, NN).argmax(dim=-1))
        for row in y_dense_normed.view(BS, NN, NN).argmax(dim=-1):
            print((row.unique().equal(row.sort().values)))
        breakpoint()

    '''
    1-> 0 -> 2
    2-> 1
    '''
