'''
This file contains helper functions for the syntheseus package.
'''
import os
import pickle
import time
import pandas as pd
import torch
from dataclasses import dataclass
from enum import Enum
from typing import Any

from syntheseus.reaction_prediction.inference import *
from syntheseus.search.mol_inventory import SmilesListInventory
from syntheseus.search.graph.and_or import AndNode, OrNode
from syntheseus.search.analysis import diversity
from syntheseus.reaction_prediction.inference.config import ModelConfig
from syntheseus.cli.eval_single_step import BaseEvalConfig

from multiguide.syntheseus.visualize import visualize_andor
from multiguide.syntheseus.single_step_models.root_aligned_fixed import RootAlignedFixedModel
from syntheseus import Molecule
from syntheseus.search.algorithms.best_first import retro_star
from syntheseus.search.node_evaluation.common import ReactionModelLogProbCost
from syntheseus.search.node_evaluation.common import ConstantNodeEvaluator
from syntheseus.search.analysis.route_extraction import (
    iter_routes_time_order,
)
from syntheseus.search.graph.and_or import AndNode
from multiguide.syntheseus.value_node_evaluator import ValueNodeEvaluator
from multiguide.syntheseus.value_node_with_property import ValueNodeWithPropertyEvaluator
from multiguide.helpers import PROJECT_ROOT
from multiguide.syntheseus.retro_star_search import MolIsPurchasableGuidedCost
from multiguide.syntheseus.retro_star_search import RetroStarSearchWithPropertyFilter
from multiguide.syntheseus.single_step_models.neural_sym import NeuralSymPredictor
from multiguide.desp.DESP import DESP
from multiguide.desp.retro_predictor import RetroPredictor

from multiguide.helpers import PROJECT_ROOT, device

MISSING: Any = "???"

class NewBackwardModelClass(Enum):
    Chemformer = ChemformerModel
    GLN = GLNModel
    Graph2Edits = Graph2EditsModel
    LocalRetro = LocalRetroModel
    MEGAN = MEGANModel
    MHNreact = MHNreactModel
    RetroKNN = RetroKNNModel
    RootAligned = RootAlignedModel
    RootAlignedFixed = RootAlignedFixedModel

@dataclass
class BackwardModelConfig(ModelConfig):
    """Config for loading one of the supported backward models."""

    model_class: NewBackwardModelClass = MISSING

@dataclass
class EvalConfig(BackwardModelConfig, BaseEvalConfig):
    """Config for running evaluation on a given dataset."""

    pass

def flatten_evaluation_results(evaluation_df):
    """Convert dict with lists to list of dicts with single values."""
    # Flatten: convert one dict with lists into multiple dicts with single values
    num_predictions = len(evaluation_df['reactant_predictions'])  # assuming this determines the length
    all_rows = []
    for j in range(num_predictions):
        row = {}
        for key, value_list in evaluation_df.items():
            if isinstance(value_list, list) and len(value_list) == num_predictions:
                row[key] = value_list[j]
            else:
                row[key] = value_list  # for scalar values that are repeated
        all_rows.append(row)

    return all_rows

