import torch

emax_default = 2

def get_random_problems(distribution, batch_size, problem_size, pref=None, sparse_method=None, use_sparse=False, set_seed=False, seed=1000, emax=emax_default):
    if set_seed:
        torch.manual_seed(seed)

    if distribution == "EUC":
        problems = torch.rand(size=(batch_size, problem_size, 4))
        return _get_problem_dists(problems)
    if distribution == "TMAT":  
        matrix1 = _generate_TMAT(batch_size, problem_size)
        matrix2 = _generate_TMAT(batch_size, problem_size)
        problems = torch.stack((matrix1, matrix2), dim=3)
        return problems
    if distribution == "XASY":
        problems = torch.rand((batch_size, problem_size, problem_size, 2))
        return problems
    if distribution == "MG_fix":
        edge_dists, indices = get_random_problems_mg_fix(distribution, batch_size, problem_size, emax=emax)

        if use_sparse:
            dists_sparse = sparsify(pref, sparse_method, indices[0], edge_dists, problem_size, emax=emax)
            return dists_sparse
        else:
            return edge_dists, indices
    if distribution == "MG_flex":
        edge_dists, indices = get_random_problems_mg_flex(distribution, batch_size, problem_size, emax=emax)

        if use_sparse:
            dists_sparse = sparsify(pref, sparse_method, indices[0], edge_dists, problem_size, emax=emax)
            return dists_sparse
        else:
            return edge_dists, indices


def _generate_TMAT(batch_size, problem_size, min_val=1, max_val=1000000):
    problems = torch.randint(low=min_val, high=max_val+1, size=(batch_size,
        problem_size, problem_size))
    problems[:, torch.arange(problem_size), torch.arange(problem_size)] = 0
    while True:
        old_problems = problems.clone()
        problems, _ = (problems[:, :, None, :] + problems[:, None, :,
            :].transpose(2,3)).min(dim=3)
        if (problems == old_problems).all():
            break

    max_value = problems.amax(dim=(1, 2), keepdim=True)  # Shape (B, 1, 1)
    problems_max = max_value.expand(batch_size, problem_size, problem_size)

    return torch.divide(problems, problems_max)

def _get_problem_dists(problems):
    B, N, Nobj2 = problems.shape
    
    x_coords = problems[:, :, ::2]  # shape (B, N, Nobj)
    y_coords = problems[:, :, 1::2]  # shape (B, N, Nobj)


    idx_i, idx_j = torch.triu_indices(N, N, offset=1)
    
    x_diff = x_coords[:, idx_i] - x_coords[:, idx_j]  # shape (B, num_edges, Nobj)
    y_diff = y_coords[:, idx_i] - y_coords[:, idx_j]  # shape (B, num_edges, Nobj)

    x_diff = x_coords.unsqueeze(2) - x_coords.unsqueeze(1)  # Shape: (B, N, N, 2)
    y_diff = y_coords.unsqueeze(2) - y_coords.unsqueeze(1)  # Shape: (B, N, N, 2)
    dist_matrix = torch.sqrt(x_diff ** 2 + y_diff ** 2)  # Shape: (B, N, N, 2)

    return dist_matrix

