import argparse
import json
import os
from tqdm import tqdm
import torch

def parse_args():
    parser = argparse.ArgumentParser(description="Process some integers.")
    parser.add_argument('--gradient_path', type=str, required=True, help='Path to the gradient files')
    parser.add_argument('--original_data_dir', type=str, required=True, help='Path to the original data directory')
    parser.add_argument('--dim', type=int, default=8192, help='Dimension size')
    parser.add_argument('--normalize', type=bool, default=True, help='Whether to normalize the gradients')
    parser.add_argument('--select_percentage', type=float, default=5, help='Percentage of selection')
    parser.add_argument('--sub_iter_every', type=int, default=1, help='Sub iteration frequency')
    parser.add_argument('--largest', type=int, default=1, help='Largest value')
    parser.add_argument('--dataset', type=str, required=True, help='Dataset name')
    parser.add_argument('--model', type=str, required=True, help='Model name')
    parser.add_argument('--iteration', type=int, required=True, help='Iteration number')
    parser.add_argument('--log_file_path', type=str, required=True, help='Path to the log file')
    parser.add_argument('--exp_name', type=str, default='', help='Name of the experiment')
    parser.add_argument('--norm_ratio', type=float, default=1, help='The hyperparameter for the normalization')

    return parser.parse_args()

def merge_info(output_dir: str, prefix="grads"):
    """ Merge the representations and gradients into a single file without normalization. """
    # # if the file exists just read it
    # if os.path.exists(os.path.join(output_dir, f"all_unormalized.pt")):
    #     merged_data = torch.load(os.path.join(output_dir, f"all_unormalized.pt"))
    #     print(
    #         f"Loading the unnormalized {prefix} (Shape: {merged_data.shape}) from {output_dir}.")
    #     return merged_data
    info = os.listdir(output_dir)
    info = [file for file in info if file.startswith(prefix)]
    # Sort the files in ascending order
    info.sort(key=lambda x: int(x.split(".")[0].split("-")[1]))
    merged_data = []
    for file in info:
        data = torch.load(os.path.join(output_dir, file))
        merged_data.append(data)
    merged_data = torch.cat(merged_data, dim=0)

    output_file = os.path.join(output_dir, f"all_unormalized.pt")
    torch.save(merged_data, output_file)
    print(
        f"Saving the unnormalized {prefix} (Shape: {merged_data.shape}) to {output_file}.")
    return merged_data

def main():
    args = parse_args()
    dataset = args.dataset
    model = args.model
    iteration = args.iteration
    gradient_path = args.gradient_path
    original_data_dir = args.original_data_dir
    dim = args.dim
    normalize = args.normalize
    select_percentage = args.select_percentage
    sub_iter_every = args.sub_iter_every
    largest = args.largest
    exp_name = args.exp_name
    norm_ratio = args.norm_ratio

    # read the gradients
    all_grads = merge_info(os.path.join(gradient_path, f'gradinfo/{dim}'), prefix="grads")

    if normalize and norm_ratio == 1:
        all_grads = all_grads / torch.norm(all_grads, dim=1, keepdim=True)
    elif normalize and norm_ratio != 1:
        all_grads = all_grads / (torch.norm(all_grads, dim=1, keepdim=True) * norm_ratio + (1 - norm_ratio))


    # select the data
    selected_points = []
    select_num = int(len(all_grads) * select_percentage / 100)
    selected_features = all_grads[selected_points]
    lamdba = 0.1

    selected_points_this_iter = []
    for i_ in tqdm(range(0, select_num, sub_iter_every), desc=f'Selecting data with sub-iteration every {sub_iter_every}'):
        if i_ >= select_num - sub_iter_every:
            select_num_now = select_num - i_
        else:
            select_num_now = sub_iter_every
            
        # calculate the uncertainty
        U = torch.matmul(selected_features.transpose(0,1), selected_features) + lamdba * torch.eye(dim)
        # U = U.diagonal()
        # all_uncertainty = torch.sqrt(torch.sum(all_grads * all_grads / U, dim=1))
        U_inv = torch.linalg.inv(U)
        all_uncertainty = torch.diagonal(torch.matmul(all_grads, torch.matmul(U_inv, all_grads.transpose(0,1))))

        # selecting data excluding the selected points
        unselected_points = set(range(len(all_uncertainty))) - set(selected_points)
        unselected_points = list(unselected_points)
        select_now = torch.topk(all_uncertainty[unselected_points], select_num_now, largest=(largest == 1))[1]
        select_idx_now = [unselected_points[i] for i in select_now]
        selected_points += select_idx_now
        selected_points_this_iter += select_idx_now
        selected_features = all_grads[selected_points]

    print(f"Asking for {select_num} data, selected for {len(selected_points)} data")
    
    # read data from original_data_dir
    data_points = []
    with open(original_data_dir, 'r') as f:
        for line in f:
            data_points.append(json.loads(line))
    data_points = data_points[1:]
    data_points = [data_points[i] for i in selected_points_this_iter]
    data_points = [{'selection': False}] + data_points

    if len(exp_name) > 0:
        selected_data_path = f'.cache/{exp_name}_{dataset}_{model}_ours_selected_iteration{iteration}.jsonl'
    else:
        selected_data_path = f'.cache/{dataset}_{model}_ours_selected_iteration{iteration}.jsonl'
    with open(selected_data_path, 'w') as f:
        for dp in data_points:
            f.write(json.dumps(dp) + '\n')
    with open(args.log_file_path, 'r') as f:
        logs = json.loads(f.readline())
    logs['latest_selected_data'] = selected_data_path
    with open(args.log_file_path, 'w') as f:
        f.write(json.dumps(logs))
if __name__ == "__main__":
    main()
