import torch

def get_random_problems(distribution, batch_size, problem_size, emax=None, set_seed=False, seed=1000):
    if set_seed:
        torch.manual_seed(seed)

    if distribution == "MG_fix" or distribution == "MG_fix_sym":
        return get_random_problems_mg_fix(distribution, batch_size, problem_size, emax)
    if distribution == "MG_flex" or distribution == "MG_flex_sym":
        return get_random_problems_mg_flex(distribution, batch_size, problem_size, emax)
    elif distribution == "TMAT":
        return get_random_problems_TMAT(batch_size, problem_size)
    elif distribution == "XASY":
        return get_random_problems_XASY(batch_size, problem_size)

def get_random_problems_mg_fix(distribution, batch_size, problem_size, emax):

    matrix = torch.rand(batch_size, problem_size * (problem_size - 1), emax, 3)
    problems, _ = torch.sort(matrix, dim=2)  # Sort along each dimension
    problems[:, :, :, 1] = problems[:, :, :, 1].flip(dims=(2,))
    problems[:, :, :, 2] = matrix[:, :, :, 2] # Randomly assign the third dimension

    idx_i, idx_j = torch.triu_indices(problem_size, problem_size, offset=1)
    idx_pairs = torch.stack([idx_i, idx_j], dim=1)  # shape (num_edges, 2)
    reverse_idx_pairs = torch.stack([idx_j, idx_i], dim=1)  # shape (num_edges, 2)
    all_idx_pairs = torch.cat([idx_pairs, reverse_idx_pairs], dim=1).view(-1, 2)  # Flatten into pairs
    indices = all_idx_pairs.unsqueeze(0).repeat(batch_size, 1, 1).to(torch.int32)

    edge_dists = problems.reshape(batch_size, emax * (problem_size * (problem_size - 1)), 3)
    edge_indices = indices.repeat_interleave(emax, dim=1)

    if distribution == "MG_fix":
        return edge_dists, edge_indices, emax
    #elif distribution == "MG_sym": 
        #reshaped = edge_dists.view(batch_size, -1, 4, 2)
        #reshaped[:, :, [2, 3], :] = reshaped[:, :, [0, 1], :]
        #edge_dists = reshaped.view(batch_size, 2 * (problem_size * (problem_size - 1)), 2)
        #return edge_dists, edge_indices

def get_random_problems_mg_flex(distribution, batch_size, problem_size, emax):
    matrix = torch.rand(batch_size, problem_size * (problem_size - 1), emax, 3)
    result = matrix.clone()

    # Filter out all dominated edges
    for i in range(emax): # For each parallel edge
        for j in range(emax): # Loop through all other edges
            if i != j:
                # Replace dominated edges
                slice_i = matrix[..., i, :]  # Shape: (B, E, 3)
                slice_j = matrix[..., j, :]  # Shape: (B, E, 3)

                condition_j_greater = (slice_j > slice_i).all(dim=-1)  # Shape: (B, E)

                indices_j_greater = condition_j_greater.nonzero(as_tuple=True)

                result[indices_j_greater + (torch.tensor(j), slice(None))] = matrix[indices_j_greater + (torch.tensor(i), slice(None))]

    idx_i, idx_j = torch.triu_indices(problem_size, problem_size, offset=1)
    idx_pairs = torch.stack([idx_i, idx_j], dim=1)  # shape (num_edges, 2)
    reverse_idx_pairs = torch.stack([idx_j, idx_i], dim=1)  # shape (num_edges, 2)
    all_idx_pairs = torch.cat([idx_pairs, reverse_idx_pairs], dim=1).view(-1, 2)  # Flatten into pairs
    indices = all_idx_pairs.unsqueeze(0).repeat(batch_size, 1, 1).to(torch.int32)

    edge_dists = result.reshape(batch_size, emax * (problem_size * (problem_size - 1)), 3)
    edge_indices = indices.repeat_interleave(emax, dim=1)

    if distribution == "MG_flex":
        return edge_dists, edge_indices, emax
    #elif distribution == "MG_sym": 
        #reshaped = edge_dists.view(batch_size, -1, 4, 2)
        #reshaped[:, :, [2, 3], :] = reshaped[:, :, [0, 1], :]
        #edge_dists = reshaped.view(batch_size, 2 * (problem_size * (problem_size - 1)), 2)
        #return edge_dists, edge_indices
    
def get_random_problems_TMAT(batch_size, problem_size):
    matrix1 = _generate_TMAT(batch_size, problem_size)
    matrix2 = _generate_TMAT(batch_size, problem_size)
    matrix3 = _generate_TMAT(batch_size, problem_size)
    problems = torch.stack((matrix1, matrix2, matrix3), dim=3)

    return _get_problem_dists(problems)

def get_random_problems_XASY(batch_size, problem_size):
    problems = torch.rand(batch_size, problem_size, problem_size, 3)

    return _get_problem_dists(problems)

def _generate_TMAT(batch_size, problem_size, min_val=1, max_val=1000000):
    problems = torch.randint(low=min_val, high=max_val+1, size=(batch_size,
        problem_size, problem_size))
    problems[:, torch.arange(problem_size), torch.arange(problem_size)] = 0
    while True:
        old_problems = problems.clone()
        problems, _ = (problems[:, :, None, :] + problems[:, None, :,
            :].transpose(2,3)).min(dim=3)
        if (problems == old_problems).all():
            break

    max_value = problems.amax(dim=(1, 2), keepdim=True)  # Shape (B, 1, 1)
    problems_max = max_value.expand(batch_size, problem_size, problem_size)

    return torch.divide(problems, problems_max)

def _get_problem_dists(problems):
    """ 
    Converts set of distance matrices (B, N, N, Nobj) to edge attributes (B, E, Nobj) and edge indices (B, E, 2)
    """
    B, N, _, Nobj = problems.shape

    idx_i, idx_j = torch.triu_indices(N, N, offset=1)
    idx_pairs = torch.stack([idx_i, idx_j], dim=1)  # shape (num_edges, 2)
    reverse_idx_pairs = torch.stack([idx_j, idx_i], dim=1)  # shape (num_edges, 2)
    all_idx_pairs = torch.cat([idx_pairs, reverse_idx_pairs], dim=1).view(-1, 2)  # Flatten into pairs
    indices = all_idx_pairs.unsqueeze(0).repeat(B, 1, 1).to(torch.int32)
    i_indices, j_indices = all_idx_pairs[:, 0], all_idx_pairs[:, 1]

    dists = problems[:, i_indices, j_indices, :]

    return dists, indices

def augment_data(dists, edge_to_node, augmentation_factor = 8):
    if augmentation_factor > 1:
        step_size = 0.5 / (augmentation_factor // 2)

        possible_factors = [1]
        possible_factors.extend(
            [0.5 + x * step_size for x in range(augmentation_factor // 2)]
        )
        possible_factors.extend(
            [1.5 - x * step_size for x in range(augmentation_factor // 2)]
        )  ## 0.5 ... 1 ... 1.5
        
        #factor = random.choice(possible_factors)
        possible_factors = possible_factors[:-1] # Exclude last so that aug factor matches specification

    aug_dists = dists
    aug_edge_to_node = edge_to_node
    for factor in possible_factors[1:]:
        aug_dists = torch.cat((aug_dists, dists * factor), dim=0)
        aug_edge_to_node = torch.cat((aug_edge_to_node, edge_to_node), dim=0)

    return aug_dists, aug_edge_to_node, possible_factors