def get_model(config):
    # models = [
    #     ChemformerModel(),
    #     Graph2EditsModel(),
    #     LocalRetroModel(),
    #     MEGANModel(),
    #     MHNreactModel(),
    #     RetroKNNModel()
    # ]
    print(f'======= using {config.single_step_model.model_type}')
    if config.single_step_model.model_dir is not None:
        model_dir = os.path.join(PROJECT_ROOT,
                                'checkpoints',
                                config.single_step_model.model_dir)
    else:
        model_dir = None
    if config.single_step_model.model_type == 'chemformer':
        model = ChemformerModel(use_cache=True,
                                default_num_results=config.single_step_model.default_num_results, # 10
                                model_dir=model_dir,
                                device=device)
    elif config.single_step_model.model_type == 'graph2edits':
        model = Graph2EditsModel(use_cache=True,
                                default_num_results=config.single_step_model.default_num_results, # 10
                                model_dir=model_dir,
                                device=device)
    elif config.single_step_model.model_type == 'localretro':
        model = LocalRetroModel(use_cache=True,
                                default_num_results=config.single_step_model.default_num_results, # 10
                                model_dir=model_dir,
                                device=device)
    elif config.single_step_model.model_type == 'megan':
        model = MEGANModel(use_cache=True,
                            default_num_results=config.single_step_model.default_num_results, # 10
                            model_dir=model_dir,
                            device=device)
    elif config.single_step_model.model_type == 'mhnreact':
        model = MHNreactModel(use_cache=True,
                              default_num_results=config.single_step_model.default_num_results, # 10
                              model_dir=model_dir,
                              device=device)
    elif config.single_step_model.model_type == 'retroknn':
        model = RetroKNNModel(use_cache=True,
                              default_num_results=config.single_step_model.default_num_results, # 10
                              model_dir=model_dir,
                              device=device)
    elif config.single_step_model.model_type == 'gln':
        model = GLNModel(use_cache=True,
                        default_num_results=config.single_step_model.default_num_results, # 10
                        model_dir=model_dir,
                        device=device)
    elif config.single_step_model.model_type == 'neuralsym':
        model_path = os.path.join(PROJECT_ROOT,
                                'checkpoints',
                                config.single_step_model.model_dir)
        templates_path = os.path.join(PROJECT_ROOT,
                                        'data',
                                        'desp_data',
                                        'idx2template_retro.json')
        model = NeuralSymPredictor(use_cache=True,
                                   default_num_results=config.single_step_model.default_num_results)
        model.setup(model_path, templates_path)
    else:
        raise ValueError(f"Model type {config.single_step_model.model_type} not found")
    return model

def search_desp(config, product_smi, conditional_starting_material=None, conditional_target=None):
    '''
    Get the results from the DESP model.
    '''
    # TODO: could make this read checkpoint from checkpoints directory
    retro_model_path = os.path.join(PROJECT_ROOT, 'data', 'desp_data', 'model_retro.pt')
    retro_templates_path = os.path.join(PROJECT_ROOT, 'data', 'desp_data', 'idx2template_retro.json')
    retro_predictor = RetroPredictor(
        model_path=retro_model_path,
        templates_path=retro_templates_path
    )
    desp = DESP(
        strategy=config.search.strategy,
        retro_predictor=retro_predictor,
        device= 0 if torch.cuda.is_available() else 'cpu'
    ) # switch to 'f2f' if you want to try F2F
    start_time = time.time()
    # result = (result, route, searcher)
    result = desp.search(
        product_smi, # Target SMILES
        [conditional_starting_material],  # List of starting materials SMILES
        stop_on_first_solution=config.search.stop_on_first_solution
    )
    searcher = result[2]
    num_iterations = result[0][1]
    search_stats = {
        'num_model_calls': searcher.num_model_calls,
        'num_nodes_explored': searcher.num_nodes_explored,
        'time_taken': time.time() - start_time,
        'num_iterations': num_iterations
    }
    return result, search_stats

def run_for_one_mol_desp(config, product_smi, product_idx, starting_material_smi=None):
    '''
        Run the DESP search for a given molecule.
    '''
    result, search_stats = search_desp(config, product_smi, starting_material_smi)
    print(result)
    out_dir = os.path.join(
        PROJECT_ROOT,
        'experiments', 
        'search',
        config.search.type,
        config.general.experiment_name,
        f'strategy_{config.search.strategy}',
        f'graphs_for_mol{product_idx}'
    )
    os.makedirs(out_dir, exist_ok=True)
    pickle.dump(
        result[1],
        open(
            os.path.join(
                out_dir,
                'output_graph.pkl'
            ), 'wb'
        )
    )
    pickle.dump(
        search_stats,
        open(
            os.path.join(
                out_dir,
                'search_stats.pkl'
            ),
        'wb'
        )
    )
    # pickle.dump(result[2], open(os.path.join(PROJECT_ROOT,
    #     out_dir,
    #     f'searcher{product_idx}.pkl'), 'wb'))
    return result, search_stats

