from functools import partial
import multiprocessing
import random

import numpy as np
from tqdm import tqdm
import datasets
import os
import yaml
import json


NUM_PROC = os.cpu_count()//2

# DRY_RUN=True
DRY_RUN=False

COMBINED_OUT_PATH = '/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_datasets/processed/rift_v2_combined/train'
SRC_SPLIT_BASE_OUT_PATH = '/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_datasets/processed/rift_v2_src_split'
SRC_WEIGHTS_PATH = 'scripts/rift_v2_size_normalized_weights.json'

# rift1 = datasets.load_from_disk("/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/rift/")
# rift2 = datasets.load_from_disk("/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/riftV1/")

# # Filter datasets
# rift1_filtered = rift1.filter(lambda x: x['dataset_name'] in ['wiki', 'wiki2', 'msmarco'])
# rift2_filtered = rift2.filter(lambda x: x['dataset_name'] not in ['wiki1', 'wiki2', 'msmarco'])

# # Initialize combined dictionary with features specification
# combined_rift = {
#     "query": [],
#     "document": [],
#     "negative": [],  # Using 'negative' consistently
#     "dataset_name": []
# }

# # Process RIFT1
# for data in tqdm(rift1_filtered, desc="Processing RIFT1"):
#     combined_rift["query"].append(data["query"])
#     # Handle the negatives field - ensure it's a list of strings
#     document = data["document"] if isinstance(data["document"], list) else [data["document"]]
#     combined_rift["document"].append(document)
#     negatives = data["negatives"] if isinstance(data["negatives"], list) else [data["negatives"]]
#     combined_rift["negative"].append(negatives)
#     combined_rift["dataset_name"].append(data["dataset_name"])

# # Process RIFT2
# for data in tqdm(rift2_filtered, desc="Processing RIFT2"):
#     combined_rift["query"].append(data["query"])
#     # Handle the negative field - ensure it's a list of strings
#     document = data["document"] if isinstance(data["document"], list) else [data["document"]]
#     combined_rift["document"].append(data["document"])
#     negative = data["negative"] if isinstance(data["negative"], list) else [data["negative"]]
#     combined_rift["negative"].append(negative)
#     combined_rift["dataset_name"].append(data["dataset_name"])

# # Create features specification
# features = datasets.Features({
#     "query": datasets.Value("string"),
#     "document": datasets.Sequence(datasets.Value("string")),  # Specify that document is a sequence of strings
#     "negative": datasets.Sequence(datasets.Value("string")),  # Specify that negative is a sequence of strings
#     "dataset_name": datasets.Value("string")
# })

# # Create dataset with explicit features specification
# combined_rift = datasets.Dataset.from_dict(combined_rift, features=features)
# combined_rift.save_to_disk("/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_datasets/processed/rift_v2")
nomic_data = datasets.load_from_disk("/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_datasets/processed/nomic_supervised_combined_v3/train")
rift_data = datasets.load_from_disk("/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_datasets/processed/rift_v2_combined/train")
from IPython import embed; embed()
data_hf = datasets.load_from_disk("/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_datasets/processed/rift_v2")

print(data_hf)
with open('scripts/rift_v2.yaml', 'r') as file: 
    ds_metadata = yaml.safe_load(file)

long_prompt_dict = {}
for data in ds_metadata['datasets']:
    name = data['name']
    print(name)
    if 'query_prefix' in data.keys():
        query_prefix = data['query_prefix']
    else:
        query_prefix = ""
    if 'document_prefix' in data.keys():
        document_prefix = data['document_prefix']
    else:
        document_prefix = query_prefix
    long_prompt_dict[name] = {"query_prefix": query_prefix, "document_prefix": document_prefix}

