import os
import numpy as np
import torch
import torch.nn.functional as F
import gco
import time
from lapjv import lapjv
import warnings
from match import get_onehot_matrix, mix_input

warnings.filterwarnings("ignore")

def to_one_hot(inp,num_classes,device='cuda'):
    '''target label to one-hot'''
    y_onehot = torch.zeros((inp.size(0), num_classes), dtype=torch.float32, device=device)
    y_onehot.scatter_(1, inp.unsqueeze(1), 1)
    return y_onehot


def get_lambda(alpha=1.0, alpha2=None):
    '''Return random lambda samples'''
    if alpha > 0.:
        if alpha2 is None:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = np.random.beta(alpha + 1e-2, alpha2 + 1e-2)
    else:
        lam = 1.
    return lam

def distance(z, dist_type='l2'):
    '''Return distance matrix between vectors''' 
    with torch.no_grad():
        diff = z.unsqueeze(1) - z.unsqueeze(0)
        if dist_type[:2]=='l2':
            A_dist = (diff ** 2).sum(-1)
            if dist_type=='l2':
                A_dist = torch.sqrt(A_dist)
            elif dist_type=='l22':
                pass
        elif dist_type=='l1':
            A_dist = diff.abs().sum(-1)
        elif dist_type=='linf':
            A_dist = diff.abs().max(-1)[0]
        else:
            return None
    return A_dist

def cost_matrix(width):
    '''Calculate transport cost matrix'''
    C = np.zeros([width**2, width**2], dtype=np.float32)
    for m_i in range(width**2):
        i1 = m_i // width
        j1 = m_i % width
        for m_j in range(width**2):
            i2 = m_j // width
            j2 = m_j % width
            C[m_i,m_j]= abs(i1-i2)**2 + abs(j1-j2)**2
    C = C/(width-1)**2
    C = torch.tensor(C).cuda()
    return C

cost_matrix_dict = {'2':cost_matrix(2).unsqueeze(0), '4':cost_matrix(4).unsqueeze(0), '8':cost_matrix(8).unsqueeze(0), '16':cost_matrix(16).unsqueeze(0)}
      

