import argparse
import os

import torch
import numpy as np
import faiss
import heapq

argparser = argparse.ArgumentParser(
    description='Script for selecting the data for training')
argparser.add_argument('--gradient_path', type=str, default="{} ckpt{}",
                       help='The path to the gradient file')
argparser.add_argument('--train_file_names', type=str, nargs='+',
                       help='The name of the training file')
argparser.add_argument('--ckpts', type=int, nargs='+',
                       help="Checkpoint numbers.")
argparser.add_argument('--checkpoint_weights', type=float, nargs='+',
                       help="checkpoint weights")
argparser.add_argument('--target_task_names', type=str,
                       nargs='+', help="The name of the target tasks")
argparser.add_argument('--validation_gradient_path', type=str,
                       default="{} ckpt{}", help='The path to the validation gradient file')
argparser.add_argument('--output_path', type=str, help='The path to the output')
argparser.add_argument('--load_top_info', action='store_true', help='Whether load top K info or compute them')
argparser.add_argument('--load_kde_info', action='store_true', help='Whether load kde info or compute them')
argparser.add_argument('--alpha', type=float, help='The parameter that controls the tradeoff')
argparser.add_argument('--C', type=float, help='A constant', default=5)
argparser.add_argument('--sigma', type=float, default=1.0)
argparser.add_argument('--no_save', action='store_true', help='Whether load kde info or compute them')


args = argparser.parse_args()

N_SUBTASKS = {"mmlu": 57, "bbh": 27, "tydiqa": 9}
MAX_K = 5000
KDE_K = 100
SIGMA = args.sigma
ALPHA = args.alpha
if SIGMA != 1.0:
    postfix=f"sigma{SIGMA}"
else:
    postfix=""

if not args.load_top_info:
    args.load_kde_info = False
    print("KDE will be recompuated since top info will be recompuated")

if not args.load_top_info or not args.load_kde_info:
    xb = []
    for train_file_name in args.train_file_names:
        xb_subset = []
        for i, ckpt in enumerate(args.ckpts):
            gradient_path = args.gradient_path.format(ckpt, train_file_name)
            training_info = torch.load(gradient_path)
            xb_subset.append(args.checkpoint_weights[i] * training_info * 10 ** 5)
        xb_subset = np.concatenate(xb_subset, axis=-1)
        xb.append(xb_subset)
        print(f"{train_file_name}: {xb_subset.shape[0]}")
    xb = np.concatenate(xb, axis=0)
    print(f"xb shape: {xb.shape}")
    index = faiss.IndexFlatL2(xb.shape[-1])
    index.add(xb)
    print("Index has been built.")
    N = xb.shape[0]
else:
    index = None
    print("There is no need to use the index.")
    N = 270679

