# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Meta Platforms, Inc. All Rights Reserved

import torch
import os
from collections import OrderedDict


# PATH to finetune clip model
opendas_model_path = f"{os.getenv('MODEL_SAVE_DIR')}/ade20k_150_negative/OpenDASBasic/vit_l14_c2_ep10_batch16_12+8ctx_use_both_losses_0shots/seed1/VLPromptLearner/model.pth.tar-12"
clip_ckpt =  torch.load(opendas_model_path, map_location=torch.device('cpu'))

new_model = OrderedDict()
state_dict = clip_ckpt['state_dict']

for k, v in state_dict.items():
    new_key = k.replace('module.','')
    print(new_key)
    new_model[new_key] = v

# PATH to trained MaskFormer model
ovseg_model = torch.load('./ov-seg/weights/ovseg_swinbase_vitL14_ft_mpt.pth', 'cpu')

for k, v in new_model.items():
    new_k = 'clip_adapter.custom_clip.' + k
    print(new_k)
    ovseg_model['model'][new_k] = v

try:
    ovseg_model['model']['clip_adapter.clip_model.visual.mask_embedding'] = new_model['visual.mask_embedding']
    print('clip_ckpt has mask_embedding, remember to set MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD True during OVSeg evaluation')
except:
    print('clip_ckpt does not have mask_embedding, remember to set MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD False during OVSeg evaluation')

torch.save(ovseg_model, './ov-seg-all/weights/ovseg_opendas_ade.pth')
