from enum import Enum

import torch
import torch.nn.functional as F


def assemble_feature(log_prob, action, data, no_agent, dev):
    action_onehot = F.one_hot(action, num_classes=no_agent).to(dev)
    features = torch.cat((action_onehot, data[:, 1:, :], log_prob[:, :, None]), dim=2).to(dev)
    return features


def get_src_vector(n_node: int, n_agent: int) -> list[int]:
    src_vector = [n_node // n_agent + 1 if n_node % n_agent >= (i + 1) else n_node // n_agent for i in range(n_agent)]
    if sum(src_vector) != n_node:
        raise ValueError(f"missmatch : n_node = {n_node} | sum of src_vector : {sum(src_vector)}")
    return src_vector


class Loss(Enum):
    iMTSP = "iMTSP"
    iMTSP2 = "iMTSP2"
    proposed = "proposed"
    proposed2 = "proposed2"
    proposed3 = "proposed3"
    proposed4 = "proposed4"
    proposed5 = "proposed5"

    def getValues():
        return [e.value for e in Loss]
