import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import graph_lib
from model import utils as mutils
import ot
import time


def loss_fn(model, databatch, graph, cond=None, t=None, perturbed_batch=None):
    """
    Batch shape: [B, L] int. D given from graph
    """
    B, L = databatch.shape
    sourcebatch = graph.sample_limit(databatch.shape).to(databatch)


    try:

        # source_exp = sourcebatch.unsqueeze(1)
        # data_exp = databatch.unsqueeze(0)
        # M = (source_exp != data_exp).sum(dim=2).float()
        # M = 1*M / M.max()




        # a = np.ones(B) / B
        # b = np.ones(B) / B
        
        # sinkhorn_plan = ot.sinkhorn(a, b, M.cpu().numpy(), 0.01)
        # sinkhorn_plan = torch.from_numpy(sinkhorn_plan)

        # flattened_plan = sinkhorn_plan.flatten()
        # num_samples = B
        # indices = torch.multinomial(flattened_plan, num_samples, replacement=True)

        # sourcebatch_indices = indices // sinkhorn_plan.shape[1]
        # databatch_indices = indices % sinkhorn_plan.shape[1]

        # sourcebatch = sourcebatch[sourcebatch_indices]
        # databatch = databatch[databatch_indices]           

        source_exp = sourcebatch.unsqueeze(1)
        data_exp = databatch.unsqueeze(0) 
        M = (source_exp != data_exp).sum(dim=2).float()
        return M.mean().cpu().numpy()


    except:
        print('Error calculating optimal transport, continuing with independent sampling this batch')
