import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from src.networks import MatcherMLP, MatcherGIN, MatcherWeaveNet, MatcherWeaveNet_A, MatcherWeaveNetDual, MatcherDBM_P, MatcherDBM_A, MatcherSSWN
import os.path
from src.utils import calc_balance, calc_satisfaction, calc_fairness, binarize, is_match, is_stable_match, count_blocking_pairs, calc_SexEqualityCost, calc_EgalitarianCost, calc_BalanceCost  #enumerate_stable_match_bf
from tqdm import tqdm

class BaseModel():
    epsilon=10.0**-7
    def __init__(self,
                 device='auto',
                 checkpoints_dir='checkpoints'):
        
        self.save_dir = checkpoints_dir
        print(self.save_dir)
            
        self.device = device
        if isinstance(device,str):
            if device=='auto':
                if torch.cuda.is_available():
                    self.device = 'cuda'
                else:
                    self.device = 'cpu'
            self.device = torch.device(self.device)                
            
    def to_cuda(self,x):
        if self.device is None:
            return x
        return x.to(self.device)
        

    def save_network(self,filename):
        save_path = os.path.join(self.save_dir,filename)
        torch.save(self.net.state_dict(),save_path)
            
    def load_network(self, filename):
        save_path = os.path.join(self.save_dir,filename)
        param = torch.load(save_path)
        self.net.load_state_dict(param)
        
