from typing import List
import copy

import numpy as np
import torch

from guacamol.distribution_matching_generator import DistributionMatchingGenerator

from guacamol_evaluation.load_model import load_model
from guacamol_evaluation.sample_smiles import sample_3d_molecules, smiles_from_3d_molecules, smiles_from_3d_molecules_with_edge_model
from bond_type_prediction.initialize_pp_model import load_trained_edge_model


class EDMSmilesSampler(DistributionMatchingGenerator):
    """
    Generator that samples SMILES strings using an EDM model.
    """

    def __init__(self, generator_model_path: str, batch_size: int, edge_predictor_path: str = None, total_number_samples: int = 10000, use_cached_3d_mols=False) -> None:
        """
        Args:
            generator_model_path (str): path to folder containing generator model checkpoints. e.g. outputs/edm_zinc/
            batch_size (int): batch size used to sample from the model
            edge_predictor_path (str): path to folder containing edge predictor model checkpoints. e.g. outputs/edge_model_zinc/
        """
        self.model, utilities_dict = load_model(generator_model_path)
        self.nodes_dist, self.args, self.device, self.dataset_info = utilities_dict.values()
        self.batch_size = batch_size

        if edge_predictor_path is not None:
            self.edge_model = load_trained_edge_model(edge_predictor_path)
        else:
            print('Not using the Edge prediction model')
            self.edge_model = None

        self.total_number_samples = total_number_samples
        self.generated_3d_mols = {}
        self.generated_smiles = []
        self.n_smiles_used = 0

        if use_cached_3d_mols:
            with np.load('guacamol_evaluation/eval_zinc_joint_training.npz') as f:
                cached_3d_mols = {key: torch.from_numpy(val) for key, val in f.items()}
            print('Building smiles from cached 3D molecules')
            smiles, _, _ = smiles_from_3d_molecules_with_edge_model(self.edge_model, cached_3d_mols, min(self.batch_size, 100), self.dataset_info, return_mols=False)
            print(f'Done building smiles!')
            self.generated_smiles = smiles
            print(f'Could generate: {len(self.generated_smiles)} smiles')


    def generate(self, number_samples: int, return_mols: bool = False) -> List[str]:
        if number_samples == self.total_number_samples:
            self.n_smiles_used = 0

        cached_smiles = self.generated_smiles[self.n_smiles_used: self.n_smiles_used+number_samples]
        number_samples_to_generate = number_samples - len(cached_smiles)
        self.n_smiles_used += number_samples

        if number_samples_to_generate > 0:
            number_samples = number_samples_to_generate
        else:
            return cached_smiles
        
        # code breaks if asked to generate 1 mol. so sample 2 and return the first one
        return_just_one = False
        if number_samples == 1:
            number_samples = 2
            return_just_one = True

        molecules = sample_3d_molecules(self.model, self.nodes_dist, self.args, self.device, self.dataset_info, 
                                        n_samples=number_samples, batch_size=self.batch_size)

        if self.generated_3d_mols == {}:
            self.generated_3d_mols = copy.deepcopy(molecules)
        else:
            for key in self.generated_3d_mols:
                self.generated_3d_mols[key] = torch.cat([self.generated_3d_mols[key], molecules[key]], dim=0)

        if self.edge_model is None:
            smiles = smiles_from_3d_molecules(molecules, self.dataset_info)
        else:
            smiles = smiles_from_3d_molecules_with_edge_model(self.edge_model, molecules, min(self.batch_size, 100), self.dataset_info, return_mols=return_mols)
        
        if return_just_one:
            smiles = [smiles[0]]

        cached_smiles.extend(smiles)
        self.generated_smiles.extend(smiles)

        return cached_smiles
