import json
import os
import copy
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch import Tensor
import torch.distributed as dist
import torch.nn.functional as F
from transformers import AutoModel, BatchEncoding, PreTrainedModel
from transformers.modeling_outputs import ModelOutput,BaseModelOutput,Seq2SeqModelOutput

from typing import Optional, Dict

from .arguments import ModelArguments, DataArguments, \
    DenseTrainingArguments as TrainingArguments
import logging

logger = logging.getLogger(__name__)
import pdb

@dataclass
class DenseOutput(ModelOutput):
    q_reps: Tensor = None
    p_reps: Tensor = None
    p_ance_reps: Tensor = None
    loss: Tensor = None
    scores: Tensor = None

class AttOutput(ModelOutput):
    q_reps: Tensor = None
    p_reps: Tensor = None
    all_cross_attentions: Tensor=None

class LinearPooler(nn.Module):
    def __init__(
            self,
            input_dim: int = 768,
            output_dim: int = 768,
            tied=True
    ):
        super(LinearPooler, self).__init__()
        self.linear_q = nn.Linear(input_dim, output_dim)
        if tied:
            self.linear_p = self.linear_q
        else:
            self.linear_p = nn.Linear(input_dim, output_dim)

        self._config = {'input_dim': input_dim, 'output_dim': output_dim, 'tied': tied}

    def forward(self, q: Tensor = None, p: Tensor = None):
        if q is not None:
            return self.linear_q(q[:, 0])
        elif p is not None:
            return self.linear_p(p[:, 0])
        else:
            raise ValueError

    def load(self, ckpt_dir: str):
        if ckpt_dir is not None:
            _pooler_path = os.path.join(ckpt_dir, 'pooler.pt')
            if os.path.exists(_pooler_path):
                logger.info(f'Loading Pooler from {ckpt_dir}')
                state_dict = torch.load(os.path.join(ckpt_dir, 'pooler.pt'), map_location='cpu')
                self.load_state_dict(state_dict)
                return
        logger.info("Training Pooler from scratch")
        return

    def save_pooler(self, save_path):
        torch.save(self.state_dict(), os.path.join(save_path, 'pooler.pt'))
        with open(os.path.join(save_path, 'pooler_config.json'), 'w') as f:
            json.dump(self._config, f)