def set_search_algorithm(config, model, target_smi=None):
    '''
    Set the search algorithm.
    '''
    print(f'======= using dummy_inventory {config.search.dummy_inventory}')
    if config.search.dummy_inventory:
        # Dummy inventory with just two purchasable molecules.
        start_time = time.time()
        inventory = SmilesListInventory(
            smiles_list=["Cc1ccc(B(O)O)cc1", "O=Cc1ccc(I)cc1"]
        )
    else:
        inventory = get_inventory(config)

    # NEW: remove the target from inventory if it's there
    if config.search.remove_target_from_inventory and target_smi is not None:
        original_inventory = inventory._mol_set
        target_mol = Molecule(target_smi)
        print(f'len(original_inventory): {len(original_inventory)}')
        inventory._mol_set.discard(target_mol)
        print(f'len(inventory._mol_set): {len(inventory._mol_set)}')
        print('removed target from inventory')

    # 1: OrNode cost function.
    # We will follow the original paper and give molecules a
    # cost of 0 if they are purchasable, and a cost of infinity
    # otherwise. This class is provided as a default in retro_star.
    # If purchasable molecules have non-zero costs then a different
    # cost function could be used.
    or_node_cost_fn = retro_star.MolIsPurchasableCost()
    #or_node_cost_fn = MolIsPurchasableGuidedCost()

    # 2: AndNode cost function
    # We will follow the original paper and define the cost of the
    # reaction as the -log(softmax) of the reaction model output,
    # thresholded at a minimum value. We use the built-in
    # `ReactionModelLogProbCost` class for this. This class simply
    # reads out the "probability" value from `reaction.metadata`,
    # which is provided by the PaRoutesModel.
    and_node_cost_fn = ReactionModelLogProbCost(normalize=False)

    # 3: search heuristic (value function)
    # Here we just use a constant value function which is always 0,
    # corresponding to the "retro*-0" algorithm (the most optimistic).
    print(f'======= using value function {config.search.heuristic}')
    print(f'======= using guided {config.search.guided}')
    if config.search.heuristic=='constant':
        retro_star_value_function = ConstantNodeEvaluator(0.0)
    elif config.search.guided or config.search.heuristic == 'value_function_with_property':
        assert config.classifier_guidance, \
            'Classifier guidance must be specified when heuristic is value_function_with_property'
        retro_star_value_function = ValueNodeWithPropertyEvaluator(config)
    elif config.search.heuristic=='value_function':
        assert config.value_function, \
            'Value function must be specified when heuristic is value_function'
        retro_star_value_function = ValueNodeEvaluator(config)
    else:
        raise ValueError(f"Heuristic {config.search.heuristic} not found")

    search_algorithm = RetroStarSearchWithPropertyFilter(
        config=config,
        reaction_model=model,
        mol_inventory=inventory,
        or_node_cost_fn=or_node_cost_fn,
        and_node_cost_fn=and_node_cost_fn,
        value_function=retro_star_value_function,
        limit_reaction_model_calls=config.search.limit_reaction_model_calls,
        time_limit_s=config.search.time_limit_s,
        limit_iterations=config.search.limit_iterations,
        stop_on_first_solution=config.search.stop_on_first_solution
    )
    return search_algorithm

