from functools import partial
import multiprocessing
import random

import numpy as np
from tqdm import tqdm
import datasets
import os
import yaml
import json

data_with_different_keys = ["reddit_full", "s2orc_citation_title_full","s2orc_abstract_citation_full", "wiki_title_body_full", "amazon_qa_full", ]
dataset_keys_for_exemptions = {"reddit_full": ["title", "body"], "s2orc_citation_title_full": ["query", "pos"], "s2orc_abstract_citation_full": ["query", "pos"], "wiki_title_body_full": ["title", "text"], "amazon_qa_full": ["query", "pos"]}
directories = [
  "amazon_reviews_full",
  "paq_full",
  "s2orc_citation_title_full",
  "s2orc_title_abstract_full",
  "s2orc_abstract_citation_full",
  "s2orc_abstract_body_index_filtered",
  "wikianswers_full",
  "wiki_title_body_full",
  "gooaq_full",
  "codesearch_full",
  "yahoo_title_answer_full",
  "agnews_full",
  "amazon_qa_full",
  "yahoo_qa_full",
  "yahoo_title_question_full",
  "ccnews_full",
  "npr_full",
  "eli5_full",
  "cnn_full",
  "stackexchange_question_question_full",
  "stackexchange_title_body_full",
  "stackexchange_body_body_full",
  "sentence_compression_full",
  "wikihow_full",
  "altlex_full",
  "quora_full",
  "simplewiki_full",
  "squad_full",
  "reddit_full"
]

datasets_to_query_prefix_and_document_prefix = {
    "reddit_full": {"query_prefix": "clustering", "document_prefix": "clustering"},
    "amazon_reviews_full": {"query_prefix": "classification", "document_prefix": "classification"},
    "paq_full": {"query_prefix": "search_query", "document_prefix": "search_document"},
    "s2orc_citation_title_full": {"query_prefix": "clustering", "document_prefix": "clustering"},
    "s2orc_title_abstract_full": {"query_prefix": "clustering", "document_prefix": "clustering"},
    "s2orc_abstract_citation_full": {"query_prefix": "clustering", "document_prefix": "clustering"},
    "s2orc_abstract_body_index_filtered": {"query_prefix": "clustering", "document_prefix": "clustering"},
    "wikianswers_full": {"query_prefix": "classification", "document_prefix": "classification"},
    "wiki_title_body_full": {"query_prefix": "classification", "document_prefix": "classification"},
    "gooaq_full": {"query_prefix": "search_query", "document_prefix": "search_document"},
    "codesearch_full": {"query_prefix": "clustering", "document_prefix": "clustering"},
    "yahoo_title_answer_full": {"query_prefix": "search_query", "document_prefix": "search_document"},
    "agnews_full": {"query_prefix": "classification", "document_prefix": "classification"},
    "amazon_qa_full": {"query_prefix": "search_query", "document_prefix": "search_document"},
    "yahoo_qa_full": {"query_prefix": "search_query", "document_prefix": "search_document"},
    "yahoo_title_question_full": {"query_prefix": "search_query", "document_prefix": "search_document"},
    "ccnews_full": {"query_prefix": "classification", "document_prefix": "classification"},
    "npr_full": {"query_prefix": "classification", "document_prefix": "classification"},
    "eli5_full": {"query_prefix": "search_query", "document_prefix": "search_document"},
    "cnn_full": {"query_prefix": "classification", "document_prefix": "classification"},
    "stackexchange_question_question_full": {"query_prefix": "clustering", "document_prefix": "clustering"},
    "stackexchange_title_body_full": {"query_prefix": "clustering", "document_prefix": "clustering"},
    "stackexchange_body_body_full": {"query_prefix": "clustering", "document_prefix": "clustering"},
    "sentence_compression_full": {"query_prefix": "classification", "document_prefix": "classification"},
    "wikihow_full": {"query_prefix": "classification", "document_prefix": "classification"},
    "altlex_full": {"query_prefix": "classification", "document_prefix": "classification"},
    "quora_full": {"query_prefix": "classification", "document_prefix": "classification"},
    "simplewiki_full": {"query_prefix": "classification", "document_prefix": "classification"},
    "squad_full": {"query_prefix": "search_query", "document_prefix": "search_document"}
    }

with open('scripts/nomic_yaml_long_sys_prompts.yaml', 'r') as file: prime_service = yaml.safe_load(file)
long_prompt_dict = {}
for data in prime_service['datasets']:
    if 'query_prefix' in data.keys():
        query_prefix = data['query_prefix']
    else:
        query_prefix = ""
    if 'document_prefix' in data.keys():
        document_prefix = data['document_prefix']
    else:
        document_prefix = ""
    long_prompt_dict[data['name']] = {"query_prefix": query_prefix, "document_prefix": document_prefix}
