'''
Calculate the cosine similarity of prompts.
'''

import torch
import os
import pickle

def calc_sim_tensor(tensor_1, tensor_2):
    cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6)
    output = cos(torch.mean(tensor_1, dim=-1), torch.mean(tensor_2, dim=-1))
    return float(output.cpu().detach().numpy())

if __name__ == '__main__':
    dataset_lst = ['wnli', 'mrpc', 'superglue-rte', 'superglue-multirc', 'superglue-boolq', 'superglue-wic',
                   'superglue-wsc.fixed', 'superglue-cb']

    model_name='t5-3b'
    if model_name == 't5_base':
        source_path = '/export/home/OpenPrompt/mixture_prompt/source_prompts/'
    elif model_name == 't5_large':
        source_path = '/export/home/OpenPrompt/mixture_prompt/attempt_large_source/'
    else:
        source_path = '/export/home/OpenPrompt/mixture_prompt/attempt_3b_source/'

    for dataset in dataset_lst:
        max_names = []
        max_idxs = []
        sim_lst_lst = []
        for i in range(20):
            p = '../../outputs/baselines/init_pt_fewshot_table/'+model_name+'/'+dataset+'/prefix_embeddings_'+str(i)+'.pt'
            if not os.path.exists(p):
                break
            tensor_2 = torch.load(p)
            sim_lst = []
            max_sim = -1
            p_best = None
            p_best_idx = 0
            for j, source in enumerate(["mnli", "sst2", "qnli", "qqp", "squad", "record"]):
                tensor_1 = torch.load(source_path+source+'_prompt.pt')
                sim = calc_sim_tensor(tensor_1.cuda(), tensor_2)
                sim_lst.append((sim, j))
                if max_sim > sim:
                    pass
                else:
                    max_sim = sim
                    p_best = source
                    p_best_idx = j

            # print(max_sim, p_best)
            max_names.append(p_best)
            max_idxs.append([p_best_idx])
            sim_lst_lst.append(sim_lst)
        print('*' * 20)
        print(dataset)
        # print(max_names, len(max_names))
        # print('max_idxs', max_idxs)

        for num in [1,3,5]:
            max_idxs = []
            for i in range(len(sim_lst_lst)):
                sim_lst = sorted(sim_lst_lst[i], key=lambda x: x[0], reverse=True)

                max_idxs.append([x[-1] for x in sim_lst[:num]])

            with open('/export/home/OpenPrompt/mixture_prompt/outputs/baselines/sim-prompt/summary/'+model_name+'/'+dataset+'_'+str(num)+'.pickle', 'wb') as handle:
                pickle.dump(max_idxs, handle, protocol=pickle.HIGHEST_PROTOCOL)