import os
import sys
sys.path.insert(0, os.getcwd())
import torch
import argparse

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="circuit_features_merging")
    parser.add_argument('--app', type=str, default='maxcut')
    parser.add_argument('--full_emb_path', type=str, default='pretrained\\dim-16\\maxcut-model-circuits_4_qubits_gsqas_full_embedding.pt')
    parser.add_argument('--emb_path', type=str, default='pretrained\\dim-16\\maxcut-model-circuits_4_qubits_gsqas.pt')
    parser.add_argument('--save_path', type=str, default='pretrained\\dim-16\\maxcut-model-circuits_4_qubits_gsqas_full_embedding.pt')
    
    args = parser.parse_args()
    
    full_embedding = torch.load(args.full_emb_path)
    feature_embedding = torch.load(args.emb_path)
    
    metric = None
    
    if args.app in ["maxcut", "vqe"]:
        metric = "energy"
    elif args.app in ["fidelity"]:
        metric = "fidelity"
    else:
        raise ValueError("We don't set other downstream tasks apart from maxcut, vqe and fidelity.")
    
    for ind in range(len(full_embedding)):
        feature_embedding[ind][metric] = full_embedding[ind][metric]
        feature_embedding[ind]['time'] = full_embedding[ind]['time']
    torch.save(feature_embedding, args.save_path)
    print("save finished...")