from .span_propose_attn import SpanProposeCriterionWeighted
from .span_propose_attn_llava_ov import SpanProposeCriterionWeightedLLavaOV

criterion_list = {
    "span_propose_attn": SpanProposeCriterionWeighted,
    "span_propose_attn_llava_ov": SpanProposeCriterionWeightedLLavaOV,
}

def build_criterion(args):
    if args.kd_loss_type not in criterion_list.keys():
        raise ValueError(f"Criterion {args.kd_loss_type} not found.")
    return criterion_list[args.kd_loss_type](args)