class BaseSMModel(BaseModel):
    sm_col = nn.Softmax(dim=-1)
    sm_row = nn.Softmax(dim=-2)    
    
    def __init__(self,
                 *args,
                 lr=0.00001,
                 lambda_m=1.0,
                 lambda_u=1.0,
                 lambda_s=0.,
                 lambda_f=0.,
                 lambda_b=0.,
                 constraint_p=2.0,                 
                 **kwargs):
        super(BaseSMModel,self).__init__(*args,**kwargs)
        self.lr = lr
        
        self.lambda_m = lambda_m
        self.lambda_u = lambda_u
        self.lambda_s = lambda_s
        self.lambda_f = lambda_f
        self.lambda_b = lambda_b
        self.constraint_p=constraint_p
                
    def set_optimizer(self):
        self.optimizer = torch.optim.Adam(self.net.parameters(),lr=self.lr,betas=(0.5,0.999))        

    @staticmethod
    def criterion_satisfaction_maximize(m,sab,sba):
        return - calc_satisfaction(m,sab,sba)

    @staticmethod
    def criterion_satisfaction_with_fairness(m,sab,sba):
        return - calc_balance(m,sab,sba)
        #return - min((m*sab).mean(),(m.t()*sba).mean())

    
    @staticmethod
    def criterion_fairness(m,sab,sba):
        return calc_fairness(m,sab,sba)
    
    def criterion(self,M,Ss,Ns,isTrain=True):
        batch_size = M.shape[0]
        l_mat_constraint=0
        l_unstability=0
        l_satisfaction=0
        l_fairness=0
        l_satisfaction_with_fairness = 0 
        completeness = 0

        for b in range(batch_size):
            na = Ns[0][b]
            nb = Ns[1][b]
            na = int(na)
            nb = int(nb)
            m = M[b][:na][:,:nb]
            sab = Ss[0][b][:na][:,:nb]
            sba = Ss[1][b][:nb][:,:na]
                    
            mc = self.sm_col(m)
            mr = self.sm_row(m)

            if self.constraint_p>=0:
                _l = self.criterion_matrix_constraint_normed_correlation(m,p=self.constraint_p)
            else:
                _l = self.criterion_matrix_constraint_HV(mc,mr)
            l_mat_constraint += _l

            _l = (self.criterion_unstability_HV(sab,sba,mc)+self.criterion_unstability_HV(sab,sba,mr))/2
            l_unstability += _l
            
            _l = (self.criterion_satisfaction_maximize(mc,sab,sba)+self.criterion_satisfaction_maximize(mr,sab,sba))/2
            l_satisfaction += _l

            _l = (self.criterion_satisfaction_with_fairness(mc,sab,sba)+self.criterion_satisfaction_with_fairness(mr,sab,sba))/2
            l_satisfaction_with_fairness += _l
            
            _l = (self.criterion_fairness(mc,sab,sba)+self.criterion_fairness(mr,sab,sba))/2
            l_fairness += _l

        l = self.lambda_m * l_mat_constraint
        l += self.lambda_u * l_unstability
        l += self.lambda_s * l_satisfaction
        l += self.lambda_b * l_satisfaction_with_fairness
        l += self.lambda_f * l_fairness
        
        return l, l_mat_constraint, l_unstability, l_satisfaction, l_fairness, l_satisfaction_with_fairness
  
            
    @torch.jit.script
    def criterion_matrix_constraint_HV(mc,mr):
        # minimize diff of Mcol and Mrow
        return abs(mc-mr).mean()
    
    @torch.jit.script
    def criterion_matrix_constraint_normed_correlation(m,p:float=2,epsilon:float=10**-7):
        m_exp = torch.clamp(m,epsilon).exp() # add epsilon to m for numerical stability.
        mc_norm = m_exp.norm(p,dim=-1,keepdim=True)
        mr_norm = m_exp.norm(p,dim=-2,keepdim=True)
        N,M = m.shape[-2:]
        Z = (N+M)/(2*N*M)
        dM = 1-((m_exp)/mc_norm * (m_exp)/mr_norm).sum()*(Z)
        return dM
    
    
    @torch.jit.script
    def criterion_unstability_HV(sab,sba,m,epsilon:float=10**-7):
        uns=m.new_zeros(1)
        assert sab.shape[0]==sba.shape[1] and sab.shape[0]==m.shape[0]
        N = m.shape[0]
        M = m.shape[1]
        for c in range(M): 
            sab_selected = sab[:,c:c+1] # to keep dimension, sab[:,c] is implemented as sab[:,c:c+1]
            sab_selected = torch.repeat_interleave(sab_selected,M,dim=1)        
            # c           i=0         i=1         i=2
            # 0 tensor([1.0000e-07, 1.0000e-07, 1.0000e-07])
            unsab = (m*torch.clamp(sab_selected-sab,epsilon)).sum(dim=1)
            
            sba_selected = sba[c:c+1,:] # keep dimension.
            sba_selected = torch.repeat_interleave(sba_selected,N,dim=0)
            _sba = sba_selected.t()
            
            _m = m.new_zeros(N,N)
            _m += m[:,c:c+1]
            
            
            # c           i=0         i=1         i=2
            # 0 tensor([1.0000e-07, 1.0000e-07, 1.0000e-07])       
            unsba = (_m*torch.clamp(sba_selected-_sba,epsilon)).sum(dim=0)
            # Admarl unsab*unsba unsab[0]*unsba[0], unsab[1]*unsba[1],unsab[2]*unsba[2],
            if c==0:
                uns = (unsab*unsba).sum()
            else:
                uns += (unsab*unsba).sum()

        return uns
    
    def train(self,ss,ns,forwarding_depth=-1):
        #xs = [self.to_cuda(x) for x in xs]
        ss = [self.to_cuda(s) for s in ss]
        self.optimizer.zero_grad()
        if hasattr(self.net,"support_depth") and self.net.support_depth:
            M = self.net.forward(ss,depth=forwarding_depth)
        else:
            M = self.net.forward(ss)        
       
        loss,l_mc,l_uns,l_sat,l_fair,l_balance = self.criterion(M,ss,ns)

        #print(loss,l_mc,l_uns,l_sat,l_fair,"\n")
        loss = loss/M.shape[0]
        loss.backward()
        self.optimizer.step()
        return M, l_mc, l_uns, l_sat, l_fair, l_balance


    def validate(self,dl_val,device):
        self.net.eval()
        num_stable_matching = 0
        L_m = 0.0
        L_u = 0.0
        L_f = 0.0
        L_b = 0.0
        num_samples = 0
        
        SEq = 0
        Bal = 0
        
        with torch.no_grad():
            #for Sab, Sba, Na, Nb, matches, fairness, satisfaction, gs_matches, SEqs, Bals in dl_val:  
            print("Validation")
            t = tqdm(dl_val)
            for batch in t:    
                Sab, Sba, Na, Nb = [b.to(device) for b in batch[:4]]
                num_samples+=Sab.shape[0]
                Ss = [Sab,Sba]
                Ns = [Na,Nb]
                M = self.net.forward(Ss)
                _,l_mc,l_uns,_,l_fair,l_balance = self.criterion(M,Ss,Ns)
                L_m += l_mc
                L_u += l_uns
                L_f += l_fair
                L_b += l_balance

                for i, (sab, sba, na, nb, m) in enumerate(zip(Sab,Sba,Na,Nb,M)):                
                    m = torch.minimum(self.sm_col(m),self.sm_row(m))
                    m = binarize(m,na,nb)
                    
                    
                    M[i] = self.to_cuda(m)
                    if is_match(M[i]) and is_stable_match(sab,sba,M[i]):
                        num_stable_matching+=1

                    cab = torch.round((Ns[1]-1)*(1-(sab-0.1)/0.9))
                    cba = torch.round((Ns[0]-1)*(1-(sba-0.1)/0.9))
                    SEq += calc_SexEqualityCost(m,cab,cba)
                    Bal += calc_BalanceCost(m,cab,cba)
        return L_m.detach().cpu()/num_samples, L_u.detach().cpu()/num_samples, L_f.detach().cpu()/num_samples, L_b.detach().cpu()/num_samples, float(num_stable_matching)/num_samples, SEq.detach().cpu()/num_samples, Bal.detach().cpu()/num_samples
    
                

    
