from functools import partial
import multiprocessing
import random

import numpy as np
from tqdm import tqdm
import datasets
import os
import yaml
import json


NUM_PROC = os.cpu_count()//2

# DRY_RUN=True
DRY_RUN=False

COMBINED_OUT_PATH = '/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_datasets/processed/nomic_supervised_combined_v3/train'
SRC_SPLIT_BASE_OUT_PATH = '/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_datasets/processed/nomic_supervised_src_split_v3'
SRC_WEIGHTS_PATH = 'scripts/nomic_supervised_size_normalized_weights_v3.json'

data_hf = datasets.load_dataset("jxm/nomic_embed_supervised", keep_in_memory=False, num_proc=NUM_PROC)['train']

print(data_hf)
with open('scripts/nomic_supervised.yaml', 'r') as file: 
    ds_metadata = yaml.safe_load(file)

long_prompt_dict = {}
for data in ds_metadata['datasets']:
    name = data['bucket'].split('/')[-2]
    print(name)
    if 'query_prefix' in data.keys():
        query_prefix = data['query_prefix']
    else:
        query_prefix = ""
    if 'document_prefix' in data.keys():
        document_prefix = data['document_prefix']
    else:
        document_prefix = query_prefix
    long_prompt_dict[name] = {"query_prefix": query_prefix, "document_prefix": document_prefix}

def combine_to_text(row):
    folder_ = row['dataset']
    row["query"] = f"<s> {long_prompt_dict[folder_]['query_prefix'].title()}: <QUERY> {row['query']} <QUERY>".strip()
    row["document"] = f"<DOC> {row['document']} <DOC> {long_prompt_dict[folder_]['document_prefix'].title()}:".strip()
    row["negative"] = [f"<DOC> {neg} <DOC> {long_prompt_dict[folder_]['document_prefix'].title()}:".strip() for neg in row["negative"]]
    # row["query"] = f"<s> {long_prompt_dict[folder_]['query_prefix'].title()}: <QUERY> {row['query']} <QUERY>".strip()
    # row["document"] = f"{long_prompt_dict[folder_]['document_prefix'].title()}: <DOC> {row['document']} <DOC>".strip()
    # row["negative"] = [f"{long_prompt_dict[folder_]['document_prefix'].title()}: <DOC> {neg} <DOC>".strip() for neg in row["negative"]]
    # row["query"] = f"<s><TASK> {long_prompt_dict[folder_]['query_prefix'].title()} <TASK><QUERY> {row['query']} <QUERY>".strip()
    # row["document"] = f"<DOC> {row['document']} <DOC><TASK> {long_prompt_dict[folder_]['document_prefix'].title()} <TASK>".strip()
    # row["negative"] = [f"<DOC> {neg} <DOC><TASK> {long_prompt_dict[folder_]['document_prefix'].title()} <TASK>".strip() for neg in row["negative"]]
    # row["query"] = f"<s> {long_prompt_dict[folder_]['query_prefix'].title()}: {row['query']}".strip()
    # row["document"] = f"{long_prompt_dict[folder_]['document_prefix'].title()}: {row['document']}".strip()
    # row["negative"] = [f"{long_prompt_dict[folder_]['document_prefix'].title()}: {neg}".strip() for neg in row["negative"]]
    return row

data_hf = data_hf.map(combine_to_text, num_proc=NUM_PROC)

if not DRY_RUN:
    data_hf.save_to_disk(COMBINED_OUT_PATH, num_proc=NUM_PROC)

# now break apart by split using a filter
# also compute relative weights based on subset sizes by storing a table
subset_sizes = {}
def filter_and_save(subset):
    print(f"Processing subset {subset} ... ")
    subset_data = data_hf.filter(lambda x: x['dataset'] == subset, num_proc=NUM_PROC, batched=False)    
    print(f"Split {subset}\n{subset_data}")
    print("################ EXAMPLE ################")
    print(f"Query: {subset_data[0]['query']}")
    print(f"Document: {subset_data[0]['document']}")
    print(f"Negative: {subset_data[0]['negative'][0]}")
    print("########################################")
    subset_sizes[subset] = len(subset_data)
    if not DRY_RUN:
        subset_data.save_to_disk(f'{SRC_SPLIT_BASE_OUT_PATH}/{subset}/train', num_proc=NUM_PROC)

unique_subsets = [d['bucket'].split('/')[-2] for d in ds_metadata['datasets']]
print(f"Found splits: {unique_subsets}")

for subset in unique_subsets:
    filter_and_save(subset)

print(subset_sizes)

# relative size normalized weights

total_size = sum(subset_sizes.values())
weights = {k: v/total_size for k,v in subset_sizes.items()}


print(json.dumps(weights,indent=4))


# save weights
if not DRY_RUN:
    with open(SRC_WEIGHTS_PATH, 'w') as f:
        json.dump(weights, f, indent=4)
    print(f"Saved weights to {SRC_WEIGHTS_PATH}")