import torch
import torch.nn.functional as F

def evolved_forward(decoder, encoded_last_node, ninf_mask):
    """
    Evolved forward function for TSP_Decoder.
    Args:
        decoder: TSP_Decoder instance, providing access to attributes like self.k, self.v, etc.
        encoded_last_node: torch.Tensor of shape (batch, pomo, embedding)
        ninf_mask: torch.Tensor of shape (batch, pomo, problem)
    Returns:
        probs: torch.Tensor of shape (batch, pomo, problem)
    """
    head_num = decoder.model_params['head_num']

    # Multi-Head Attention
    q_last = reshape_by_heads(decoder.Wq_last(encoded_last_node), head_num=head_num)
    q = decoder.q_first + q_last
    out_concat = multi_head_attention(q, decoder.k, decoder.v, rank3_ninf_mask=ninf_mask)
    mh_atten_out = decoder.multi_head_combine(out_concat)

    # Single-Head Attention
    score = torch.matmul(mh_atten_out, decoder.single_head_key)

    # ReEvo: Dynamic bias based on problem size and node features
    problem_size = ninf_mask.size(-1)
    scale_factor = torch.log(torch.tensor(problem_size, dtype=torch.float, device=score.device))
    dynamic_bias = torch.relu(torch.matmul(encoded_last_node, decoder.single_head_key)) * scale_factor
    score = score + dynamic_bias

    # Adjusted scaling and clipping
    sqrt_embedding_dim = decoder.model_params['sqrt_embedding_dim']
    logit_clipping = decoder.model_params['logit_clipping']
    score_scaled = score / sqrt_embedding_dim
    score_clipped = logit_clipping * torch.tanh(score_scaled)
    score_masked = score_clipped + ninf_mask
    probs = F.softmax(score_masked, dim=2)

    return probs