from typing import List
import copy
import math
import rdkit.Chem.AllChem as Chem

import numpy as np
import torch
from tqdm import tqdm

from guacamol.distribution_matching_generator import DistributionMatchingGenerator

from guacamol_evaluation.load_model import load_ldm_model
from guacamol_evaluation.sample_smiles import sample_3d_molecules, smiles_from_3d_molecules, smiles_from_3d_molecules_with_edge_model
from guacamol_evaluation.ldm_sampling import sample_from_ldm
from qm9.data.molecule_class import Molecule, MoleculeBatch
from qm9.analyze_joint_training import is_valid
from synthetic_coordinates.rdkit_helpers import smiles_to_mol


class LDMSmilesScaffolder(DistributionMatchingGenerator):
    """
    Generator that samples SMILES strings using an EDM model.
    """

    def __init__(self, exp_folder: str, batch_size: int, 
                 updated_nodes_dist, scaffold_smiles, diffusion_steps, jump_len, jump_n_sample,
                 total_number_samples: int = 10000, use_cached_3d_mols=False, ckpt_prefix='') -> None:
        """
        Args:
            exp_folder (str): path to folder containing generator model checkpoints. e.g. outputs/edm_zinc/
            batch_size (int): batch size used to sample from the model
            edge_predictor_path (str): path to folder containing edge predictor model checkpoints. e.g. outputs/edge_model_zinc/
        """
        self.model, utilities_dict = load_ldm_model(exp_folder, ckpt_prefix=ckpt_prefix, mol_optimizer=True)
        self.nodes_dist, self.args, self.device, self.dataset_info, property_norms, conditioning, self.prop_dist = utilities_dict.values()
        self.dtype = torch.float32
        self.model.set_device_and_dtype(self.device, self.dtype)
        # update diffusion steps
        self.model.T = diffusion_steps
        
        # update nodes_dist on just subset with the scaffold
        self.nodes_dist = updated_nodes_dist

        # set up scaffold
        self.scaffold_smiles = scaffold_smiles
        self.scaffold_mol = Molecule(self.device, self.dtype, smiles=scaffold_smiles)
        scaffold_graph_single = self.scaffold_mol.get_graph()
        self.scaffold_n_atoms = scaffold_graph_single['num_atoms'].cpu()

        # set attributes
        self.jump_len = jump_len
        self.jump_n_sample = jump_n_sample
        
        self.batch_size = batch_size

        self.total_number_samples = total_number_samples
        self.generated_3d_mols = {}
        self.generated_smiles = []
        self.n_smiles_used = 0

        if use_cached_3d_mols:
            with np.load('guacamol_evaluation/eval_zinc_joint_training.npz') as f:
                cached_3d_mols = {key: torch.from_numpy(val) for key, val in f.items()}
            print('Building smiles from cached 3D molecules')
            smiles, _, _ = smiles_from_3d_molecules_with_edge_model(self.edge_model, cached_3d_mols, min(self.batch_size, 100), self.dataset_info, return_mols=False)
            print(f'Done building smiles!')
            self.generated_smiles = smiles
            print(f'Could generate: {len(self.generated_smiles)} smiles')


    def generate(self, number_samples: int, return_mols: bool = False) -> List[str]:
        if number_samples == self.total_number_samples:
            self.n_smiles_used = 0

        cached_smiles = self.generated_smiles[self.n_smiles_used: self.n_smiles_used+number_samples]
        number_samples_to_generate = number_samples - len(cached_smiles)
        self.n_smiles_used += number_samples

        if number_samples_to_generate > 0:
            number_samples = number_samples_to_generate
        else:
            return cached_smiles
        
        # code breaks if asked to generate 1 mol. so sample 2 and return the first one
        return_just_one = False
        if number_samples == 1:
            number_samples = 2
            return_just_one = True

        smiles = self.sample_with_scaffold(number_samples)
        
        if return_just_one:
            smiles = [smiles[0]]

        cached_smiles.extend(smiles)
        self.generated_smiles.extend(smiles)

        return cached_smiles

    def sample_with_scaffold(self, n_samples):
        self.model.eval()
        batch_size = min(self.batch_size, n_samples)

        n_batches = math.ceil(n_samples/batch_size) # account for remainder
        generated_smiles = []
        for i in tqdm(range(n_batches)):
            if i == n_batches - 1 and n_samples % batch_size != 0:
                n_mols = n_samples % batch_size
            else:
                n_mols = batch_size

            nodesxsample = self.nodes_dist.sample(n_mols)
            # construct a batch out of the molecule to sample different completions
            scaffold_graph = MoleculeBatch(batch_size=n_mols, molecule=self.scaffold_mol).graph

            x, h, node_mask = self.model.complete_scaffold(scaffold_graph, fix_noise=False, 
                                    add_nodes=nodesxsample-self.scaffold_n_atoms, resampling_times=10,
                                    use_jumps=True, jump_len=self.jump_len, jump_n_sample=self.jump_n_sample, 
                                    dataset_info=self.dataset_info)

            smiles = self.model.get_smiles_from_x_h(x, h, node_mask, self.dataset_info)
            generated_smiles.extend(smiles)

            # for every valid generated smiles, make sure it contains the scaffold
            # this is expected by construction
            # submol = smiles_to_mol(self.scaffold_smiles, only_explicit_H=True)
            # for s in smiles:
            #     if is_valid(s) and is_valid(Chem.MolToSmiles(Chem.MolFromSmiles(s))):
            #         m = smiles_to_mol(s, only_explicit_H=True)
            #         assert m.HasSubstructMatch(submol), "Generated a valid molecule not containing the scaffold. Something is wrong (?)"

        return generated_smiles