class DenseModel(nn.Module):
    def __init__(
            self,
            lm_q: PreTrainedModel,
            lm_p: PreTrainedModel,
            pooler: nn.Module = None,
            model_args: ModelArguments = None,
            data_args: DataArguments = None,
            train_args: TrainingArguments = None,
            ground_passage_num: int = 3
    ):
        super().__init__()

        self.lm_q = lm_q
        self.lm_p = lm_p
        self.pooler = pooler
        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')

        self.model_args = model_args
        self.train_args = train_args
        self.data_args = data_args
        self.config = self.lm_p.config # make deepspeed happy
        self.ground_passage_num=data_args.ground_passage_num


        if train_args.negatives_x_device:
            if not dist.is_initialized():
                raise ValueError('Distributed training has not been initialized for representation all gather.')
            self.process_rank = dist.get_rank()
            self.world_size = dist.get_world_size()

    def forward(
            self,
            query: Dict[str, Tensor] = None,
            passage: Dict[str, Tensor] = None,
    ):
        q_hidden, q_reps = self.encode_query(query)
        p_hidden, p_reps = self.encode_passage(passage)
        if q_reps is None or p_reps is None:
            return DenseOutput(
                q_reps=q_reps,
                p_reps=p_reps,
            )
        # if self.training:
        if self.train_args.negatives_x_device:
            q_reps = self.dist_gather_tensor(q_reps)
            p_reps = self.dist_gather_tensor(p_reps)
        effective_bsz = self.train_args.per_device_train_batch_size * self.world_size \
            if self.train_args.negatives_x_device \
            else self.train_args.per_device_train_batch_size
        scores = torch.matmul(q_reps, p_reps.transpose(0, 1))
        # print(scores.shape)
        # scores = scores.view(effective_bsz, -1)  # ???
        target = torch.arange(
            scores.size(0),
            device=scores.device,
            dtype=torch.long
        )
        target = target    * self.data_args.train_n_passages
        loss = self.cross_entropy(scores, target)
        if self.training and self.train_args.negatives_x_device:
            loss = loss * self.world_size  # counter average weight reduction
        return DenseOutput(
            loss=loss,
            scores=scores,
            q_reps=q_reps,
            p_reps=p_reps
        )

    def encode_passage(self, psg):
        if psg is None:
            return None, None
        psg = BatchEncoding(psg)
        if self.model_args.use_t5:
            decoder_input_ids = torch.zeros((psg.input_ids.shape[0], 1), dtype=torch.long)
            decoder_input_ids = decoder_input_ids.to(psg.input_ids.device)
            psg_out = self.lm_p(**psg, decoder_input_ids=decoder_input_ids, return_dict=True)
            p_hidden = psg_out.encoder_last_hidden_state
            p_reps = p_hidden.mean(dim=1)
        elif self.model_args.use_t5_decoder:
            decoder_input_ids = torch.zeros((psg.input_ids.shape[0], 1), dtype=torch.long)
            decoder_input_ids = decoder_input_ids.to(psg.input_ids.device)
            psg_out = self.lm_p(**psg, decoder_input_ids=decoder_input_ids, return_dict=True)
            p_hidden = psg_out.last_hidden_state
            p_reps = p_hidden[:, 0, :]
        else:
            psg_out = self.lm_p(**psg, return_dict=True)
            p_hidden = psg_out.last_hidden_state
            if self.pooler is not None:
                p_reps = self.pooler(p=p_hidden)  # D * d
            else:
                p_reps = p_hidden[:, 0]
        return p_hidden, p_reps

    def encode_query(self, qry):
        if qry is None:
            return None, None
        qry = BatchEncoding(qry)
        if self.model_args.use_t5:
            decoder_input_ids = torch.zeros((qry.input_ids.shape[0], 1), dtype=torch.long)
            decoder_input_ids = decoder_input_ids.to(qry.input_ids.device)
            qry_out = self.lm_q(**qry, decoder_input_ids=decoder_input_ids, return_dict=True)
            q_hidden = qry_out.encoder_last_hidden_state
            q_reps = q_hidden.mean(dim=1)
        elif self.model_args.use_t5_decoder:
            decoder_input_ids = torch.zeros((qry.input_ids.shape[0], 1), dtype=torch.long)
            decoder_input_ids = decoder_input_ids.to(qry.input_ids.device)
            qry_out = self.lm_q(**qry, decoder_input_ids=decoder_input_ids, return_dict=True)
            q_hidden = qry_out.encoder_last_hidden_state
            q_reps_1=torch.reshape(q_hidden,(-1,self.ground_passage_num+1,q_hidden.shape[-2],q_hidden.shape[-1]))
            q_reps_2=torch.reshape(q_reps_1,(-1,q_reps_1.shape[1]*q_reps_1.shape[2],q_reps_1.shape[-1]))
            q_reps_tuple = BaseModelOutput(
                last_hidden_state=q_reps_2,
                hidden_states=None,
                attentions=None,
            )
            flatten_attention_mask=torch.reshape(qry.attention_mask,(q_reps_2.shape[0],q_reps_2.shape[1]))
            ground_decoder_input_ids = torch.zeros((q_reps_2.shape[0], 1), dtype=torch.long)
            ground_decoder_input_ids = ground_decoder_input_ids.to(qry.input_ids.device)
            grounded_out=self.lm_q(attention_mask=flatten_attention_mask,decoder_input_ids=ground_decoder_input_ids,encoder_outputs=q_reps_tuple, return_dict=True)
            q_hidden = grounded_out.last_hidden_state
            q_reps = q_hidden[:, 0, :]
        else:
            qry_out = self.lm_q(**qry, return_dict=True)
            q_hidden = qry_out.last_hidden_state
            if self.pooler is not None:
                q_reps = self.pooler(q=q_hidden)
            else:
                q_reps = q_hidden[:, 0]
        return q_hidden, q_reps

    @staticmethod
    def build_pooler(model_args):
        pooler = LinearPooler(
            model_args.projection_in_dim,
            model_args.projection_out_dim,
            tied=not model_args.untie_encoder
        )
        pooler.load(model_args.model_name_or_path)
        return pooler
    @classmethod
    def build(
            cls,
            model_args: ModelArguments,
            data_args: DataArguments,
            train_args: TrainingArguments,
            ground_passage_num: int = 3,
            **hf_kwargs,
    ):
        # load local
        if os.path.isdir(model_args.model_name_or_path):
            if model_args.untie_encoder:
                _qry_model_path = os.path.join(model_args.model_name_or_path, 'query_model')
                _psg_model_path = os.path.join(model_args.model_name_or_path, 'passage_model')
                if not os.path.exists(_qry_model_path):
                    _qry_model_path = model_args.model_name_or_path
                    _psg_model_path = model_args.model_name_or_path
                logger.info(f'loading query model weight from {_qry_model_path}')
                lm_q = AutoModel.from_pretrained(
                    _qry_model_path,
                    **hf_kwargs
                )
                logger.info(f'loading passage model weight from {_psg_model_path}')
                lm_p = AutoModel.from_pretrained(
                    _psg_model_path,
                    **hf_kwargs
                )
            else:
                lm_q = AutoModel.from_pretrained(model_args.model_name_or_path, **hf_kwargs)
                lm_p = lm_q
        # load pre-trained
        else:
            lm_q = AutoModel.from_pretrained(model_args.model_name_or_path, **hf_kwargs)
            lm_p = copy.deepcopy(lm_q) if model_args.untie_encoder else lm_q

        if model_args.add_pooler:
            pooler = cls.build_pooler(model_args)
        else:
            pooler = None

        # cross_att_layer=cls.build_cross_att_layer(model_args)

        model = cls(
            lm_q=lm_q,
            lm_p=lm_p,
            pooler=pooler,
            model_args=model_args,
            data_args=data_args,
            train_args=train_args,
            ground_passage_num=ground_passage_num
        )
        return model

    def save(self, output_dir: str):
        if self.model_args.untie_encoder:
            os.makedirs(os.path.join(output_dir, 'query_model'))
            os.makedirs(os.path.join(output_dir, 'passage_model'))
            self.lm_q.save_pretrained(os.path.join(output_dir, 'query_model'))
            self.lm_p.save_pretrained(os.path.join(output_dir, 'passage_model'))
        else:
            self.lm_q.save_pretrained(output_dir)

        if self.model_args.add_pooler:
            self.pooler.save_pooler(output_dir)

    def dist_gather_tensor(self, t: Optional[torch.Tensor]):
        if t is None:
            return None
        t = t.contiguous()

        all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
        dist.all_gather(all_tensors, t)

        all_tensors[self.process_rank] = t
        all_tensors = torch.cat(all_tensors, dim=0)

        return all_tensors