def mixup_process(out, target_reweighted, hidden=0, args=None, sc=None, A_dist=None):
    '''Various mixup process'''
    if args is not None:
        mixup_alpha = args.mixup_alpha
        mean = args.mean
        std = args.std
        
    if args.match_mix:
        #Co-Mixup
        m_block_num = args.m_block_num
        m_part = args.m_part

        with torch.no_grad():
            n_input = out.shape[0]
            n_output = n_input
            width = out.shape[-1]
         
            if A_dist is None:
                A_dist = torch.eye(n_input, device='cuda')
            A_base = torch.eye(m_part, device='cuda')

            if m_block_num == -1:
                m_block_num = 2**np.random.randint(1, 5)

            block_size = width // m_block_num
            sc = F.avg_pool2d(sc, block_size)
                        
            out_list = []
            target_list = []
            # Partition a batch
            for i in range(n_output//m_part):
                sc_part = sc[i * m_part: (i+1) * m_part]
                sc_norm = sc_part / sc_part.view(m_part, -1).sum(1).view(m_part, 1, 1)
                cost_matrix = -sc_norm

                A_dist_part = A_dist[i * m_part: (i+1) * m_part, i * m_part: (i+1) * m_part]
                A_dist_part = A_dist_part / torch.sum(A_dist_part) * m_part
                A = (1 - args.lam_dist) * A_base + args.lam_dist * A_dist_part
                
                # return a batch(partitioned) of mixup labeling
                mask_onehot = get_onehot_matrix(cost_matrix.detach(), A, m_part, 
                                        beta=args.m_beta, gamma=args.m_gamma, eta=args.m_eta, mixup_alpha=args.mixup_alpha, 
                                        thres=args.m_thres, thres_type=args.m_thres_type, 
                                        set_resolve=args.set_resolve, niter=args.m_niter, device='cuda')
                # generate image and corrsponding soft target
                output_part, target_part = mix_input(mask_onehot, out[i * m_part: (i+1) * m_part], target_reweighted[i * m_part: (i+1) * m_part])
                    
                out_list.append(output_part)
                target_list.append(target_part)

            out = torch.cat(out_list, dim=0)
            target_reweighted = torch.cat(target_list, dim=0)
    else:
        indices1 = np.arange(out.size(0))
        indices2 = np.random.permutation(out.size(0))

    lam = get_lambda(mixup_alpha)
    lam = np.array([lam]).astype('float32')
    
    if hidden:
        # Manifold Mixup
        out = out[indices1]*lam[0] + out[indices2]*(1-lam[0])
        ratio = torch.ones(out.shape[0], device='cuda') * lam[0]
    else:
        if args.box:
            # CutMix
            out, ratio = mixup_box(out[indices1], out[indices2], alpha=lam[0])
        elif args.graph:
            # PuzzleMix
            block_num = 2**np.random.randint(1, 5)
            if block_num > 1:
                out, ratio = mixup_graph(out, sc, indices1, indices2, block_num=block_num,
                                 alpha=lam, beta=args.beta, gamma=args.gamma, eta=args.eta, neigh_size=args.neigh_size, n_labels=args.n_labels,
                                 mean=mean, std=std, transport=args.transport, t_eps=args.t_eps, t_size=args.t_size, 
                                 )
            else: 
                ratio = torch.ones(out.shape[0], device='cuda')
        elif args.match_mix:
            pass 
        else:
            # Input Mixup
            out = out[indices1]*lam[0] + out[indices2]*(1-lam[0])
            ratio = torch.ones(out.shape[0], device='cuda') * lam[0]
    
    if args.match_mix:
        pass
    else:
        target_reweighted = target_reweighted[indices1] * ratio.unsqueeze(-1) + target_reweighted[indices2] * (1 - ratio.unsqueeze(-1))
    
    return out.contiguous(), target_reweighted

  

def graphcut_multi(unary1, unary2, pw_x, pw_y, alpha, beta, eta, n_labels=2, eps=1e-8):
    '''alpha-beta swap algorithm'''
    block_num = unary1.shape[0]
    large_val = 1000 * block_num ** 2 
    
    if n_labels == 2:
        prior= eta * np.array([-np.log(alpha + eps), -np.log(1 - alpha + eps)]) / block_num ** 2
    elif n_labels == 3:
        prior= eta * np.array([-np.log(alpha**2 + eps), -np.log(2 * alpha * (1-alpha) + eps), -np.log((1 - alpha)**2 + eps)]) / block_num ** 2
    elif n_labels == 4:
        prior= eta * np.array([-np.log(alpha**3 + eps), -np.log(3 * alpha **2 * (1-alpha) + eps), 
                             -np.log(3 * alpha * (1-alpha) **2 + eps), -np.log((1 - alpha)**3 + eps)]) / block_num ** 2
        
    unary_cost =  (large_val * np.stack([(1-lam) * unary1 + lam * unary2 + prior[i] for i, lam in enumerate(np.linspace(0,1, n_labels))], axis=-1)).astype(np.int32)
    pairwise_cost = np.zeros(shape=[n_labels, n_labels], dtype=np.float32)

    for i in range(n_labels):
        for j in range(n_labels):
            pairwise_cost[i, j] = (i-j)**2 / (n_labels-1)**2

    pw_x = (large_val * (pw_x + beta)).astype(np.int32)
    pw_y = (large_val * (pw_y + beta)).astype(np.int32)

    labels = 1.0 - gco.cut_grid_graph(unary_cost, pairwise_cost, pw_x, pw_y, algorithm='swap')/(n_labels-1)
    mask = labels.reshape(block_num, block_num)

    return mask

  
def neigh_penalty(input1, input2, k):
    '''data local smoothness term'''
    pw_x = input1[:,:,:-1,:] - input2[:,:,1:,:]
    pw_y = input1[:,:,:,:-1] - input2[:,:,:,1:]

    pw_x = pw_x[:,:,k-1::k,:]
    pw_y = pw_y[:,:,:,k-1::k]

    pw_x = F.avg_pool2d(pw_x.abs().mean(1), kernel_size=(1,k))
    pw_y = F.avg_pool2d(pw_y.abs().mean(1), kernel_size=(k,1))

    return pw_x, pw_y


def mixup_box(input1, input2, alpha=0.5):
    '''CutMix'''
    batch_size, _, height, width = input1.shape
    ratio = np.zeros([batch_size])
    
    rx = np.random.uniform(0,height)
    ry = np.random.uniform(0,width)
    rh = np.sqrt(1 - alpha) * height
    rw = np.sqrt(1 - alpha) * width
    x1 = int(np.clip(rx - rh / 2, a_min=0., a_max=height))
    x2 = int(np.clip(rx + rh / 2, a_min=0., a_max=height))
    y1 = int(np.clip(ry - rw / 2, a_min=0., a_max=width))
    y2 = int(np.clip(ry + rw / 2, a_min=0., a_max=width))
    input1[:, :, x1:x2, y1:y2] = input2[:, :, x1:x2, y1:y2]
    ratio += 1 - (x2-x1)*(y2-y1)/(width*height)
    
    ratio = torch.tensor(ratio, dtype=torch.float32).cuda()
    return input1, ratio


def mixup_graph(input, grad, indices1, indices2, block_num=2, alpha=0.5, beta=0., gamma=0., eta=0.2, neigh_size=2, n_labels=2, mean=None, std=None, transport=False, t_eps=10.0, t_size=16):
    '''Puzzle Mix'''
    input1 = input[indices1].clone()
    input2 = input[indices2].clone()
        
    batch_size, _, _, width = input1.shape
    block_size = width // block_num
    neigh_size = min(neigh_size, block_size)
    t_size = min(t_size, block_size)

    # prior parameter
    if alpha.shape[0] == 1:
        alpha = np.ones([batch_size]) * alpha[0]
    beta = beta/block_num/16
    
    mask=[]
    ratio = np.zeros([batch_size])

    # unary term
    grad_pool = F.avg_pool2d(grad, block_size)
    unary_torch = grad_pool / grad_pool.view(batch_size, -1).sum(1).view(batch_size, 1, 1)
    unary1_torch = unary_torch[indices1]
    unary2_torch = unary_torch[indices2]
     
    # calculate pairwise terms
    input_pool = F.avg_pool2d(input * std + mean, neigh_size)
    input1_pool = input_pool[indices1]
    input2_pool = input_pool[indices2]

    pw_x = torch.zeros([batch_size, 2, 2, block_num-1, block_num], device='cuda')
    pw_y = torch.zeros([batch_size, 2, 2, block_num, block_num-1], device='cuda')

    k = block_size//neigh_size

    pw_x[:, 0, 0], pw_y[:, 0, 0] = neigh_penalty(input2_pool, input2_pool, k)
    pw_x[:, 0, 1], pw_y[:, 0, 1] = neigh_penalty(input2_pool, input1_pool, k)
    pw_x[:, 1, 0], pw_y[:, 1, 0] = neigh_penalty(input1_pool, input2_pool, k)
    pw_x[:, 1, 1], pw_y[:, 1, 1] = neigh_penalty(input1_pool, input1_pool, k)

    pw_x = beta * gamma * pw_x
    pw_y = beta * gamma * pw_y
        
    # re-define unary and pairwise terms to draw graph
    unary1 = unary1_torch.clone()
    unary2 = unary2_torch.clone()
        
    unary2[:,:-1,:] += (pw_x[:,1,0] + pw_x[:,1,1])/2.
    unary1[:,:-1,:] += (pw_x[:,0,1] + pw_x[:,0,0])/2.
    unary2[:,1:,:] += (pw_x[:,0,1] + pw_x[:,1,1])/2.
    unary1[:,1:,:] += (pw_x[:,1,0] + pw_x[:,0,0])/2.

    unary2[:,:,:-1] += (pw_y[:,1,0] + pw_y[:,1,1])/2.
    unary1[:,:,:-1] += (pw_y[:,0,1] + pw_y[:,0,0])/2.
    unary2[:,:,1:] += (pw_y[:,0,1] + pw_y[:,1,1])/2.
    unary1[:,:,1:] += (pw_y[:,1,0] + pw_y[:,0,0])/2.
       
    pw_x = (pw_x[:,1,0] + pw_x[:,0,1] - pw_x[:,1,1] - pw_x[:,0,0])/2
    pw_y = (pw_y[:,1,0] + pw_y[:,0,1] - pw_y[:,1,1] - pw_y[:,0,0])/2

    unary1 = unary1.detach().cpu().numpy()
    unary2 = unary2.detach().cpu().numpy()
    pw_x = pw_x.detach().cpu().numpy()
    pw_y = pw_y.detach().cpu().numpy()

    # solve graphcut
    for i in range(batch_size):
        mask.append(graphcut_multi(unary2[i], unary1[i], pw_x[i], pw_y[i], alpha[i], beta, eta, n_labels))
        ratio[i] = mask[i].sum()

    # optimal mask
    mask = torch.tensor(mask, dtype=torch.float32, device='cuda')
    mask = mask.unsqueeze(1)
   
    # tranport
    if transport:
        if t_size == -1:
            t_block_num = block_num
            t_size = block_size
        elif t_size < block_size:
            # block_size % t_size should be 0 
            t_block_num = width // t_size
            mask = F.interpolate(mask, size=t_block_num)
            grad_pool = F.avg_pool2d(grad, t_size)
            unary_torch = grad_pool / grad_pool.view(batch_size, -1).sum(1).view(batch_size, 1, 1)
            unary1_torch = unary_torch[indices1]
            unary2_torch = unary_torch[indices2]
        else:
            t_block_num = block_num
            
        # input1
        plan = mask_transport(mask, unary1_torch, eps=t_eps)
        input1 = transport_image(input1, plan, batch_size, t_block_num, t_size)

        # input2
        plan = mask_transport(1-mask, unary2_torch, eps=t_eps)
        input2 = transport_image(input2, plan, batch_size, t_block_num, t_size)

    # final mask and mixed ratio
    mask = F.interpolate(mask, size=width)
    ratio = torch.tensor(ratio/block_num**2, dtype=torch.float32, device='cuda')
         
    return mask * input1 + (1-mask) * input2, ratio


def mask_transport(mask, grad_pool, eps=0.01):
    '''optimal transport plan'''
    batch_size = mask.shape[0]
    block_num = mask.shape[-1]

    n_iter = int(block_num)
    C = cost_matrix_dict[str(block_num)]
    
    z = (mask>0).float()
    cost = eps * C - grad_pool.reshape(-1, block_num**2, 1) * z.reshape(-1, 1, block_num**2)
    
    # row and col
    for _ in range(n_iter):
        row_best = cost.min(-1)[1]
        plan = torch.zeros_like(cost).scatter_(-1, row_best.unsqueeze(-1), 1)

        # column resolve
        cost_fight = plan * cost
        col_best = cost_fight.min(-2)[1]
        plan_win = torch.zeros_like(cost).scatter_(-2, col_best.unsqueeze(-2), 1) * plan
        plan_lose = (1-plan_win) * plan

    cost += plan_lose

    return plan_win
    

def transport_image(img, plan, batch_size, block_num, block_size):
    '''apply transport plan to images'''
    input_patch = img.reshape([batch_size, 3, block_num, block_size, block_num * block_size]).transpose(-2,-1)
    input_patch = input_patch.reshape([batch_size, 3, block_num, block_num, block_size, block_size]).transpose(-2,-1)
    input_patch = input_patch.reshape([batch_size, 3, block_num**2, block_size, block_size]).permute(0,1,3,4,2).unsqueeze(-1)

    input_transport = plan.transpose(-2,-1).unsqueeze(1).unsqueeze(1).unsqueeze(1).matmul(input_patch).squeeze(-1).permute(0,1,4,2,3)
    input_transport = input_transport.reshape([batch_size, 3, block_num, block_num, block_size, block_size])
    input_transport = input_transport.transpose(-2,-1).reshape([batch_size, 3, block_num, block_num * block_size, block_size])
    input_transport = input_transport.transpose(-2,-1).reshape([batch_size, 3, block_num * block_size, block_num * block_size])
    
    return input_transport

  

