from transformers import CONFIG_MAPPING, AutoConfig, AutoModelForQuestionAnswering

from MQSP_evaluation.distributed_utils import DistGroups

from ..base.model_util import ConcatSequenceOrigin, change_position_embedding


def init_qa_model(args):
    if args.model_name_or_path:
        config = AutoConfig.from_pretrained(
            args.model_name_or_path, return_dict=True, gradient_checkpointing=args.checkpoint_activations
        )
    else:
        config = CONFIG_MAPPING[args.model_type]()

    if args.model_name_or_path:
        model = AutoModelForQuestionAnswering.from_pretrained(
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
        )
    else:
        model = AutoModelForQuestionAnswering.from_config(config)

    # TODO: seq_length
    change_position_embedding(model, args)
    old_qa_linear_forward = model.qa_outputs.forward

    def new_qa_forward(sub_sequence_output):
        # (batch, sub_seq_length, hidden_size)
        sequence_output = ConcatSequenceOrigin.apply(sub_sequence_output, DistGroups["sp"])
        ret = old_qa_linear_forward(sequence_output)
        ## (batch, sub_seq_length, 2)
        return ret

    model.qa_outputs.forward = new_qa_forward
    return model