class DenseModelForInference(DenseModel):
    POOLER_CLS = LinearPooler
    # CROSS_ATT_LAYER_CLS=CrossGroundPassageAtt
    def __init__(
            self,
            lm_q: PreTrainedModel,
            lm_p: PreTrainedModel,
            pooler: nn.Module = None,
            # cross_att_layer: CrossGroundPassageAtt = None,
            model_args: ModelArguments = None,
            ground_passage_num: int = 3,
            use_ground: bool = False,
            **kwargs,
    ):
        nn.Module.__init__(self)
        self.lm_q = lm_q
        self.lm_p = lm_p
        self.pooler = pooler
        self.model_args = model_args
        # self.cross_att_layer=cross_att_layer
        self.ground_passage_num=ground_passage_num
        self.use_ground=use_ground

    @torch.no_grad()
    def encode_passage(self, psg):
        return super(DenseModelForInference, self).encode_passage(psg)

    @torch.no_grad()
    def encode_query(self, qry):
        if qry is None:
            return None, None, None
        qry = BatchEncoding(qry)
        if self.use_ground:
            decoder_input_ids = torch.zeros((qry.input_ids.shape[0], 1), dtype=torch.long)
            decoder_input_ids = decoder_input_ids.to(qry.input_ids.device)
            qry_out = self.lm_q(**qry, decoder_input_ids=decoder_input_ids, return_dict=True)
            q_hidden = qry_out.encoder_last_hidden_state
            q_reps_1=torch.reshape(q_hidden,(-1,self.ground_passage_num+1,q_hidden.shape[-2],q_hidden.shape[-1]))
            q_reps_2=torch.reshape(q_reps_1,(-1,q_reps_1.shape[1]*q_reps_1.shape[2],q_reps_1.shape[-1]))
            flatten_attention_mask=torch.reshape(qry.attention_mask,(q_reps_2.shape[0],q_reps_2.shape[1]))
            q_reps_tuple = BaseModelOutput(
                last_hidden_state=q_reps_2,
                hidden_states=None,
                attentions=None,
            )
            ground_decoder_input_ids = torch.zeros((q_reps_2.shape[0], 1), dtype=torch.long)
            ground_decoder_input_ids = ground_decoder_input_ids.to(qry.input_ids.device)
            grounded_out=self.lm_q(attention_mask=flatten_attention_mask,decoder_input_ids=ground_decoder_input_ids,encoder_outputs=q_reps_tuple, return_dict=True,output_attentions=True)
            q_hidden = grounded_out.last_hidden_state
            q_reps = q_hidden[:, 0, :]
            all_cross_attentions=grounded_out.cross_attentions
            return q_hidden, q_reps,all_cross_attentions
        else:
            decoder_input_ids = torch.zeros((qry.input_ids.shape[0], 1), dtype=torch.long)
            decoder_input_ids = decoder_input_ids.to(qry.input_ids.device)
            qry_out = self.lm_q(**qry, decoder_input_ids=decoder_input_ids, return_dict=True)
            q_hidden = qry_out.last_hidden_state
            q_reps = q_hidden[:, 0, :]
            return q_hidden, q_reps,None
        # else:
        #     decoder_input_ids = torch.zeros((qry.input_ids.shape[0], 1), dtype=torch.long)
        #     decoder_input_ids = decoder_input_ids.to(qry.input_ids.device)
        #     qry_out = self.lm_q(**qry, decoder_input_ids=decoder_input_ids, return_dict=True,output_hidden_states=True,output_attentions=True)
        #     q_reps_tuple = BaseModelOutput(
        #         last_hidden_state=qry_out.encoder_last_hidden_state,
        #         hidden_states=None,
        #         attentions=None,
        #     )
        #     grounded_out=self.lm_q(attention_mask=qry.attention_mask,decoder_input_ids=decoder_input_ids,encoder_outputs=q_reps_tuple, return_dict=True,output_hidden_states=True,output_attentions=True)
        #     g_hidden = grounded_out.last_hidden_state
        #     g_reps = g_hidden[:, 0, :]
        #     return g_hidden, g_reps

    def forward(
            self,
            query: Dict[str, Tensor] = None,
            passage: Dict[str, Tensor] = None,
    ):
        q_hidden, q_reps,all_cross_attentions = self.encode_query(query)
        p_hidden, p_reps = self.encode_passage(passage)
        return AttOutput(q_reps=q_reps, p_reps=p_reps,all_cross_attentions=all_cross_attentions)

    @classmethod
    def build(
            cls,
            model_name_or_path: str = None,
            model_args: ModelArguments = None,
            data_args: DataArguments = None,
            train_args: TrainingArguments = None,
            use_ground: bool=False,
            ground_passage_num: int = 3,
            **hf_kwargs,
    ):
        assert model_name_or_path is not None or model_args is not None
        if model_name_or_path is None:
            model_name_or_path = model_args.model_name_or_path

        # load local
        if os.path.isdir(model_name_or_path):
            _qry_model_path = os.path.join(model_name_or_path, 'query_model')
            _psg_model_path = os.path.join(model_name_or_path, 'passage_model')
            if os.path.exists(_qry_model_path):
                logger.info(f'found separate weight for query/passage encoders')
                logger.info(f'loading query model weight from {_qry_model_path}')
                lm_q = AutoModel.from_pretrained(
                    _qry_model_path,
                    **hf_kwargs
                )
                logger.info(f'loading passage model weight from {_psg_model_path}')
                lm_p = AutoModel.from_pretrained(
                    _psg_model_path,
                    **hf_kwargs
                )
            else:
                logger.info(f'try loading tied weight')
                logger.info(f'loading model weight from {model_name_or_path}')
                lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs)
                lm_p = lm_q
        else:
            logger.info(f'try loading tied weight')
            logger.info(f'loading model weight from {model_name_or_path}')
            lm_q = AutoModel.from_pretrained(model_name_or_path, **hf_kwargs)
            lm_p = lm_q

        pooler_weights = os.path.join(model_name_or_path, 'pooler.pt')
        pooler_config = os.path.join(model_name_or_path, 'pooler_config.json')
        if os.path.exists(pooler_weights) and os.path.exists(pooler_config):
            logger.info(f'found pooler weight and configuration')
            with open(pooler_config) as f:
                pooler_config_dict = json.load(f)
            pooler = cls.POOLER_CLS(**pooler_config_dict)
            pooler.load(model_name_or_path)
        else:
            pooler = None

        # cross_att_layer_weights = os.path.join(model_name_or_path, 'cross_att_layer.pt')
        # cross_att_layer_config = os.path.join(model_name_or_path, 'cross_att_layer_config.json')
        # if os.path.exists(cross_att_layer_weights) and os.path.exists(cross_att_layer_config):
        #     logger.info(f'found cross_att_layer weight and configuration')
        #     with open(cross_att_layer_config) as f:
        #         cross_att_layer_config_dict = json.load(f)
        #     cross_att_layer = cls.CROSS_ATT_LAYER_CLS(**cross_att_layer_config_dict)
        #     cross_att_layer.load(model_name_or_path)
        # else:
        #     cross_att_layer = None

        model = cls(
            lm_q=lm_q,
            lm_p=lm_p,
            pooler=pooler,
            model_args=model_args,
            # cross_att_layer=cross_att_layer,
            use_ground=use_ground,
            ground_passage_num=ground_passage_num
        )
        return model