print(long_prompt_dict)


NUM_PROC = os.cpu_count()//2

# DRY_RUN=True
DRY_RUN=False

BASE_DIR = "/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_datasets/raw/retrival/contrastive-index-filtered/"
COMBINED_OUT_PATH = '/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_datasets/processed/nomic_positive_only_combined/train'
SRC_SPLIT_BASE_OUT_PATH = '/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_datasets/processed/nomic_positive_only_src_split'
SRC_WEIGHTS_PATH = 'scripts/nomic_positive_only_size_normalized_weights.json'

def format_data(row, folder_, include_long_prompt=False):
    # if folder_ in data_with_different_keys:
    #     row["query"] = f"<s> {datasets_to_query_prefix_and_document_prefix[folder_]['query_prefix'].title()}: <QUERY> {row[dataset_keys_for_exemptions][folder_][0]} <QUERY>".strip()
    #     row["document"] = f"<DOC> {row[dataset_keys_for_exemptions][folder_][1]} <DOC> {datasets_to_query_prefix_and_document_prefix[folder_]['document_prefix'].title()}:".strip()
    # else:
    row["query"] = f"<s> {datasets_to_query_prefix_and_document_prefix[folder_]['query_prefix'].title()}: <QUERY> {row['query']} <QUERY>".strip()
    row["document"] = f"<DOC> {row['document']} <DOC> {datasets_to_query_prefix_and_document_prefix[folder_]['document_prefix'].title()}:".strip()
    return row

subset_sizes = {}
all_datasets = []
for i, subset in tqdm(enumerate(directories), desc="Processing datasets"):
    print(f"Processing: {subset}")
    subset_data = datasets.load_dataset("json", data_files=f"{BASE_DIR}/{subset}/*.jsonl", keep_in_memory=False, num_proc=NUM_PROC)['train']
    # we'll change the column names to query and document
    if subset in data_with_different_keys:
        subset_data = subset_data.rename_column(dataset_keys_for_exemptions[subset][0], "query") if dataset_keys_for_exemptions[subset][0] != "query" else subset_data
        subset_data = subset_data.rename_column(dataset_keys_for_exemptions[subset][1], "document") if dataset_keys_for_exemptions[subset][1] != "document" else subset_data
    # subset_data = subset_data.select(range(1000))
    subset_data = subset_data.map(partial(format_data, folder_=subset), num_proc=NUM_PROC)
    
    subset_data = subset_data.remove_columns([col for col in subset_data.column_names if col not in ["query", "document"]])
    print(f"Split {subset}\n{subset_data}")
    print("################ EXAMPLE ################")
    print(f"Query: {subset_data[0]['query']}")
    print(f"Document: {subset_data[0]['document']}")
    print("########################################")
    subset_sizes[subset] = len(subset_data)
    if not DRY_RUN:
        subset_data.save_to_disk(f"{SRC_SPLIT_BASE_OUT_PATH}/{subset}/train", num_proc=NUM_PROC)
        all_datasets.append(subset_data)
    # if i > 5:
        # break

print(subset_sizes)

total_size = sum(subset_sizes.values())
weights = {k: v/total_size for k,v in subset_sizes.items()}


print(json.dumps(weights,indent=4))

# save weights
if not DRY_RUN:
    with open(SRC_WEIGHTS_PATH, 'w') as f:
        json.dump(weights, f, indent=4)
    print(f"Saved weights to {SRC_WEIGHTS_PATH}")

# saving combined dataset
if not DRY_RUN:
    combined = datasets.concatenate_datasets(all_datasets)
    combined.save_to_disk(COMBINED_OUT_PATH, num_proc=NUM_PROC)

# reading json
# with open('scripts/nomic_positive_only_size_normalized_weights.json', 'r') as file:
#     weights = json.load(file)
# with open('/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/nomic_positive_only_src_separated_norm_weighted.json', 'r') as file:
#     config = json.load(file)
# import math
# precision = 1e-4  # Example: 6 decimal places
# for i, (key, value) in enumerate(weights.items()):
#     if config['train_data'][i]['data_dir'].split('/')[-2] != key:
#         print(f"Error: {config['train_data'][i]['data_dir'].split('/')[-2]} != {key}")
#     if not math.isclose(config['train_data'][i]['scheduler'][0][1][-1], value*100, rel_tol=precision):
#         print(f"Error: {config['train_data'][i]['scheduler'][0][1][-1]} != {value*100}\data_dir: {key}")
# #     print(data)
# #     print(type(data))
# #     print(data.keys())

