""" A script to create a subset of the full pcqm4m data set for a specific functional property """

import argparse
from collections import defaultdict
import json
import os
import random
random.seed(1234)

import numpy as np
from tqdm import tqdm

FILTERING_PARAMETERS = {
    'valency': {'key': 'a number of valence electrons', 'text_idx': 0},
    'qed': {'key': 'a weighted quantitative estimation of drug-likeness', 'text_idx': 1}
}

def save_pcqm4m_dataset_subset(
    *,
    dataset_path: str,
    dataset_size: int,
    graph_cutoff_size: int,
    property_name: str,
):
    """ Saves a subset of the pcqm4m data set based on a set of filtering parameters """
    assert property_name in FILTERING_PARAMETERS
    directories = [dir_name for dir_name in dataset_path.split("/") if dir_name != '']
    parent_directory = "/" + "/".join(directories[:-1]) + "/"
    dataset_name = directories[-1]
    new_dataset_name = f"{dataset_name}_{property_name}_{graph_cutoff_size}_{dataset_size}"
    new_dataset_path = os.path.join(parent_directory, new_dataset_name)
    os.mkdir(new_dataset_path)
    file_names = []
    property_values = []
    for file_name in tqdm(os.listdir(dataset_path)):
        if not file_name.endswith(".json"):
            continue
        with open(os.path.join(dataset_path, file_name), 'r', encoding='utf-8') as data_json:
            graph = json.load(data_json)
        if len(graph['nodes']) > graph_cutoff_size:
            continue
        property_values.append(graph['properties'][FILTERING_PARAMETERS[property_name]['key']])
        file_names.append(file_name)
    quantiles = [0, 25, 50, 75, 100]
    print(f"Percentiles at {quantiles} - {np.percentile(property_values, quantiles)}")
    print(f"Mean - {np.mean(property_values)}")
    print(f"Standard Deviation - {np.std(property_values)}")
    file_names = sorted(file_names)
    random.shuffle(file_names)
    split_file_names = {
        'test': file_names[:1000],
        'val': file_names[1000:2000],
        'train': file_names[2000:2000 + dataset_size]
    }
    counter = defaultdict(lambda: 0)
    for split_name, file_names in split_file_names.items():
        for file_name in tqdm(file_names):
            new_file_name = f"{new_dataset_name}_{split_name}_{counter[split_name]}.json"
            new_file_path = os.path.join(new_dataset_path, new_file_name)
            file_path = os.path.join(dataset_path, file_name)
            with open(file_path, 'r', encoding='utf-8') as data_json:
                graph = json.load(data_json)
            graph['text'] = graph['text'][FILTERING_PARAMETERS[property_name]['text_idx']]
            with open(new_file_path, 'w', encoding='utf-8') as split_file:
                json.dump(graph, split_file)
            counter[split_name] += 1
    print(f"Created new data set at {new_dataset_path}")
    for split_name, count in counter.items():
        print(f"{split_name} split contains {count} graphs")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Subsampling a pcqm4m is a repeatable way')
    parser.add_argument('--dataset-path', type=str, required=True)
    parser.add_argument('--dataset-size', type=int, required=True)
    parser.add_argument('--graph-cutoff-size', type=int, required=True)
    parser.add_argument('--property', type=str, required=True)
    args = parser.parse_args()
    save_pcqm4m_dataset_subset(
        dataset_path=args.dataset_path,
        dataset_size=args.dataset_size,
        graph_cutoff_size=args.graph_cutoff_size,
        property_name=args.property
    )