#################### ModelSM ####################
    
class ModelSM(BaseSMModel):
    def __init__(self,
                 *args,
                 network_type='WeaveNet',
                 sab_shape=None,
                 L=18, # number of layers
                 w1=64, # output channels of middle layers
                 w2=64, # output channels of inside convolution of FW layer.
                 use_resnet=True,
                 asymmetric=False,
                 **kwargs):
        super(ModelSM,self).__init__(*args,**kwargs)

        if network_type=='WeaveNet':
            self.net = MatcherWeaveNet(L=L, D=w1, inner_conv_out_channels=w2,use_resnet=use_resnet,asymmetric=asymmetric)
        elif network_type=='WeaveNet_A':
            self.net = MatcherWeaveNet_A(L=L, D=w1, key_query_channels=w2,use_resnet=use_resnet,asymmetric=asymmetric)
        elif network_type=='WeaveNetDual':
            self.net = MatcherWeaveNetDual(L=L, D=w1,  inner_conv_out_channels=w2, use_resnet=use_resnet)
        elif network_type=='DBM_P':
            self.net = MatcherDBM_P(L=L, D=w1, use_resnet=use_resnet)
        elif network_type=='DBM_A':
            self.net = MatcherDBM_A(L=L, D=w1, key_query_channels=w2,use_resnet=use_resnet)
        elif network_type=='SSWN':
            self.net = MatcherSSWN(L=L, D=w1, inner_conv_out_channels=w2,use_resnet=use_resnet)
        else:
            assert(sab_shape is not None)
            N=sab_shape[-2]
            M=sab_shape[-1]

            if network_type=='MLP':            
                self.net = MatcherMLP(N,M,
                                      input_channels=1,
                                      mid_nc=w1,
                                      n_layers=L,
                                      norm='batch')
            elif network_type=='GIN':
                self.net = MatcherGIN(N,M,self.device,L,w1)
            else:
                raise RuntimeError('Unknown network `{}`.'.format(network_type))

            
        self.net = self.to_cuda(self.net)
        self.set_optimizer()
        

       
