import torch
import os
import json
import numpy as np
import argparse
from tqdm import tqdm
import math
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
from torch import nn
PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--pt_data_path_T0", type=str, default='../data/neft_p/alpaca_neft_G_shorten_1.pt')
    parser.add_argument("--pt_data_path_T1", type=str, default='../data/neft_new/alpaca_no_stable_noise_240_average.pt')
    parser.add_argument("--pt_data_path_T2", type=str, default='../data/neft_new/final_test/alpaca_uniform_noise10_240.pt')
    parser.add_argument("--json_data_path", type=str, default='../data/alpaca_data.json')

    parser.add_argument("--json_save_path", type=str, default='../data/alpaca_data_neft.json')
    parser.add_argument("--model_name_or_path", type=str, default='../llama2/Llama-2-7b-hf')
    parser.add_argument("--prompt", type=str, default='alpaca')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    print(args)


    pt_data_T2 = torch.load(args.pt_data_path_T2, map_location=torch.device('cpu'))
    with open(args.json_data_path, "r") as f:
        json_data = json.load(f)
    with open(args.overlap_data_path, "r") as f:
        overlap_data = json.load(f)
    overlap_position = []
    for item in overlap_data:
        if item in json_data:
            position = json_data.index(item)
            overlap_position.append((item,position))

    new_data_dict={}
    kl_target=[pt_data_T2[i]['kl_target'].item() for i in range(len(pt_data_T2))]
    kl_target_list=[]

    for i in tqdm(range(len(kl_target))):
        kl_target_list.append((kl_target[i],i))

    kl_target_sorted = sorted(kl_target_list, key=lambda x: x[0], reverse=False)
    kl_target_sorted2 =   [ (kl_target_sorted[i][0],i) for i in range(26000)]
    with open('../data/average/kl_final/alpaca_uniform_noise10_240_26000_ID.json', "w+") as fw:
              json.dump(kl_target_sorted2, fw, indent=4)



    kl_target_sorted_3000 = [json_data[kl_target_sorted[i][1]] for i in
                             range(len(kl_target_sorted2))]
    new_data_dict['kl_target_sorted_3000'] = kl_target_sorted_3000

    with open('../data/average/kl_final/alpaca_uniform_noise10_240_26000.json', "w+") as fw:
        json.dump(new_data_dict['kl_target_sorted_3000'], fw, indent=4)

if __name__ == '__main__':
    main()