#!/usr/bin/env python

import heapq
import math
from itertools import product

import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from tqdm.auto import tqdm

from ts_main import read_input, parse_input_dict
from ts_utils import read_reagents, get_pareto_indices


def keep_largest(items, n):
    """Keeps the n largest items in a list, designed to work with a list of [score,SMILES]
    :param items: the list of items to keep
    :param n: the number of items to keep
    :return: list of the n largest items
    """
    heap = []
    for item in items:
        if len(heap) < n:
            heapq.heappush(heap, item)
        else:
            if item[0] > heap[0][0]:
                heapq.heapreplace(heap, item)
    return heap


def keep_pareto_optimal(listlist, _):
    # Extract scores into a NumPy array for efficient computation
    scores = np.stack([x[2] for x in listlist])
    # Get indices of non-dominated points
    pareto_indices = get_pareto_indices(scores)
    # Create a boolean mask for non-dominated elements
    non_dominated = np.zeros(len(listlist), dtype=bool)
    non_dominated[pareto_indices] = True
    # Return only the non-dominated elements
    return [elem for elem, nd in zip(listlist, non_dominated) if nd]


def unpack_input_dict(input_dict, num_to_select=None):
    """ Unpack the input dictionary and create the Evaluator object
    :param input_dict:
    :param num_to_select:
    :return:
    """
    if input_dict.get("evaluator_class") is None:
        parse_input_dict(input_dict)
    evaluator = input_dict["evaluator_class"]
    reaction_smarts = input_dict["reaction_smarts"]
    reagent_file_list = input_dict["reagent_file_list"]
    rxn = AllChem.ReactionFromSmarts(reaction_smarts)
    reagent_lists = read_reagents(reagent_file_list, num_to_select, evaluator.num_objs)
    return evaluator, rxn, reagent_lists


def enumerate_library(json_filename, outfile_name, num_to_select):
    _, rxn, reagent_lists = setup_baseline(json_filename, num_to_select)
    len_list = [len(x) for x in reagent_lists]
    total_prods = math.prod(len_list)
    print(f"{total_prods:.2e} products")
    product_list = []
    for reagents in tqdm(product(*reagent_lists), total=total_prods):
        reagent_mol_list = [x.mol for x in reagents]
        prod = rxn.RunReactants(reagent_mol_list)
        if len(prod):
            product_mol = prod[0][0]
            Chem.SanitizeMol(product_mol)
            product_smiles = Chem.MolToSmiles(product_mol)
            product_name = "_".join([x.reagent_name for x in reagents])
            product_list.append([product_smiles, product_name])
    product_df = pd.DataFrame(product_list, columns=["SMILES", "Name"])
    product_df.to_csv(outfile_name, index=False)


def setup_baseline(json_filename, num_to_select=None):
    """ Common code for baseline methods, reads JSON input and creates necessary objects
    :param json_filename: JSON file with configuration options
    :param num_to_select: number of reagents to use with exhaustive search. Set to a lower values for development.
    Setting to None uses all reagents.
    :return: evaluator class, RDKit reaction, list of lists with reagents
    """
    input_dict = read_input(json_filename)
    return unpack_input_dict(input_dict, num_to_select=num_to_select)