def combine_to_text(row):
    folder_ = row['dataset_name']
    is_symmetric = long_prompt_dict[folder_]['query_prefix'] == long_prompt_dict[folder_]['document_prefix']
    row["query"] = f"<s> {long_prompt_dict[folder_]['query_prefix'].title()}: <QUERY> {row['query']} <QUERY>".strip()
    # # we'll use query meta token for symmetric tasks (e.g clustering, classification tasks)
    # if is_symmetric:
    #     row["document"] = f"<s> {long_prompt_dict[folder_]['query_prefix'].title()}: <QUERY> {row['document']} <QUERY>".strip()
    #     row["negative"] = [f"<s> {long_prompt_dict[folder_]['query_prefix'].title()}: <QUERY> {neg} <QUERY>".strip() for neg in row["negative"]]
    # else:
    row["document"] = [f"<DOC> {doc} <DOC> {long_prompt_dict[folder_]['document_prefix'].title()}:".strip() for doc in row["document"]]
    row["negative"] = [f"<DOC> {neg} <DOC> {long_prompt_dict[folder_]['document_prefix'].title()}:".strip() for neg in row["negative"]]

    # row["query"] = f"<s> {long_prompt_dict[folder_]['query_prefix'].title()}: <QUERY> {row['query']} <QUERY>".strip()
    # row["document"] = f"{long_prompt_dict[folder_]['document_prefix'].title()}: <DOC> {row['document']} <DOC>".strip()
    # row["negative"] = [f"{long_prompt_dict[folder_]['document_prefix'].title()}: <DOC> {neg} <DOC>".strip() for neg in row["negative"]]
    # row["query"] = f"<s><TASK> {long_prompt_dict[folder_]['query_prefix'].title()} <TASK><QUERY> {row['query']} <QUERY>".strip()
    # row["document"] = f"<DOC> {row['document']} <DOC><TASK> {long_prompt_dict[folder_]['document_prefix'].title()} <TASK>".strip()
    # row["negative"] = [f"<DOC> {neg} <DOC><TASK> {long_prompt_dict[folder_]['document_prefix'].title()} <TASK>".strip() for neg in row["negative"]]
    # row["query"] = f"<s> {long_prompt_dict[folder_]['query_prefix'].title()}: {row['query']}".strip()
    # row["document"] = f"{long_prompt_dict[folder_]['document_prefix'].title()}: {row['document']}".strip()
    # row["negative"] = [f"{long_prompt_dict[folder_]['document_prefix'].title()}: {neg}".strip() for neg in row["negative"]]
    return row

data_hf = data_hf.map(combine_to_text, num_proc=NUM_PROC)

if not DRY_RUN:
    data_hf.save_to_disk(COMBINED_OUT_PATH, num_proc=NUM_PROC)

# now break apart by split using a filter
# also compute relative weights based on subset sizes by storing a table
subset_sizes = {}
def filter_and_save(subset):
    print(f"Processing subset {subset} ... ")
    subset_data = data_hf.filter(lambda x: x['dataset_name'] == subset, num_proc=NUM_PROC, batched=False)    
    print(f"Split {subset}\n{subset_data}")
    print("################ EXAMPLE ################")
    print(f"Query: {subset_data[0]['query']}")
    print(f"Document: {subset_data[0]['document']}")
    print(f"Negative: {subset_data[0]['negative']}")
    print("########################################")
    subset_sizes[subset] = len(subset_data)
    if not DRY_RUN:
        subset_data.save_to_disk(f'{SRC_SPLIT_BASE_OUT_PATH}/{subset}/train', num_proc=NUM_PROC)

unique_subsets = [d['name'] for d in ds_metadata['datasets']]
print(f"Found splits: {unique_subsets}")

for subset in unique_subsets:
    filter_and_save(subset)

print(subset_sizes)

# relative size normalized weights

total_size = sum(subset_sizes.values())
weights = {k: v/total_size for k,v in subset_sizes.items()}


print(json.dumps(weights,indent=4))


# save weights
if not DRY_RUN:
    with open(SRC_WEIGHTS_PATH, 'w') as f:
        json.dump(weights, f, indent=4)
    print(f"Saved weights to {SRC_WEIGHTS_PATH}")