import numpy as np

class UOT_Example:
    @staticmethod
    def KL(a,b):
        return np.sum(a*(np.log(a)-np.log(b))) - np.sum(a) + np.sum(b)
    
    def __init__(self,n,m, er=1., ec=1.):
        self.n = n
        self.m = m
        self.rand_M()
        self.rand_marginal()
        self.er = er
        self.ec = ec
        
    def set_er(self, er):
        self.er = er
        
    def set_ec(self, ec):
        self.ec = ec
        
    def rand_M(self,):
        self.M = np.random.dirichlet([1,]*self.n,self.m).T
        self.cost = -np.log(self.M)
        print(self.cost)
        
    def rand_marginal(self):
        self.theta = np.random.dirichlet([1,]*self.m)
        self.eta = np.random.dirichlet([1,]*self.n)
        print(self.theta, self.eta)
        
    def set_plan(self, plan):
        self.plan = plan
    
    def value(self, plan=None):
        if plan is None:
            return self.value(self.plan)
        return (np.sum(plan*self.cost) + 
                self.er * self.KL(np.sum(plan, axis=1), self.eta) + 
                self.ec * self.KL(np.sum(plan, axis=0), self.theta))
    
    def grad(self, plan):
        return (self.cost + 
                self.er * (np.log(np.sum(plan, axis=1)) - np.log(self.eta)).reshape(-1,1) +
                self.ec * (np.log(np.sum(plan, axis=0)) - np.log(self.theta)))
    
    
    
def pattern_well_generated(pattern: np.ndarray)->bool:
    '''
    Parameters
    ----------
    pattern: np.ndarray of shape (n,m), with entries 0 or 1
             1 for position with `cost + er * R + ec * C = 0`
             0 for position with `cost + er * R + ec * C > 0`
    Return
    ------
    True or False
    '''
    return not (np.any(np.sum(pattern,axis=1)==0) or np.any(np.sum(pattern,axis=0)==0))
    
    
def pattern_decompose(pattern: np.ndarray)->list:
    '''
    Parameters
    ----------
    pattern: np.ndarray of shape (n,m), with entries 0 or 1
             1 for position with `cost + er * R + ec * C = 0`
             0 for position with `cost + er * R + ec * C > 0`
    Return
    ------
    List of indices
    [group_1, group_2, ...]
    where 
    group_i = [row_list_i, col_list_i]
    '''
    assert pattern_well_generated(pattern)
    
    nonzero_entries = np.argwhere(pattern != 0)
    # entry_set = set([tuple(x) for x in tmp])
    flag = nonzero_entries.copy()
    ret = []
    while len(flag > 0):
        rows = []
        cols = []
        inc_rows = [flag[0,0]]
        inc_cols = [flag[0,1]]
        flag = flag[1:,:]
        ir = len(inc_rows)
        ic = len(inc_cols)
        while ir + ic > 0:
            for i in inc_rows:
                t = flag[flag[:,0]==i][:,1]
                flag = flag[flag[:,0]!=i]
                inc_cols += list(t)
            rows += inc_rows
            inc_rows = []
            ic = len(inc_cols)
            for i in inc_cols:
                t = flag[flag[:,1]==i][:,0]
                flag = flag[flag[:,1]!=i]
                inc_rows += list(t)
            cols += inc_cols
            inc_cols = []
            ir = len(inc_rows)
            # print(ir,ic)
        
        print(flag, rows, cols)
        ret += [[sorted(list(set(rows))), sorted(list(set(cols)))]]
    return ret
    
    
def block_balance(row_sums: np.ndarray, col_sums: np.ndarray, group: list, 
                  er: np.float64=1, ec: np.float64=1) -> list:
    '''
    
    '''
    row, col = group
    ratio = np.sum(np.exp(row_sums[row]))/np.sum(np.exp(col_sums[col]))
    row_sums[row] -= np.log(ratio) * ec / (er+ec)
    col_sums[col] += np.log(ratio) * er / (er+ec)
    return row_sums, col_sums

def block_values(pattern: np.ndarray, cost: np.ndarray, 
                 rows: list=None, cols: list=None,
                 er: np.float64=1., ec: np.float64=1.) -> list:
    '''
    Parameters
    ----------
    pattern: np.ndarray of shape (n,m), with entries 0 or 1
             1 for position with `cost + er * R + ec * C = 0`
             0 for position with `cost + er * R + ec * C > 0`
    
    cost: np.ndarray of shape (n,m), 
          pre-processed cost matrix, `true_cost - er * ln(r_i) - ec * ln(c_j)`

    Return
    ------
    True or False
    '''
    assert pattern.shape == cost.shape
    n, m = pattern.shape
    row_sums = np.zeros(n, dtype=np.float64)
    col_sums = np.zeros(m, dtype=np.float64)
    
    nonzero_entries = np.argwhere(pattern != 0)
    # entry_set = set([tuple(x) for x in tmp])
    flag = nonzero_entries.copy()
    ret = []
    while len(flag > 0):
        rows = []
        cols = []
        inc_rows = [flag[0,0]]
        inc_cols = [flag[0,1]]
        row_sums[flag[0,0]] = -cost[tuple(flag[0])] / 2 / er
        col_sums[flag[0,1]] = -cost[tuple(flag[0])] / 2 / ec
        flag = flag[1:,:]
        ir = len(inc_rows)
        ic = len(inc_cols)
        while ir + ic > 0:
            for i in inc_rows:
                t = flag[flag[:,0]==i][:,1]
                flag = flag[flag[:,0]!=i]
                for k in t:
                    if col_sums[k] != 0:
                        print("col_sums[%d] ="%k, col_sums[k])
                    col_sums[k] = (-cost[i,k] - row_sums[i] * er) / ec
                inc_cols += list(t)
            rows += inc_rows
            inc_rows = []
            ic = len(inc_cols)
            for i in inc_cols:
                t = flag[flag[:,1]==i][:,0]
                flag = flag[flag[:,1]!=i]
                for k in t:
                    if row_sums[k] != 0:
                        print("row_sums[k] =", row_sums[k])
                    row_sums[k] = (-cost[k,i] - col_sums[i] * ec) / er
                inc_rows += list(t)
            cols += inc_cols
            inc_cols = []
            ir = len(inc_rows)
            # print(ir,ic)
        
        print(er*row_sums.reshape(-1,1)+ec*col_sums + cost)
        # print(flag, rows, cols)
        row_sums, col_sums = block_balance(row_sums, col_sums, [sorted(rows), sorted(cols)], er, ec)
        ret += [[sorted(rows), sorted(cols)]]
        
        
    return ret, row_sums, col_sums
    
    
def verify_sums(cost, row_sums, col_sums, pattern, er=1., ec=1.):
    ineq = cost+er*row_sums.reshape(-1,1)+ec*col_sums
    if np.any(ineq<-1e-8):
        return False
    if np.any(np.abs(pattern * ineq)>1e-8):
        return False
    return True


if __name__ == "__main__":
    pass