import torch
import torch.nn.functional as F
from tqdm import tqdm
import random
import copy


# def clean_data(data):
#     data_cleaned = []
#     for item in data:
#         if item['complexity'] < 3.5:
#             data_cleaned.append(item)
#     return data_cleaned

def clean_data(data):
    return data

self_instruct_full = clean_data(torch.load('/mnt/public02/usr/yuanpeiwen/instruction_pool_cleaned/data/cleaned_self_instruction_instruction_embedding.pth'))
self_instruct = torch.cat([item['embedding'].reshape(1, -1) for item in self_instruct_full], dim=0)

alpaca_data_full = clean_data(torch.load('/mnt/public02/usr/yuanpeiwen/instruction_pool_cleaned/data/cleaned_alpaca_gpt4_data_instruction_embedding.pth'))
alpaca_data = torch.cat([item['embedding'].reshape(1, -1) for item in alpaca_data_full], dim=0)

alpaca_evol_full = clean_data(torch.load('/mnt/public02/usr/yuanpeiwen/instruction_pool_cleaned/data/cleaned_alpaca_evol_instruct_70k_instruction_embedding.pth'))
alpaca_evol = torch.cat([item['embedding'].reshape(1, -1) for item in alpaca_evol_full], dim=0)

dolly_full = clean_data(torch.load('/mnt/public02/usr/yuanpeiwen/instruction_pool_cleaned/data/cleaned_databricks_dolly_15k_instruction_embedding.pth'))
dolly = torch.cat([item['embedding'].reshape(1, -1) for item in dolly_full], dim=0)

sharegpt_full = clean_data(torch.load('/mnt/public02/usr/yuanpeiwen/instruction_pool_cleaned/data/cleaned_ShareGPT_V3_unfiltered_cleaned_split_no_imsorry_instruction_embedding.pth'))
sharegpt = torch.cat([item['embedding'].reshape(1, -1) for item in sharegpt_full], dim=0)

# sharegpt_evol_full = clean_data(torch.load('/mnt/public02/usr/yuanpeiwen/instruction_pool/clustering/embeddings/WizardLM_evol_instruct_V2_143k_instruction_embedding.pth'))
# sharegpt_evol = torch.cat([item['embedding'].reshape(1, -1) for item in sharegpt_evol_full], dim=0)

full_data = self_instruct_full + alpaca_data_full + dolly_full + sharegpt_full + alpaca_evol_full
for i in range(len(full_data)):
    full_data[i]['idx'] = i
sorted_full_data  = sorted(full_data, key=lambda x: x['quality'], reverse=True)

pool = []
pool.append(sorted_full_data[0])
for i in tqdm(range(1, len(sorted_full_data))):
    embedding = sorted_full_data[i]['embedding'].cuda()
    nn_distance = -1.0
    left = 0
    while True:
        right = min(left+5000, len(pool))
        pool_embedding = [pool[j]['embedding'].unsqueeze(0).cuda() for j in range(left, right)]
        pool_embedding = torch.cat(pool_embedding, dim=0)
        distance = F.cosine_similarity(embedding, pool_embedding).detach().cpu().flatten()
        if distance.max() > nn_distance:
            nn_distance = distance.max()
        left = right
        if right >= len(pool):
            break
    if nn_distance < 0.9:
        pool.append(sorted_full_data[i])

    if len(pool) % 1000 == 0:
        print(f"Collected: {len(pool)}/6000")
    if len(pool) == 6000:
        break
torch.save(pool, f'/mnt/public02/usr/yuanpeiwen/instruction_pool_cleaned/baselines/cleaned_no_complexity_deita_6k.pth')

# pool = []
# for item in datasets:
#     print(f"Processing: {item['name']}")
#     name = item['name']
#     data = item['data']
#     if len(pool) > 0:
#         data = data + pool
#     sorted_full_data  = sorted(data, key=lambda x: x['quality'] * x['complexity'], reverse=True)

#     cur_pool = []
#     cur_pool.append(sorted_full_data[0])

#     for i in tqdm(range(1, len(sorted_full_data))):
#         embedding = sorted_full_data[i]['embedding'].cuda()
#         nn_distance = -1.0
#         left = 0
#         while True:
#             right = min(left+5000, len(cur_pool))
#             pool_embedding = [cur_pool[j]['embedding'].unsqueeze(0).cuda() for j in range(left, right)]
#             pool_embedding = torch.cat(pool_embedding, dim=0)
#             distance = F.cosine_similarity(embedding, pool_embedding).detach().cpu().flatten()
#             if distance.max() > nn_distance:
#                 nn_distance = distance.max()
#             left = right
#             if right >= len(cur_pool):
#                 break
#         if nn_distance < 0.9:
#             cur_pool.append(sorted_full_data[i])
#             print(len(cur_pool))

#         if len(cur_pool) == 6000:
#             break
#     torch.save(cur_pool, f'/mnt/public02/usr/yuanpeiwen/instruction_pool_cleaned/baselines/deita_outputs/new_order_sbs_deita_{name}.pth')
#     pool = copy.deepcopy(cur_pool)