for target_task_name in args.target_task_names:
    output_dir = os.path.join(args.output_path, target_task_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    top_indices_file = os.path.join(args.output_path, target_task_name,
        f"top_indices_{target_task_name}_{MAX_K}nn.npy")
    top_dists_file = os.path.join(args.output_path, target_task_name,
        f"top_dists_{target_task_name}_{MAX_K}nn.npy")

    if not args.load_top_info:
        xq = []
        for i, ckpt in enumerate(args.ckpts):
            validation_path = args.validation_gradient_path.format(
                target_task_name, ckpt)
            validation_info = torch.load(validation_path)
            xq.append(args.checkpoint_weights[i] * validation_info * 10 ** 5)
        xq = np.concatenate(xq, axis=-1)
        print(f"task: {target_task_name}, xq shape: {xq.shape}")
        top_dists, top_indices = index.search(xq, MAX_K)
        top_indices = top_indices.astype(int)
        np.save(top_indices_file, top_indices)
        np.save(top_dists_file, top_dists)
    else:
        top_indices = np.load(top_indices_file).astype(int)
        top_dists = np.load(top_dists_file)

    sorted_indices = np.argsort(top_dists, axis=-1)
    static_indices = np.indices(top_dists.shape)[0]
    top_dists = np.sqrt(top_dists[static_indices, sorted_indices])
    top_indices = top_indices[static_indices, sorted_indices]
    print(f"top_indices shape: {top_indices.shape}")

    kde_file = os.path.join(args.output_path, target_task_name,
        f"top_kdes_{target_task_name}_{MAX_K}nn_sigma{SIGMA}.npy")
    if not args.load_kde_info:
        top_indices_set = list(set([i for i in top_indices.reshape(-1)]))
        top_features = xb[top_indices_set]
        print(f"size of top set: {len(top_indices_set)}")
        nlist = int(np.sqrt(top_features.shape[0]))
        quantizer = faiss.IndexFlatL2(top_features.shape[-1])
        index_local = faiss.IndexIVFFlat(quantizer, top_features.shape[-1], nlist)
        index_local.train(top_features)
        print("Index training finished.")
        index_local.add(top_features)
        print("Index adding finished.")
        index_local.nprobe = 10
        D2, I = index_local.search(top_features, KDE_K)
        kernel = 1 - D2 / (SIGMA ** 2)
        print(f'A point has {(kernel > 0).sum(axis=-1).mean() - 1} near-duplicates on average')
        kernel = kernel * (kernel > 0)
        kde = kernel.sum(axis=-1)
        print(f"KDE shape: {kde.shape}")
        kde_map = {top_indices_set[i]:kde[i] for i in range(len(top_indices_set))}
        kde_mapfunc = np.vectorize(lambda t: kde_map[t])
        top_kdes = kde_mapfunc(top_indices)
        np.save(kde_file, top_kdes)
    else:
        top_kdes = np.load(kde_file)

    if not args.no_save:
        M = top_indices.shape[0]
        C = args.C
        lastK = [0] * M
        heap = [(1.0 / top_kdes[j][0], 0, j) for j in range(M)]
        heapq.heapify(heap)
        dist_weighted_sum = [top_dists[j][0] / top_kdes[j][0] for j in range(M)]
        s = 0
        # increase on the transportation cost
        cost = np.zeros(M)
        total_cost = 0
        while len(heap) > 0:
            count, curr_k, curr_j = heapq.heappop(heap)
            s = count
            # if increase s by a little bit, the 0, 1, ..., curr_k has to transport probability mass to curr_k + 1
            total_cost -= cost[curr_j]
            cost[curr_j] = top_dists[curr_j][curr_k + 1] * count - dist_weighted_sum[curr_j]
            # print(f"curr_k: {curr_k}, count: {count}, top_dists[curr_j][curr_k + 1]: {top_dists[curr_j][curr_k + 1]}, dist_weighted_sum[curr_j]: {dist_weighted_sum[curr_j]}")
            total_cost += cost[curr_j]
            # print(f's: {s}, cost: {total_cost}')
            # If the condition breaks, the current s will be the final s
            if ALPHA / C * total_cost >= (1 - ALPHA) * M:
                break
            lastK[curr_j] = curr_k
            if curr_k < MAX_K - 2:
                count += 1.0 / top_kdes[curr_j][curr_k + 1]
                heapq.heappush(heap, (count, curr_k + 1, curr_j))
                dist_weighted_sum[curr_j] += top_dists[curr_j][curr_k + 1] / top_kdes[curr_j][curr_k + 1]
        print(f"s: {s}")
        print(f"K stats - average: {np.mean(lastK)} max: {np.max(lastK)} min: {np.min(lastK)}")
        print(f"Reaches {MAX_K - 2}: {(np.array(lastK) == MAX_K - 2).sum()}")
        global_probs = np.zeros(N)
        for j in range(M):
            prob_sum = 0
            for k in range(lastK[j] + 1):
                global_probs[top_indices[j][k]] += 1 / M / s / top_kdes[j][k]
                prob_sum += 1 / M / s / top_kdes[j][k]
            global_probs[top_indices[j][lastK[j] + 1]] += max(1.0 / M - prob_sum, 0)
            assert 1.0 / M - prob_sum >= -1e-9, f'{1.0 / M - prob_sum}'
            assert (1.0 / M - prob_sum) * top_kdes[j][lastK[j] + 1] * M * s <= 1 + 1e-9 or lastK[j] == MAX_K - 2, f'{(1.0 / M - prob_sum) * top_kdes[j][lastK[j] + 1] * M * s}'
        print(f'sum of probs: {global_probs.sum()}')
        print(f'non-zero entries: {(global_probs > 0).sum()}')
            
        output_file = os.path.join(args.output_path, target_task_name,
            f"prob_alpha{ALPHA}{postfix}.npy")
        np.save(output_file, global_probs)
        print("Saved assigned probability to {}".format(output_file))