def random_baseline_general(
    input_dict, num_trials,
    outfile_name="scores.csv",
    num_to_save=100,
    ascending_output=False,
    filter_func=None,
    score_colname="score"
):
    """ Randomly combine reagents to create products and evaluate them
    :param input_dict: parameters from the input JSON file
    :param num_trials: number of random products to create
    :param outfile_name: name of the output file
    :param num_to_save: number of molecules to save to the output csv file
    :param ascending_output: sort ascending if True
    :param filter_func: filtering function to apply, keep_largest or keep_pareto_optimal
    :param score_colname: name of the score column in the output file
    """
    score_list = []
    evaluator, rxn, reagent_lists = unpack_input_dict(input_dict)
    num_reagents = len(reagent_lists)
    len_list = [len(x) for x in reagent_lists]
    total_prods = np.prod(len_list)
    print(f"{total_prods:.2e} products")
    for _ in tqdm(range(num_trials)):
        reagent_mol_list = []
        reagen_name_list = []
        for j in range(num_reagents):
            reagent_idx = np.random.randint(0, len_list[j] - 1)
            reagent_mol_list.append(reagent_lists[j][reagent_idx].mol)
            reagen_name_list.append(reagent_lists[j][reagent_idx].reagent_name)
        prod = rxn.RunReactants(reagent_mol_list)
        if len(prod):
            product_mol = prod[0][0]
            Chem.SanitizeMol(product_mol)
            score = evaluator.evaluate(product_mol)
            product_smiles = Chem.MolToSmiles(product_mol)
            product_name = "_".join(reagen_name_list)
            score_list.append([product_smiles, product_name, score])
        # Apply filtering function, keep_largest or keep_pareto_optimal depending on SO or MO
        # if filter_func is not None:
        #     score_list = filter_func(score_list, num_to_save)
    # Format DataFrame
    score_df = pd.DataFrame(score_list, columns=["SMILES", "Name", score_colname]).round(3)
    # Sort ascending if only a single objective is considered
    if isinstance(score_df[score_colname][0], (int, float)):
        score_df.sort_values(by=score_colname, ascending=ascending_output)
    score_df.to_csv(outfile_name, index=False)


def exhaustive_baseline(input_dict, num_to_select=None, num_to_save=100, invert_score=False):
    """ Exhaustively combine all reagents
    :param input_dict: parameters from the input JSON file
    :param num_to_select: Number of reagents to use, set to a lower number for development.  Set to None to use all.
    :param num_to_save: number of molecules to save to the output csv file
    :param invert_score: set to True when more negative values are better
    """
    score_list = []
    evaluator, rxn, reagent_lists = unpack_input_dict(input_dict, num_to_select)
    len_list = [len(x) for x in reagent_lists]
    total_prods = math.prod(len_list)
    print(f"{total_prods:.2e} products")
    for reagents in tqdm(product(*reagent_lists), total=total_prods):
        reagent_mol_list = [x.mol for x in reagents]
        prod = rxn.RunReactants(reagent_mol_list)
        if len(prod):
            product_mol = prod[0][0]
            Chem.SanitizeMol(product_mol)
            product_smiles = Chem.MolToSmiles(product_mol)
            product_name = "_".join([x.reagent_name for x in reagents])
            score = evaluator.evaluate(product_mol)
            if invert_score:
                score = score * -1.0
            score_list.append([score, product_smiles, product_name])
    #            score_list = keep_largest(score_list + [[score, product_smiles, product_name]], num_to_save)
    score_df = pd.DataFrame(score_list, columns=["score", "SMILES", "Name"])
    score_df.sort_values(by="score", ascending=False).to_csv("exhaustive_scores.csv", index=False)


def exhaustively_enumerate_lib(input_dict, output_filename="exhaustively_enumerated_scores.csv"):
    """ Exhaustively combine all reagents and continuously write the scores to a file
    :param input_dict: parameters from the input JSON file
    :param output_filename: name of the output file
    """
    evaluator, rxn, reagent_lists = unpack_input_dict(input_dict, None)
    len_list = [len(x) for x in reagent_lists]
    total_prods = math.prod(len_list)
    print(f"{total_prods:.2e} products")
    wrote_header = False
    for idx, reagents in enumerate(tqdm(product(*reagent_lists), total=total_prods)):
        reagent_mol_list = [x.mol for x in reagents]
        prod = rxn.RunReactants(reagent_mol_list)
        if len(prod):
            product_mol = prod[0][0]
            Chem.SanitizeMol(product_mol)
            product_smiles = Chem.MolToSmiles(product_mol)
            product_name = "_".join([x.reagent_name for x in reagents])
            score = evaluator.evaluate(product_mol)
            df = pd.DataFrame([[product_smiles, product_name, score]], columns=["SMILES", "Name", "score"])
            if not wrote_header:
                df.to_csv(output_filename, index=False, mode="w")
                wrote_header = True
            else:
                df.to_csv(output_filename, index=False, mode="a", header=False)


def main():
    num_to_select = -1
    input_dict = read_input("examples/quinazoline_fp_sim.json")
    # exhaustive_baseline(input_dict, num_to_select=num_to_select)
    # enumerate 50K random molecules
    random_baseline_general(input_dict, num_trials=50000, num_to_save=50000)


if __name__ == "__main__":
    main()
