
import torch
import numpy as np


def get_random_problems(batch_size, problem_size,tunnel_size):
    problems = torch.rand(size=(batch_size, problem_size, 2))
    tunnels = generate_random_order(batch_size,problem_size,tunnel_size)
    # problems.shape: (batch, problem, 2)
    return problems,tunnels[:,:,0],tunnels[:,:,1]

def augment_xy_data_by_8_fold(problems):
    # problems.shape: (batch, problem, 2)

    x = problems[:, :, [0]]
    y = problems[:, :, [1]]
    # x,y shape: (batch, problem, 1)

    dat1 = torch.cat((x, y), dim=2)
    dat2 = torch.cat((1 - x, y), dim=2)
    dat3 = torch.cat((x, 1 - y), dim=2)
    dat4 = torch.cat((1 - x, 1 - y), dim=2)
    dat5 = torch.cat((y, x), dim=2)
    dat6 = torch.cat((1 - y, x), dim=2)
    dat7 = torch.cat((y, 1 - x), dim=2)
    dat8 = torch.cat((1 - y, 1 - x), dim=2)

    aug_problems = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0)
    # shape: (8*batch, problem, 2)

    return aug_problems

import random
def generate_random_order(batch,nb_nodes,nb_tunnels):
    assert nb_nodes >= (2 * nb_tunnels)
    random_all = np.zeros((batch,2*nb_tunnels))
    for i in range(batch):
        sequence = list(range(0, nb_nodes))
    
        random_order = random.sample(sequence, 2*nb_tunnels)
        random_all[i,:] = random_order
    #random_all = torch.cat(random_all).reshape(batch,num)
    random_all_exc = random_all.reshape(batch,nb_tunnels,2)
    return torch.tensor(random_all_exc)