def run_for_one_mol_retro_star(
    config, 
    search_algorithm, 
    smi, 
    smi_idx, 
    ground_truth_data=None, 
    starting_material_smi=None
):
    '''
        Run the retro-star search for a given molecule.

        Args:
            config: config object
            smi: SMILES string of the molecule
            smi_idx: index of the molecule
            target_classes: target classes for the molecule

        Returns:
            output_graph: output graph
            routes: routes
    '''
    # Set up a reaction model with caching enabled. Number of reactions
    # to request from the model at each step of the search needs to be
    # provided at construction time.
    test_mol = Molecule(smi)
    depth = 0
    branching_factor = 0
    test_mol.metadata['ground_truth_data'] = ground_truth_data
    test_mol.metadata['branching_factor'] = branching_factor
    test_mol.metadata['reaction_type_to_synthesize'] = ground_truth_data['mol_to_rxn_type'][smi]
    test_mol.metadata['sample_prev_reaction_type'] = ''
    test_mol.metadata['true_prev_reaction_type'] = ''
    test_mol.metadata['starting_material'] = starting_material_smi
    test_mol.metadata['depth'] = depth
    test_mol.metadata['ground_truth_reactants'] = ground_truth_data['depth_to_ground_truth_reactants'][depth][branching_factor]
    test_mol.metadata['parent_product'] = ''
    test_mol.metadata['true_tanimoto'] = ground_truth_data['mol_to_tanimoto'][smi]
    test_mol.metadata['sample_tanimoto'] = ground_truth_data['mol_to_tanimoto'][smi]
    test_mol.metadata['reaction_tanimoto'] = 0
    test_mol.metadata['max_tanimoto_property_weight'] = config.classifier_guidance.eval.tanimoto_weight
    search_algorithm.reset()
    start_time = time.time()
    print('======= running search with retro_star')
    output_graph, num_search_iterations = search_algorithm.run_from_mol(test_mol)
    search_stats = {
        'num_nodes_explored': len(output_graph),
        'time_taken': time.time() - start_time,
        'num_model_calls': search_algorithm.reaction_model.num_calls(),
        'num_search_iterations': num_search_iterations,
    }
    print(f"Search stats: {search_stats}")
    save_info_and_extract_routes(config, output_graph, smi_idx, search_stats)
    
def save_info_and_extract_routes(config, output_graph, smi_idx, search_stats):
    '''
        Save the output graph and extract the routes.
    '''
    if config.search.save_output_graph:
        start_time = time.time()
        output_graph_dir = os.path.join(
            PROJECT_ROOT,
            'experiments',
            'search',
            config.search.type,
            config.general.experiment_name, # experiment subfolder
            f'strategy_{config.search.strategy}',
            f'graphs_for_mol{smi_idx}'
        )
        os.makedirs(output_graph_dir, exist_ok=True)
        print(f'======= saving output graph to {output_graph_dir}')
        output_graph_path = os.path.join(
            output_graph_dir,
            'output_graph.pkl'
        )
        with open(output_graph_path, 'wb') as f:
            pickle.dump(output_graph, f)
        output_graph_stats_path = os.path.join(
            output_graph_dir,
            'search_stats.pkl'
        )
        with open(output_graph_stats_path, 'wb') as f:
            pickle.dump(search_stats, f)

    if config.search.extract_routes:
        ## Extract the routes simply in the order they were found.
        print('======= extracting routes')
        start_time = time.time()
        routes = list(iter_routes_time_order(output_graph, max_routes=config.search.max_routes_to_extract))
        print(f'Extracted {len(routes)} routes in {time.time() - start_time} seconds')

        for idx, route in enumerate(routes):
            num_reactions = len({n for n in route if isinstance(n, AndNode)})
            print(f"Route {idx + 1} consists of {num_reactions} reactions")

def get_inventory(config):
    print(f'======= creating inventory')
    start_time = time.time()
    inventory_path = os.path.join(PROJECT_ROOT,
                                  'data',
                                  'desp_data',
                                  f'{config.search.inventory_file}.pkl')
    print(f'======= inventory path: {inventory_path}, exists: {os.path.exists(inventory_path)}')
    if os.path.exists(inventory_path):
        with open(inventory_path, 'rb') as f:
            inventory = pickle.load(f)
    else:
        print(f'======= loading bbs')
        start_time = time.time()
        bb_mol2idx = os.path.join(PROJECT_ROOT,
                                  'data',
                                  f'{config.search.inventory_file}.csv')
        df = pd.read_csv(bb_mol2idx, index_col=0)
        bbs = df['mol'].tolist()
        print(f'loaded bbs in {time.time() - start_time} seconds')
        # TODO: fix this, original version does not have print_every
        inventory = SmilesListInventory(
            smiles_list=bbs,
            print_every=1000000
        )
        # save inventory in a pickle file
        with open(inventory_path, 'wb') as f:
            pickle.dump(inventory, f)
    print(f'======= done loading inventory.')
    return inventory