def augment_data(dist_matrix, augmentation_factor = 8):
    if augmentation_factor > 1:
        step_size = 0.5 / (augmentation_factor // 2)

        possible_factors = [1]
        possible_factors.extend(
            [0.5 + x * step_size for x in range(augmentation_factor // 2)]
        )
        possible_factors.extend(
            [1.5 - x * step_size for x in range(augmentation_factor // 2)]
        )  ## 0.5 ... 1 ... 1.5
        
        #factor = random.choice(possible_factors)
        possible_factors = possible_factors[:-1] # Exclude last so that aug factor matches specification

    aug_dist_matrix = dist_matrix
    for factor in possible_factors[1:]:
        aug_dist_matrix = torch.cat((aug_dist_matrix, dist_matrix * factor), dim=0)

    return aug_dist_matrix, possible_factors


### Multigraphs ###
###################

def get_random_problems_mg_fix(distribution, batch_size, problem_size, emax = emax_default):

    matrix = torch.rand(batch_size, problem_size * (problem_size - 1), emax, 2)
    problems, _ = torch.sort(matrix, dim=2)  # Sort along each dimension
    problems[:, :, :, 1] = problems[:, :, :, 1].flip(dims=(2,))

    idx_i, idx_j = torch.triu_indices(problem_size, problem_size, offset=1)
    idx_pairs = torch.stack([idx_i, idx_j], dim=1)  # shape (num_edges, 2)
    reverse_idx_pairs = torch.stack([idx_j, idx_i], dim=1)  # shape (num_edges, 2)
    all_idx_pairs = torch.cat([idx_pairs, reverse_idx_pairs], dim=1).view(-1, 2)  # Flatten into pairs
    indices = all_idx_pairs.unsqueeze(0).repeat(batch_size, 1, 1).to(torch.int32)

    edge_dists = problems.reshape(batch_size, emax * (problem_size * (problem_size - 1)), 2)
    edge_indices = indices.repeat_interleave(emax, dim=1)

    if distribution == "MG_fix":
        return edge_dists, edge_indices
    #elif distribution == "MG_sym": 
        #reshaped = edge_dists.view(batch_size, -1, 4, 2)
        #reshaped[:, :, [2, 3], :] = reshaped[:, :, [0, 1], :]
        #edge_dists = reshaped.view(batch_size, 2 * (problem_size * (problem_size - 1)), 2)
        #return edge_dists, edge_indices

def get_random_problems_mg_flex(distribution, batch_size, problem_size, emax = emax_default):
    matrix = torch.rand(batch_size, problem_size * (problem_size - 1), emax, 2)
    result = matrix.clone()

    # Filter out all dominated edges
    for i in range(emax): # For each parallel edge
        for j in range(emax): # Loop through all other edges
            if i != j:
                # Replace dominated edges
                slice_i = matrix[..., i, :]  # Shape: (B, E, 2)
                slice_j = matrix[..., j, :]  # Shape: (B, E, 2)

                condition_j_greater = (slice_j > slice_i).all(dim=-1)  # Shape: (B, E)

                indices_j_greater = condition_j_greater.nonzero(as_tuple=True)

                result[indices_j_greater + (torch.tensor(j), slice(None))] = matrix[indices_j_greater + (torch.tensor(i), slice(None))]

    idx_i, idx_j = torch.triu_indices(problem_size, problem_size, offset=1)
    idx_pairs = torch.stack([idx_i, idx_j], dim=1)  # shape (num_edges, 2)
    reverse_idx_pairs = torch.stack([idx_j, idx_i], dim=1)  # shape (num_edges, 2)
    all_idx_pairs = torch.cat([idx_pairs, reverse_idx_pairs], dim=1).view(-1, 2)  # Flatten into pairs
    indices = all_idx_pairs.unsqueeze(0).repeat(batch_size, 1, 1).to(torch.int32)

    edge_dists = result.reshape(batch_size, emax * (problem_size * (problem_size - 1)), 2)
    edge_indices = indices.repeat_interleave(emax, dim=1)

    if distribution == "MG_flex":
        return edge_dists, edge_indices
    #elif distribution == "MG_sym": 
        #reshaped = edge_dists.view(batch_size, -1, 4, 2)
        #reshaped[:, :, [2, 3], :] = reshaped[:, :, [0, 1], :]
        #edge_dists = reshaped.view(batch_size, 2 * (problem_size * (problem_size - 1)), 2)
        #return edge_dists, edge_indices
    
def sparsify(pref, method, edge_to_node, dists, problem_size, emax = emax_default):
    dists2 = dists[:, :, 1]
    dists1 = dists[:, :, 0]
    if method == "Linear":
        dists_tot = pref[0] * dists1 + pref[1] * dists2
    else: 
        dists_tot = torch.maximum(pref[0] * dists1, pref[1] * dists2)

    B, E_tot = dists_tot.shape
    dists_tot_alts = torch.zeros((B, E_tot // emax, emax))
    dists1_alts = torch.zeros((B, E_tot // emax, emax))
    dists2_alts = torch.zeros((B, E_tot // emax, emax))

    for alt in range(emax):
        dists_tot_alts[:, :, alt] = dists_tot[:, alt::emax]
        dists1_alts[:, :, alt] = dists1[:, alt::emax]
        dists2_alts[:, :, alt] = dists2[:, alt::emax]

    min_dist_idx = torch.argmin(dists_tot_alts, dim=2)
    
    dist_1_min = torch.gather(dists1_alts, 2, min_dist_idx.unsqueeze(2)).squeeze(2)
    dist_2_min = torch.gather(dists2_alts, 2, min_dist_idx.unsqueeze(2)).squeeze(2)
    
    dists_min = torch.stack((dist_1_min, dist_2_min), dim=2)

    idx_from = edge_to_node[::emax, 0] #.unsqueeze(0).expand(B, -1)
    idx_to = edge_to_node[::emax, 1] #.unsqueeze(0).expand(B, -1)

    dist_matrix = torch.zeros(B, problem_size, problem_size, 2)
    dist_matrix[:, idx_from, idx_to, :] = dists_min

    # (B, N, N, 2)
    return dist_matrix