class DistillModel(nn.Module):
    def __init__(
            self,
            lm_q: PreTrainedModel,
            lm_p: PreTrainedModel,
            pooler: nn.Module = None,
            model_args: ModelArguments = None,
            data_args: DataArguments = None,
            train_args: TrainingArguments = None,
    ):
        super().__init__()

        self.lm_q = lm_q
        self.lm_p = lm_p
        self.pooler = pooler
        self.kl_loss=nn.KLDivLoss(reduction="batchmean")
        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
        self.model_args = model_args
        self.train_args = train_args
        self.data_args = data_args
        self.config = self.lm_p.config
        if train_args.negatives_x_device:
            if not dist.is_initialized():
                raise ValueError('Distributed training has not been initialized for representation all gather.')
            self.process_rank = dist.get_rank()
            self.world_size = dist.get_world_size()

    def forward(
            self,
            query: Dict[str, Tensor] = None,
            passage: Dict[str, Tensor] = None,
            ance_passage: Dict[str, Tensor] = None,
            target: Tensor= None,
    ):

        q_hidden, q_reps = self.encode_query(query)
        p_hidden, p_reps = self.encode_passage(passage)
        p_ance_hidden, p_ance_reps = self.encode_passage(ance_passage)

        if q_reps is None or p_reps is None or target is None or p_ance_reps is None:
            return DenseOutput(
                q_reps=q_reps,
                p_reps=p_reps,
                p_ance_reps= p_ance_reps
            )

        p_distill_reps=torch.reshape(p_reps,(q_reps.shape[0],-1,q_reps.shape[-1]))
        target=torch.reshape(target,(q_reps.shape[0],p_distill_reps.shape[1]))
        q_distill_reps=torch.unsqueeze(q_reps,1)
        distill_scores = torch.matmul(q_distill_reps, p_distill_reps.transpose(1,2))
        distill_scores=torch.squeeze(distill_scores,1)
        s_target = F.softmax(self.train_args.softmax_temperature*target,dim=-1)
        
        if self.train_args.negatives_x_device:
            q_gather_reps = self.dist_gather_tensor(q_reps)
            p_ance_reps = self.dist_gather_tensor(p_ance_reps)
            ance_scores = torch.matmul(q_gather_reps, p_ance_reps.transpose(0, 1))
        else:
            ance_scores = torch.matmul(q_reps, p_ance_reps.transpose(0, 1))

        effective_bsz = self.train_args.per_device_train_batch_size * self.world_size \
            if self.train_args.negatives_x_device \
            else self.train_args.per_device_train_batch_size
        
        if self.train_args.use_relevant:
            distill_scores = F.log_softmax(distill_scores,dim=-1)
            distill_loss = self.kl_loss(distill_scores, s_target)

            ance_target = torch.arange(
                    ance_scores.size(0),
                    device=ance_scores.device,
                    dtype=torch.long
                )
            ance_target = ance_target * self.data_args.train_n_passages
            ance_loss = self.cross_entropy(ance_scores, ance_target)
        
            if self.training and self.train_args.negatives_x_device:
                ance_loss = ance_loss * self.world_size  # counter average weight reduction
            lbd=float(self.train_args.loss_lambda/10)
            loss=lbd*self.train_args.distill_loss_balance*distill_loss+(1-lbd)*ance_loss
        else:
            combined_scores=torch.cat((distill_scores,ance_scores),-1)
            combined_scores = F.log_softmax(combined_scores,dim=-1)
            ance_target=torch.zeros(ance_scores.size(),device=ance_scores.device)
            combined_targets=torch.cat((s_target,ance_target),-1)
            loss = nn.MSELoss()(combined_scores, combined_targets)
            if self.training and self.train_args.negatives_x_device:
                loss = loss * self.world_size 
        
        return DenseOutput(
            loss=loss,
            scores=distill_scores,
            q_reps=q_reps,
            p_reps=p_reps,
            p_ance_reps=p_ance_reps
        )

    def encode_passage(self, psg):
        if psg is None:
            return None, None
        psg = BatchEncoding(psg)
        if self.model_args.use_t5:
            decoder_input_ids = torch.zeros((psg.input_ids.shape[0], 1), dtype=torch.long)
            decoder_input_ids = decoder_input_ids.to(psg.input_ids.device)
            psg_out = self.lm_p(**psg, decoder_input_ids=decoder_input_ids, return_dict=True)
            p_hidden = psg_out.encoder_last_hidden_state
            p_reps = p_hidden.mean(dim=1)
        elif self.model_args.use_t5_decoder:
            decoder_input_ids = torch.zeros((psg.input_ids.shape[0], 1), dtype=torch.long)
            decoder_input_ids = decoder_input_ids.to(psg.input_ids.device)
            psg_out = self.lm_p(**psg, decoder_input_ids=decoder_input_ids, return_dict=True)
            p_hidden = psg_out.last_hidden_state
            p_reps = p_hidden[:, 0, :]
        else:
            psg_out = self.lm_p(**psg, return_dict=True)
            p_hidden = psg_out.last_hidden_state
            if self.pooler is not None:
                p_reps = self.pooler(p=p_hidden)  # D * d
            else:
                p_reps = p_hidden[:, 0]
        return p_hidden, p_reps

    def encode_query(self, qry):
        if qry is None:
            return None, None
        qry = BatchEncoding(qry)
        if self.model_args.use_t5:
            decoder_input_ids = torch.zeros((qry.input_ids.shape[0], 1), dtype=torch.long)
            decoder_input_ids = decoder_input_ids.to(qry.input_ids.device)
            qry_out = self.lm_q(**qry, decoder_input_ids=decoder_input_ids, return_dict=True)
            q_hidden = qry_out.encoder_last_hidden_state
            q_reps = q_hidden.mean(dim=1)
        elif self.model_args.use_t5_decoder:
            decoder_input_ids = torch.zeros((qry.input_ids.shape[0], 1), dtype=torch.long)
            decoder_input_ids = decoder_input_ids.to(qry.input_ids.device)
            qry_out = self.lm_q(**qry, decoder_input_ids=decoder_input_ids, return_dict=True)
            q_hidden = qry_out.last_hidden_state
            q_reps = q_hidden[:, 0, :]
        else:
            qry_out = self.lm_q(**qry, return_dict=True)
            q_hidden = qry_out.last_hidden_state
            if self.pooler is not None:
                q_reps = self.pooler(q=q_hidden)
            else:
                q_reps = q_hidden[:, 0]
        return q_hidden, q_reps

    @staticmethod
    def build_pooler(model_args):
        pooler = LinearPooler(
            model_args.projection_in_dim,
            model_args.projection_out_dim,
            tied=not model_args.untie_encoder
        )
        pooler.load(model_args.model_name_or_path)
        return pooler

    @classmethod
    def build(
            cls,
            model_args: ModelArguments,
            data_args: DataArguments,
            train_args: TrainingArguments,
            **hf_kwargs,
    ):
        # load local
        if os.path.isdir(model_args.model_name_or_path):
            if model_args.untie_encoder:
                _qry_model_path = os.path.join(model_args.model_name_or_path, 'query_model')
                _psg_model_path = os.path.join(model_args.model_name_or_path, 'passage_model')
                if not os.path.exists(_qry_model_path):
                    _qry_model_path = model_args.model_name_or_path
                    _psg_model_path = model_args.model_name_or_path
                logger.info(f'loading query model weight from {_qry_model_path}')
                lm_q = AutoModel.from_pretrained(
                    _qry_model_path,
                    **hf_kwargs
                )
                logger.info(f'loading passage model weight from {_psg_model_path}')
                lm_p = AutoModel.from_pretrained(
                    _psg_model_path,
                    **hf_kwargs
                )
            else:
                lm_q = AutoModel.from_pretrained(model_args.model_name_or_path, **hf_kwargs)
                lm_p = lm_q
        # load pre-trained
        else:
            lm_q = AutoModel.from_pretrained(model_args.model_name_or_path, **hf_kwargs)
            lm_p = copy.deepcopy(lm_q) if model_args.untie_encoder else lm_q

        if model_args.add_pooler:
            pooler = cls.build_pooler(model_args)
        else:
            pooler = None

        model = cls(
            lm_q=lm_q,
            lm_p=lm_p,
            pooler=pooler,
            model_args=model_args,
            data_args=data_args,
            train_args=train_args,
        )
        return model

    def save(self, output_dir: str):
        if self.model_args.untie_encoder:
            os.makedirs(os.path.join(output_dir, 'query_model'))
            os.makedirs(os.path.join(output_dir, 'passage_model'))
            self.lm_q.save_pretrained(os.path.join(output_dir, 'query_model'))
            self.lm_p.save_pretrained(os.path.join(output_dir, 'passage_model'))
        else:
            self.lm_q.save_pretrained(output_dir)

        if self.model_args.add_pooler:
            self.pooler.save_pooler(output_dir)

    def dist_gather_tensor(self, t: Optional[torch.Tensor]):
        if t is None:
            return None
        t = t.contiguous()

        all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
        dist.all_gather(all_tensors, t)

        all_tensors[self.process_rank] = t
        all_tensors = torch.cat(all_tensors, dim=0)

        return all_tensors

