import torch

# Specify your pre-trained models in torch.load
chartqa_prompt_key = torch.load('./output/prompt-key/referring_qa_prompt_key.pth')['keys'] 
docvqa_prompt_key = torch.load('./output/prompt-key/detail_description_prompt_key.pth')['keys']
iconqa_prompt_key = torch.load('./output/prompt-key/complex_reasoning_prompt_key.pth')['keys']
medicalqa_prompt_key = torch.load('./output/prompt-key/conversation_prompt_key.pth')['keys']

prompt_key = torch.cat(
    (chartqa_prompt_key[:, :1, :, :], 
     docvqa_prompt_key[:, 1:2, :, :], 
     iconqa_prompt_key[:, 2:3, :, :], 
     medicalqa_prompt_key[:, 3:4, :, :]), dim=1)

print("prompt_key.shape: {}".format(prompt_key.shape))

torch.save(prompt_key, 'output/prompt-key/task_prompt_key.pth')