import pickle
import pandas as pd
import logging
from mlp_retrosyn.mlp_inference import MLPModel
from alg import molstar

def prepare_starting_molecules(filename):
    logging.info('Loading starting molecules from %s' % filename)

    if filename[-3:] == 'csv':
        starting_mols = set(list(pd.read_csv(filename)['mol']))
    else:
        assert filename[-3:] == 'pkl'
        with open(filename, 'rb') as f:
            starting_mols = pickle.load(f)

    logging.info('%d starting molecules loaded' % len(starting_mols))
    return starting_mols

def prepare_mlp(templates, model_dump, device, polymer=False, use_load_model=False):
    logging.info('Templates: %s' % templates)
    logging.info('Loading trained mlp model from %s' % model_dump)
    one_step = MLPModel(model_dump, templates, device=device, polymer=polymer, use_load_model=use_load_model)
    return one_step

def prepare_test_polymers(filename):
    logging.info('Loading test polymers from %s' % filename)
    repeat_units = []
    monomers = []
    ring_units = []
    double_units = []
    with open(filename, 'r') as f:
        f.readline()
        line = f.readline()
        while len(line) > 0:
            repeat_unit, monomer, ring_unit, double_unit, _ = line.split(',')

            repeat_units.append(repeat_unit)
            monomers.append(monomer)
            ring_units.append(ring_unit)
            double_units.append(double_unit)

            line = f.readline()

    logging.info('%d polymers loaded' % len(repeat_units))

    return repeat_units, monomers, ring_units, double_units

def prepare_molstar_planner(one_step, value_fn, use_gln, starting_mols,
                            expansion_beam, expansion_topk, iterations,
                            viz=False, viz_dir=None):
    if use_gln:
        expansion_handle = lambda x: one_step.run(x,
                                                  beam_size=expansion_beam,
                                                  topk=expansion_topk)
    else:
        expansion_handle = lambda x: one_step.run(x, topk=expansion_topk)

    plan_handle = lambda x, y: molstar(
        target_mol=x,
        target_mol_id=y,
        starting_mols=starting_mols,
        expand_fn=expansion_handle,
        value_fn=value_fn,
        iterations=iterations,
        viz=viz,
        viz_dir=viz_dir
    )
    return plan_handle
