import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import graph_lib
from model import utils as mutils
import ot
import time


def loss_fn(model, databatch, graph, cond=None, t=None, perturbed_batch=None):
    """
    Batch shape: [B, L] int. D given from graph
    """
    B, L = databatch.shape
    sourcebatch = graph.sample_limit(databatch.shape).to(databatch)

    #print('B:', B)


    try:
        sourcebatch_embedd = model.module.vocab_embed(sourcebatch).reshape(B, -1).detach()
        databatch_embedd = model.module.vocab_embed(databatch).reshape(B, -1).detach()
        sourcebatch_norm = (sourcebatch_embedd ** 2).sum(dim=1).reshape(-1, 1)  # Shape (B, 1)
        databatch_norm = (databatch_embedd ** 2).sum(dim=1).reshape(1, -1)  # Shape (1, B)
        M = sourcebatch_norm + databatch_norm - 2 * sourcebatch_embedd @ databatch_embedd.T 
        M = torch.clamp(M, min=0)
        M = M / M.max()

        a = np.ones(B) / B
        b = np.ones(B) / B
        
        sinkhorn_plan = ot.sinkhorn(a, b, M.cpu().numpy(), 0.0104)
        sinkhorn_plan = torch.from_numpy(sinkhorn_plan)

        flattened_plan = sinkhorn_plan.flatten()
        num_samples = B
        indices = torch.multinomial(flattened_plan, num_samples, replacement=True)

        # print('max: ', flattened_plan.max())
        # print('min: ', flattened_plan.min())
        # print('avg: ', flattened_plan.mean())
        # print('sum: ', flattened_plan.sum())

        sourcebatch_indices = indices // sinkhorn_plan.shape[1]
        databatch_indices = indices % sinkhorn_plan.shape[1]

        sourcebatch = sourcebatch[sourcebatch_indices]
        databatch = databatch[databatch_indices]
    except:
        print('Error calculating optimal transport, continuing with independent sampling this batch')
