
import math
import torch
import torch.nn.functional as F
import numpy as np
import random

###########################################################################################
def augment_xy_data_by_8_fold(xy_data, training=False):
    # xy_data.shape = [B, N, 2]
    # x,y shape = [B, N, 1]

    x = xy_data[:, :, [0]]
    y = xy_data[:, :, [1]]

    dat1 = torch.cat((x, y), dim=2)#
    dat2 = torch.cat((1 - x, y), dim=2)#
    dat3 = torch.cat((x, 1 - y), dim=2)#
    dat4 = torch.cat((1 - x, 1 - y), dim=2)#

    dat5 = torch.cat((y, x), dim=2)#
    dat6 = torch.cat((1 - y, x), dim=2)#
    dat7 = torch.cat((y, 1 - x), dim=2)#
    dat8 = torch.cat((1 - y, 1 - x), dim=2)#

    # data_augmented.shape = [B, N, 16]
    if training:
        data_augmented = torch.cat(
            (dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=2
        )
        return data_augmented

    # data_augmented.shape = [8*B, N, 2]
    data_augmented = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0)
    return data_augmented

def data_augment(batch):
    batch = augment_xy_data_by_8_fold(batch, training=True)
    theta = []
    for i in range(8):
        theta.append(
            torch.atan(batch[:, :, i * 2 + 1] / batch[:, :, i * 2]).unsqueeze(-1)
        )
    #print('theta',theta[0])

    theta.append(batch)
    batch = torch.cat(theta, dim=2)
    return batch
import time


def augment_tunnel_per_fold(tot):  
    if not isinstance(tot, torch.Tensor):  
        raise ValueError("tot must be a torch.Tensor")  
  
    batch, tunnels, dim = tot.shape  # dim == 4 by default
    devices = tot.device  

    avec_1 = tot[:, :, 0:1]  
    bvec_1 = tot[:, :, 1:2] 
    avec_2 = tot[:, :, 2:3]  #  (batch, tunnels, 1)  
    bvec_2 = tot[:, :, 3:4]  #  (batch, tunnels, 1)  
  
    diff_avec = avec_1 - avec_2  #  (batch, tunnels, 1)  
    diff_bvec = bvec_1 - bvec_2  #  (batch, tunnels, 1)    
    distances = torch.sqrt(diff_avec**2 + diff_bvec**2)  #  (batch, tunnels, 1)  
    distances = distances.expand(-1, -1, dim // 4)  # (batch, tunnels, 1)  


    distances[:,tunnels//2:,:] = -distances[:,tunnels//2:,:]

    all_vec = torch.cat((tot, distances), dim=2)  #(batch, tunnels, dim+1)  
  
    return all_vec 

def augment_tunnel_data_by_8_fold(xy_data, training=False):

    # xy_data.shape = [B, N, 2]
    # x,y shape = [B, N, 1]

    #original
    # x = xy_data[:, :, [0]]
    # y = xy_data[:, :, [1]]
    # m = xy_data[:,:,[2]]
    # n = xy_data[:,:,[3]]

    #consider direction.
    x_1 = xy_data[:, :, [0]]
    y_1 = xy_data[:, :, [1]]
    m_1 = xy_data[:,:,[2]]
    n_1 = xy_data[:,:,[3]]    
    x = torch.cat((x_1,m_1),dim=1)
    y = torch.cat((y_1,n_1),dim=1)
    m = torch.cat((m_1,x_1),dim=1)
    n = torch.cat((n_1,y_1),dim=1)

    dat1 = torch.cat((x, y,m,n), dim=2)
    dat_per1 = augment_tunnel_per_fold(dat1)
    #dat_per1 = dat1
    dat2 = torch.cat((1 - x, y,1-m,n), dim=2)
    dat_per2 = augment_tunnel_per_fold(dat2)
    #dat_per2 = dat2
    dat3 = torch.cat((x, 1 - y,m,1-n), dim=2)
    dat_per3 = augment_tunnel_per_fold(dat3)
    #dat_per3 = dat3
    dat4 = torch.cat((1 - x, 1 - y,1-m,1-n), dim=2)
    dat_per4 = augment_tunnel_per_fold(dat4)
    #dat_per4 = dat4
    dat5 = torch.cat((y, x,n,m), dim=2)
    dat_per5 = augment_tunnel_per_fold(dat5)
    #dat_per5 = dat5
    dat6 = torch.cat((1 - y, x,1-n,m), dim=2)
    dat_per6 = augment_tunnel_per_fold(dat6) 
    #dat_per6 = dat6
    dat7 = torch.cat((y, 1 - x,n,1-m), dim=2)
    dat_per7 = augment_tunnel_per_fold(dat7)
    #dat_per7 = dat7
    dat8 = torch.cat((1 - y, 1 - x,1-n,1-m), dim=2)
    dat_per8 = augment_tunnel_per_fold(dat8)
    #dat_per8 = dat8
    # data_augmented.shape = [B, N, 16]
    if training:
        data_augmented = torch.cat(
            (dat_per1, dat_per2, dat_per3, dat_per4, dat_per5, dat_per6, dat_per7, dat_per8), dim=2
        )
        return data_augmented

    # data_augmented.shape = [8*B, N, 2]
    data_augmented = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0)
    return data_augmented

def New_augment_xy_data_by_8_fold(xy_data, tunnel,training=False):
    # xy_data.shape = [B, N, 2]
    # x,y shape = [B, N, 1]

    x = xy_data[:, :, [0]]
    y = xy_data[:, :, [1]]

    dat1 = torch.cat((x, y), dim=2)#
    #dat_per1 = augment_per_fold(dat1,tunnel)
    dat_per1 = dat1
    dat2 = torch.cat((1 - x, y), dim=2)#
    #dat_per2 = augment_per_fold(dat2,tunnel)
    dat_per2 = dat2
    dat3 = torch.cat((x, 1 - y), dim=2)#x
    #dat_per3 = augment_per_fold(dat3,tunnel)
    dat_per3 = dat3
    dat4 = torch.cat((1 - x, 1 - y), dim=2)
    #dat_per4 = augment_per_fold(dat4,tunnel)
    dat_per4 = dat4
    dat5 = torch.cat((y, x), dim=2)#
    #dat_per5 = augment_per_fold(dat5,tunnel)
    dat_per5 = dat5
    dat6 = torch.cat((1 - y, x), dim=2)#
    #dat_per6 = augment_per_fold(dat6,tunnel) 
    dat_per6 = dat6
    dat7 = torch.cat((y, 1 - x), dim=2)#
    #dat_per7 = augment_per_fold(dat7,tunnel)
    dat_per7 = dat7
    dat8 = torch.cat((1 - y, 1 - x), dim=2)#
    #dat_per8 = augment_per_fold(dat8,tunnel)
    dat_per8 = dat8
    # data_augmented.shape = [B, N, 16]
    if training:
        data_augmented = torch.cat(
            (dat_per1, dat_per2, dat_per3, dat_per4, dat_per5, dat_per6, dat_per7, dat_per8), dim=2
        )
        return data_augmented

    # data_augmented.shape = [8*B, N, 2]
    data_augmented = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0)
    return data_augmented


