import numpy as np
import torch
import torch.nn.functional as F

from itertools import permutations
from collections import defaultdict


def network_name(network_type,L,w1,w2,asymmetric,use_resnet):    
    asym='s'
    if asymmetric:
        asym='a'
    res='no_res'
    if use_resnet:
        res='res'
        
    if network_type in ['MLP','GCN']:
        w2='NA'
        asym='NA'
        res='NA'
    if network_type == 'DBM_P':
        w2='NA'
        asym='NA'        
    if network_type == 'DBM_A':
        asym='NA'
    if network_type == 'SSWN':
        asym='NA'
        
    return "{}-{}-{}-{}-{}-{}".format(network_type,L,w1,w2,asym,res)

def loss_weight_name(lr,lm,lu,ls,lf,lb,cp):
    return "lr{}-m{}-u{}-s{}-f{}-b{}-cp{}".format(
        lr,lm,lu,ls,lf,lb,cp
    )
def agent_name(distributions,N,M,N_max,M_max):
    if N==N_max and M==M_max:
        return "{}{}x{}".format(distributions,N,M)
    return "{}{}-{}x{}-{}".format(distributions,N,N_max,M,M_max)

def dat_filename(distrib_m,distrib_f,N,setname='validation'):
    return "{}{}_size-{:02d}_{}.dat".format(distrib_m,distrib_f,N,setname)


def count_network_size(net):
    params = 0
    for p in net.parameters():
        if p.requires_grad:
            params += p.numel()
    return params


def _enum_sm_bf(sab,sba,midx):
    m = sab.new_zeros(sab.shape[0],sab.shape[1])    
    matches = []
    for i,j in enumerate(midx):
        m[i,j]=1
    #print(m)
    if is_stable_match(sab,sba,m):
        matches.append(m)
    return matches
        
def enumerate_stable_match_bf(sab,sba):
    N = sab.shape[0]
    M = sba.shape[0]
    m = torch.zeros(N,M)    
    if hasattr(sab,'is_cuda') and sab.is_cuda:
        m = m.to(sab.device)
    matches = sum([_enum_sm_bf(sab,sba,midx) for midx in permutations(range(M),r=N)],[])
    return matches


def enumerate_stable_match(sab,sba):
    #m_prior_a = 
    #m_prior_b = 
    pass

def is_match(m):
    if (torch.sum(m,dim=0)>1).any() or (torch.sum(m,dim=1)>1).any():
        # m has any duplicated connections.
        return False
    return True

sm_col = torch.nn.Softmax(dim=-1)
sm_row = torch.nn.Softmax(dim=-2)      
    
@torch.jit.script
def binarize(m,na:int,nb:int):
    
    m = torch.minimum(sm_col(m),sm_row(m))

    if na >= nb:
        m = F.one_hot(m.argmax(dim=1),num_classes=nb)
    else:
        m = F.one_hot(m.argmax(dim=0),num_classes=na).t()
    return m

@torch.jit.script
def is_stable_match(sab,sba,m):    
    N = sab.shape[0]
    M = sba.shape[0]    
    
    
    for c in range(M):
        sab_selected = sab[:,c:c+1] # to keep dimension, sab[:,c] is implemented as sab[:,c:c+1]
        sab_selected = sab_selected.repeat_interleave(M,dim=1)            
        unsab = (m*torch.clamp(sab_selected-sab,min=0)).mean(dim=1)        

        sba_selected = sba[c:c+1,:] # keep dimension.
        sba_selected = sba_selected.repeat_interleave(N,dim=0)
        _sba = sba_selected.t()
        _m = m[:,c:c+1]
        _m = _m.repeat_interleave(N,dim=1)        
        unsba = (_m*torch.clamp(sba_selected-_sba,min=0)).mean(dim=0)                
        envy = (unsab*unsba).sum()
        if envy>0:
            return False    
    return True
             
@torch.jit.script    
def count_blocking_pairs(sab,sba,m):
    
    N = sab.shape[0]
    M = sba.shape[0]
        
    n_blocking_pair = 0
    
    for c in range(M):
        sab_selected = sab[:,c:c+1] # to keep dimension, sab[:,c] is implemented as sab[:,c:c+1]        
        sab_selected = sab_selected.repeat_interleave(M,dim=1)          
        #unsab = (m*torch.clamp(sab_selected-sab,min=0)).mean(dim=1)
        
        unsab_target = (m*(sab_selected-sab)>0).sum(dim=1) # the summuated value must be 0 or 1 due to multiplied m. 
        #print("unsab_target: ", (m*(sab_selected-sab)>0))
        sba_selected = sba[c:c+1,:] # keep dimension.
        sba_selected = sba_selected.repeat_interleave(N,dim=0)
        _sba = sba_selected.t()
        _m = m[:,c:c+1]
        _m = _m.repeat_interleave(N,dim=1)
        #unsba = (_m*torch.clamp(sba_selected-_sba,min=0)).mean(dim=0)
        unsba_target = (_m*(sba_selected-_sba)>0).sum(dim=0) # 0 or 1 as unsab_target
        #print("unsba_target: ", (_m*(sba_selected-_sba)>0))
        n = (unsab_target * unsba_target).sum()
        #print("number of found blocking_pair: ",n)
        n_blocking_pair += n
        #envy = (unsab*unsba).sum()
        #print("envy: ",envy)
    return float(n_blocking_pair)



def batch_eye(size, batch_size):
    x = torch.eye(size)
    x = x.reshape((1, size, size))
    y = x.repeat(batch_size, 1,1)
    return y

def calc_fairness_np(m,sab,sba):
    return np.abs((m*sab).sum() - (m.transpose()*sba).sum())/sab.shape[0]
def calc_satisfaction_np(m,sab,sba):
    return ((m*sab).sum() + (m.transpose()*sba).sum())/sab.shape[0]

def calc_fairness(m,sab,sba):
    return ((m*sab).sum() - (m.t()*sba).sum()).abs()/sab.shape[0]

def calc_satisfaction(m,sab,sba):
    return ((m*sab).sum() + (m.t()*sba).sum())/sab.shape[0]

def calc_balance(m,sab,sba):
    return min((m*sab).sum(),(m.t()*sba).sum())/sab.shape[0]


def calc_SexEqualityCost(m,cab,cba):
    return ((m*cab).sum() - (m.t()*cba).sum()).abs()

def calc_EgalitarianCost(m,cab,cba):
    return ((m*cab).sum() + (m.t()*cba).sum())

def calc_BalanceCost(m,cab,cba):
    return max((m*cab).sum() , (m.t()*cba).sum())
