import ot
import torch

import numpy as np


def mini_batch(data, weights, batch_size):
    """
     Select a subset of sample uniformly at random without replacement
        parameters : 
        --------------------------------------------------------------
        data : np.array(n, d)
               data
        weights : np.array(n)
                  measure
        batch_size : int
                     minibatch size
    """
    id = np.random.choice(np.shape(data)[0], batch_size, replace=False)
    sub_weights = weights[id]/torch.sum(weights[id])
    return data[id], sub_weights, id


def muot(xs, xt, a, b, bs_s, bs_t, num_iter, M, reg_m):
    '''
        Compute the incomplete MBOT without replacement
    
       Parameters:
       -------------------------------------------
        xs : ndarray, shape (ns, d)
            Source data.
        xt : ndarray, shape (nt, d)
            Target data.
        a : ndarray, shape (ns,)
            Source measure.
        b : ndarray, shape (nt,)
            Target measure.
        bs_s : int
               Source minibatch size
        bs_t : int
               Target minibatch size
        num_iter : int
            number of iterations
        M : ndarray, shape (ns, nt)
            Cost matrix.
        reg_m : float
            Unbalanced parameter
        
       Returns
       --------------------------------------
        value of MBOT with replacement 
        
        Ref: https://github.com/kilianFatras/unbiased_minibatch_sinkhorn_GAN/blob/main/sample_complexity/mini_batch_ot.py
    '''
    cost = 0
    norm_coeff = 0
    for i in range(num_iter):
        #Test mini batch
        sub_xs, sub_weights_a, id_a = mini_batch(xs, a, bs_s)
        sub_xt, sub_weights_b, id_b = mini_batch(xt, b, bs_t)

        sub_M = M[id_a,:][:,id_b].clone()
        cur_mbot = ot.unbalanced.mm_unbalanced2(sub_weights_a, sub_weights_b, sub_M, reg_m=reg_m)

        #Test update gamma
        full_weight = torch.sum(a[id_a]) * torch.sum(b[id_b])
        cost += full_weight * cur_mbot
        norm_coeff += full_weight

    return (1/norm_coeff) * cost