from fairseq.checkpoint_utils import prune_state_dict

def load_state_dict(self, state_dict, strict=True, args=None):
    # print('call model_transformer.load_state_dict()')
    strict = False
    self.upgrade_state_dict(state_dict)
    new_state_dict = prune_state_dict(state_dict, args)
    missing_keys, unexpected_keys = super().load_state_dict(new_state_dict, strict)
    # if 'encoder.embed_positions_new.weight' in missing_keys:
    #     self.encoder.align_embed_position()
    if len(unexpected_keys) > 0:
        print("unexpected_keys: " ,unexpected_keys)
        raise Exception
    return []


def upgrade_state_dict_named(self, state_dict, name):
    # note: for positional embeddings longer than original (1024)
    # called in fairseq_model.upgrade_state_dict()
    # ref: https://github.com/pytorch/fairseq/blob/666d8c26e1feb4fa1fa5369f15086999f64744e0/examples/MMPT/mmpt/models/fairseqmmmodel.py#L28
    # import traceback,sys
    # traceback.print_stack(file=sys.stdout)
    super().upgrade_state_dict_named(state_dict, name)
    # print(self.state_dict().keys())
    # print('state_dict[encoder.embed_positions.weight].size()', state_dict['encoder.embed_positions.weight'].size())
    # print('self.state_dict[embed_positions.weight].size()', self.state_dict()['embed_positions.weight'].size())
    if not state_dict['encoder.embed_positions.weight'].size() == self.state_dict()['embed_positions.weight'].size():
        state_dict['encoder.embed_positions.weight'] = torch.cat((state_dict['encoder.embed_positions.weight'],
                                                   state_dict['encoder.embed_positions.weight'][-1][None, :].repeat(self.max_source_positions-1024, 1)),
                                                  0)
    # def align_embed_position(self):
    #     self.embed_positions_new.weight.data[:1026, :] = self.embed_positions.weight.data
    #     self.embed_positions_new.weight.data[1026:, :] = self.embed_positions.weight.data[-1][None, :].repeat(self.max_source_positions-1024, 1)
    #     # self.embed_positions = self.embed_positions_new
