


def save_model(model, output_dir, tokenizer):
    from simcse.trainers import CLTrainer
    trainer = CLTrainer(
        model=model,
        args=None,
        train_dataset=None,
        tokenizer=tokenizer,
        data_collator=None,
    )
    trainer.save_model(output_dir=output_dir)


def apply_pooler(args, outputs, batch):
    last_hidden = outputs.last_hidden_state
    pooler_output = outputs.pooler_output
    hidden_states = outputs.hidden_states
    # Apply different poolers
    if args.pooler == 'cls':
        # There is a linear+activation layer after CLS representation
        return pooler_output
    elif args.pooler == 'all':
        return last_hidden.reshape(last_hidden.shape[0], -1)
    elif args.pooler == 'cls_before_pooler':
        return last_hidden[:, 0]
    elif args.pooler == "avg":
        return ((last_hidden * batch['attention_mask'].unsqueeze(-1)).sum(
            1) / batch['attention_mask'].sum(-1).unsqueeze(-1))
    elif args.pooler == "avg_first_last":
        first_hidden = hidden_states[0]
        last_hidden = hidden_states[-1]
        pooled_result = ((first_hidden + last_hidden) / 2.0 * batch[
            'attention_mask'].unsqueeze(-1)).sum(1) / batch[
                            'attention_mask'].sum(-1).unsqueeze(-1)
        return pooled_result
    elif args.pooler == "avg_top2":
        second_last_hidden = hidden_states[-2]
        last_hidden = hidden_states[-1]
        pooled_result = ((last_hidden + second_last_hidden) / 2.0 * batch[
            'attention_mask'].unsqueeze(-1)).sum(1) / batch[
                            'attention_mask'].sum(-1).unsqueeze(-1)
        return pooled_result
    else:
        raise NotImplementedError


class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
