import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import graph_lib
from model import utils as mutils
import ot
import time


def loss_fn(model, databatch, graph, cond=None, t=None, perturbed_batch=None):
    """
    Batch shape: [B, L] int. D given from graph
    """
    B, L = databatch.shape
    sourcebatch = graph.sample_limit(databatch.shape).to(databatch)


    try:

        source_exp = sourcebatch.unsqueeze(1)
        data_exp = databatch.unsqueeze(0)
        M = (source_exp != data_exp).sum(dim=2).float()
        M = 1*M / M.max()




        a = np.ones(B) / B
        b = np.ones(B) / B
        
        sinkhorn_plan = ot.sinkhorn(a, b, M.cpu().numpy(), 0.01)
        sinkhorn_plan = torch.from_numpy(sinkhorn_plan)

        flattened_plan = sinkhorn_plan.flatten()
        num_samples = B
        indices = torch.multinomial(flattened_plan, num_samples, replacement=True)

        sourcebatch_indices = indices // sinkhorn_plan.shape[1]
        databatch_indices = indices % sinkhorn_plan.shape[1]

        sourcebatch = sourcebatch[sourcebatch_indices]
        databatch = databatch[databatch_indices]           

        source_exp = sourcebatch.unsqueeze(1)
        data_exp = databatch.unsqueeze(0) 
        M = (source_exp != data_exp).sum(dim=2).float()
        return M.mean().cpu().numpy()


    except:
        print('Error calculating optimal transport, continuing with independent sampling this batch')