###################################  POMOPOMOPOMOPOMOPOMO  #############################
def augment_xy_data_by_8_fold_POMO(problems,training=True):
    # problems.shape: (batch, problem, 2)

    x = problems[:, :, [0]]
    y = problems[:, :, [1]]
    # x,y shape: (batch, problem, 1)

    dat1 = torch.cat((x, y), dim=2)
    dat2 = torch.cat((1 - x, y), dim=2)
    dat3 = torch.cat((x, 1 - y), dim=2)
    dat4 = torch.cat((1 - x, 1 - y), dim=2)
    dat5 = torch.cat((y, x), dim=2)
    dat6 = torch.cat((1 - y, x), dim=2)
    dat7 = torch.cat((y, 1 - x), dim=2)
    dat8 = torch.cat((1 - y, 1 - x), dim=2)
    if training:
        data_augmented = torch.cat(
            (dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=2
        )
        return data_augmented
    
    aug_problems = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0)
    # shape: (8*batch, problem, 2)

    return aug_problems

#######################################################################################
def generate_random_order(batch,num,seg):
    assert num >= seg
    random_all = np.zeros((batch,seg))
    for i in range(batch):
        sequence = list(range(0, num))
    
        random_order = random.sample(sequence, seg)
        random_all[i,:] = random_order
    #random_all = torch.cat(random_all).reshape(batch,num)
    return random_all

#Get fixed sequence
def generate_original_seq(batch,num,seg):
    random_all = np.zeros((batch,seg))
    for i in range(batch):
        #sequence = list(range(0, seg))
        #sequence = [0,2,4,6,8,10,12,14,16,18,1,3,5,7,9,11,13,15,17,19]
        #sequence = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19]
        sequence = [0,10,1,11,2,12,3,13,4,14,5,15,6,16,7,17]
        #sequence = [0,1,4,5,8,9,12,13,16,17,2,3,6,7,10,11,14,15,18,19]
        # random_order = random.sample(sequence, seg)
        random_all[i,:] = sequence
    #random_all = torch.cat(random_all).reshape(batch,num)
    return random_all