def get_unique_routes(output_graph, routes):
    routes_objects = [output_graph.to_synthesis_graph(route) for route in routes]
    packing_set = diversity.estimate_packing_number(
        routes=routes_objects,
        distance_metric=diversity.reaction_jaccard_distance,
        radius=0.99 # because comparison uses ">", not ">="
    )
    return packing_set

def compute_property(route, property_checkpoint_path, property_name, config):
    route_properties = {'step': [],
                        property_name: []}
    step = 0
    for n in route:
        # TODO: figure out a way to differentiate the levels in the search tree?
        if isinstance(n, AndNode):
            if property_name == 'yield':
                reactants = '.'.join([r.smiles for r in n.reaction.reactants])
                property_score = get_property_score(reactants,  
                                                    path=property_checkpoint_path, 
                                                    config=config)
                n.data = {'step': step,
                            property_name: property_score}
                route_properties['step'].append(step)
                route_properties[property_name].append(property_score)
            else:
                n.data = {}
            step += 1
        elif isinstance(n, OrNode):
            if property_name == 'yield':
                n.data = {}
            else:
                # compute toxicity
                property_score = get_property_score(n.mol.smiles,  
                                                    path=property_checkpoint_path, 
                                                    config=config)
                n.data = {'step': step,
                            property_name: property_score}
                route_properties['step'].append(step)
                route_properties[property_name].append(property_score)
    return route_properties

def compute_all_properties(route, toxicity_checkpoint_path, yield_checkpoint_path, config):
    route_properties = {'step': [],
                        'synthesis_accessibility': [],
                        'natural_product_likeness': [],
                        'toxicity': [],
                        'yield': []}
    route = list(route)
    fscore = npscorer.readNPModel()
    step = 0
    for n in route:
        # TODO: figure out a way to differentiate the levels in the search tree?
        if isinstance(n, AndNode):
            n.data = {}
            step += 1
        elif isinstance(n, OrNode):
            m = Chem.MolFromSmiles(n.mol.smiles)
            # sa_score = sascorer.calculateScore(m)
            # np_score = npscorer.scoreMol(m, fscore)
            if toxicity_checkpoint_path is not None:
                toxicity_score = get_property_score(n.mol.smiles,  
                                                    path=toxicity_checkpoint_path, 
                                                    config=config)
            if yield_checkpoint_path is not None:
                yield_score = get_property_score(n.mol.smiles, 
                                                path=yield_checkpoint_path, 
                                                config=config)
            # TODO: choose which properties to store
            # n.data = {'step': step,
            #           'synthesis_accessibility': sa_score, 
            #           'natural_product_likeness': np_score,
            #           'toxicity': toxicity_score,
            #           'yield': yield_score}
            n.data = {'step': step,
                      'toxicity': toxicity_score,
                      'yield': yield_score}
            # n.data['sa'] = sa_score
            # n.data['np'] = np_score
            # TODO: add toxicity
            route_properties['step'].append(step)
            # route_properties['synthesis_accessibility'].append(sa_score)
            # route_properties['natural_product_likeness'].append(np_score)
            # route_properties['toxicity'].append(toxicity_score)
            # route_properties['yield'].append(yield_score)
        # TODO: add yield for and nodes? maybe also forward reaction prediction (NLL under transformer)
    return route_properties

def visualize_routes(config,output_graph, routes, smi_idx):
    for idx, route in enumerate(routes):
        print(f'======= visualizing route {idx + 1}')
        path = os.path.join(PROJECT_ROOT,
                            "experiments", 
                            config.general.experiment_name,
                            f'graphs_for_mol{smi_idx}',
                            f"route_{idx + 1}.pdf")
        print(f'======= saving to {path}')
        visualize_andor(
            output_graph, filename=path, nodes=route
        )