import argparse
import json
import os
from tqdm import tqdm
import torch

def merge_info(output_dir: str, prefix="grads"):
    """ Merge the representations and gradients into a single file without normalization. """
    # # if the file exists just read it
    # if os.path.exists(os.path.join(output_dir, f"all_unormalized.pt")):
    #     merged_data = torch.load(os.path.join(output_dir, f"all_unormalized.pt"))
    #     print(
    #         f"Loading the unnormalized {prefix} (Shape: {merged_data.shape}) from {output_dir}.")
    #     return merged_data
    info = os.listdir(output_dir)
    info = [file for file in info if file.startswith(prefix)]
    # Sort the files in ascending order
    info.sort(key=lambda x: int(x.split(".")[0].split("-")[1]))
    merged_data = []
    for file in info:
        data = torch.load(os.path.join(output_dir, file))
        merged_data.append(data)
    merged_data = torch.cat(merged_data, dim=0)

    output_file = os.path.join(output_dir, f"all_unormalized.pt")
    torch.save(merged_data, output_file)
    print(
        f"Saving the unnormalized {prefix} (Shape: {merged_data.shape}) to {output_file}.")
    return merged_data


argparser = argparse.ArgumentParser(
    description='Script for selecting the data for DPO training')
argparser.add_argument('--gradient_path', type=str, default="",
                       help='The path to the gradient file')
argparser.add_argument('--exp_path', type=str, default="",
                       help='The path to the experiments')
argparser.add_argument('--dim', type=int, default=8192,
                       help='The dimension of the gradient')
argparser.add_argument('--select_percentage', type=int, default=5,
                       help='The percentage of the selection')
argparser.add_argument('--sub_iter_every', type=int, default=1,
                       help='The frequency of the sub-iteration')
argparser.add_argument('--normalize', type=int, default=0,
                       help='Whether to normalize the gradients')
argparser.add_argument('--largest', type=int, default=1,
                       help='Whether to select the largest values')
argparser.add_argument('--iteration', type=int, default=0,
                       help='The iteration of the selection')
argparser.add_argument('--ori_data_dir', type=str, default='.cache',
                       help='The original data directory')


args = argparser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# read the gradients
all_grads = merge_info(os.path.join(args.gradient_path, f'gradinfo/{args.dim}'), prefix="grads")

if args.normalize:
    all_grads = all_grads / torch.norm(all_grads, dim=1, keepdim=True)

# read other info
if not os.path.exists(os.path.join(args.gradient_path, f'gradinfo/{args.dim}', 'info.json')):
    info = {'iteration': 0, 'selected_points_all': []}
else:
    with open(os.path.join(args.gradient_path, f'gradinfo/{args.dim}', 'info.json'), 'r') as f:
        info = json.load(f)
iteration = info['iteration']
selected_points = [i for i in info['selected_points_all']]

# select the data
select_num = int(len(all_grads) * args.select_percentage / 100)
selected_features = all_grads[selected_points]
lamdba = 0.1

selected_points_this_iter = []
for i_ in tqdm(range(0, select_num, args.sub_iter_every), desc=f'Selecting data with sub-iteration every {args.sub_iter_every}'):
    if i_ >= select_num - args.sub_iter_every:
        select_num_now = select_num - i_
    else:
        select_num_now = args.sub_iter_every
        
    # calculate the uncertainty
    U = torch.matmul(selected_features.transpose(0,1), selected_features) + lamdba * torch.eye(args.dim)
    U = U.diagonal()
    all_uncertainty = torch.sqrt(torch.sum(all_grads * all_grads / U, dim=1))

    # selecting data excluding the selected points
    unselected_points = set(range(len(all_uncertainty))) - set(selected_points)
    unselected_points = list(unselected_points)
    select_now = torch.topk(all_uncertainty[unselected_points], select_num_now, largest=(args.largest == 1))[1]
    select_idx_now = [unselected_points[i] for i in select_now]
    selected_points += select_idx_now
    selected_points_this_iter += select_idx_now
    selected_features = all_grads[selected_points]

print(f"Asking for {select_num} data, selected for {len(selected_points)-len(info['selected_points_all'])} data")

data_points = [{'selection': True, 'selected_idx': selected_points, 'iteration': iteration, 'selected_points_now': selected_points_this_iter, 'source': args.original_data_dir}]

if args.normalize:
    if args.largest != 1:
        with open(os.path.join(args.gradient_path, f'gradinfo/{args.dim}', f'info-finish-every{args.sub_iter_every}-p{args.select_percentage}-normalized_smallest.json'), 'w') as f:
            for dp in data_points:
                f.write(json.dumps(dp) + '\n')

    with open(os.path.join(args.gradient_path, f'gradinfo/{args.dim}', f'info-finish-every{args.sub_iter_every}-p{args.select_percentage}-normalized.json'), 'w') as f:
        for dp in data_points:
            f.write(json.dumps(dp) + '\n')
else:
    with open(os.path.join(args.gradient_path, f'gradinfo/{args.dim}', f'info-finish-every{args.sub_iter_every}-p{args.select_percentage}.json'), 'w') as f:
        for dp in data_points:
            f.write(json.dumps(dp) + '\n')