def generate_coord_from_indexes(coord,index):
    #coord.shape = [B,2N,2]
    #index.shape = [B,M,2]
    bat,num,_ = index.shape
    index_tensor_expanded = index.unsqueeze(-1).expand(-1,-1,-1,2)  # [B,M,2,1]
    index_tensor_exp = index_tensor_expanded.reshape(bat,2*num,2)
    result = torch.gather(coord, 1, index_tensor_exp.long())
    result = result.view(result.size(0), result.size(1)//2, -1)
    return result

def expand_all_as_tunnels(M, matrices):
    device = matrices.device
    outputs = []
    for matrix in matrices:
        existing_nums = set()
        
        for row in matrix:
            for num in row:
                existing_nums.add(num.item())  
        
        missing_nums = [num for num in range(M) if num not in existing_nums]
        new_matrix = matrix.clone()
        for num in missing_nums:
            new_matrix = torch.cat((new_matrix, torch.tensor([num, num]).unsqueeze(0).to(device)), dim=0)
        
        outputs.append(new_matrix)
    
    return torch.stack(outputs, dim=0)

def expand_every_tunnels(M, matrices):
    device = matrices.device
    outputs = []
    for matrix in matrices:
        existing_nums = set()
        
        for row in matrix:
            for num in row:
                existing_nums.add(num.item()) 
        
        missing_nums = [num for num in range(M) if num not in existing_nums]
        new_matrix = matrix.clone()
        exchange_matrix = torch.zeros_like(new_matrix)
        exchange_matrix[:,1]=new_matrix[:,0]
        exchange_matrix[:,0] = new_matrix[:,1]
        for num in missing_nums:
            new_matrix = torch.cat((new_matrix, torch.tensor([num, num]).unsqueeze(0).to(device)), dim=0)
        new_matrix = torch.cat((new_matrix,exchange_matrix),dim=0)
        outputs.append(new_matrix)
    
    return torch.stack(outputs, dim=0)


def find_corresponding_tunnel(tensor1, tensor2):
    B, N, _ = tensor1.size()
    B, K = tensor2.size()
    expanded_tensor2 = tensor2.unsqueeze(2).unsqueeze(3)
    expanded_tensor1 = tensor1.unsqueeze(1)
    comparison = (expanded_tensor2 == expanded_tensor1).any(dim=-1)
    index = comparison.nonzero()[:, 2]
    tensor3 = index.view(B, K)
    return tensor3

def find_other_element_in_pair(tensor1, tensor2):
    B, N, _ = tensor1.size()
    B, M = tensor2.size()

    corres_tunnel = find_corresponding_tunnel(tensor1,tensor2)
    output = tensor1[torch.arange(B).unsqueeze(1), corres_tunnel, :].sum(dim=2)
    corres_index = output - tensor2
    #print(corres_index)
    return corres_index.long()

def transform_tunnelindex_to_tunneltable(modifier, M):
    B, N ,_= modifier.shape
    transfer = np.zeros((B, N, M))

    for b in range(B):
        for n in range(N):
            for m in range(2):
                k = modifier[b, n, m]
                transfer[b, n, k] = 1

    return transfer

def create_output_matrix_with_batch(input, N):  
    batch_size, M, _ = input.shape  
      
    output = torch.zeros(batch_size, M, N, dtype=torch.int)    
    for i in range(batch_size):   
        curr_input = input[i]  
        output[i].scatter_(1, curr_input[:, 0].long().unsqueeze(1), 1)  
        output[i].scatter_(1, curr_input[:, 1].long().unsqueeze(1), 1)  
    
    return output  

############################################################################################
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.sum += (val * n)
        self.count += n

    @property
    def avg(self):
        return self.sum / self.count if self.count else 0
    

##########################################################################################
def multi_head_attention(q, k, v, mask=None):
    # q shape = (B, n_heads, n, key_dim)   : n can be either 1 or N
    # k,v shape = (B, n_heads, N, key_dim)
    # mask.shape = (B, group, N)

    B, n_heads, n, key_dim = q.shape

    # score.shape = (B, n_heads, n, N)
    score = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(q.size(-1))

    if mask is not None:
        score += mask[:, None, :, :].expand_as(score)

    shp = [q.size(0), q.size(-2), q.size(1) * q.size(-1)]
    attn = torch.matmul(F.softmax(score, dim=3), v).transpose(1, 2)
    return attn.reshape(*shp)


def make_heads(qkv, n_heads):
    shp = (qkv.size(0), qkv.size(1), n_heads, -1)
    return qkv.reshape(*shp).transpose(1, 2)