from typing import List
import copy

import numpy as np
import torch

from guacamol.distribution_matching_generator import DistributionMatchingGenerator

from guacamol_evaluation.load_model import load_ldm_model
from guacamol_evaluation.sample_smiles import sample_3d_molecules, smiles_from_3d_molecules, smiles_from_3d_molecules_with_edge_model
from guacamol_evaluation.ldm_sampling import sample_from_ldm


class LDMSmilesSampler(DistributionMatchingGenerator):
    """
    Generator that samples SMILES strings using an EDM model.
    """

    def __init__(self, exp_folder: str, batch_size: int, total_number_samples: int = 10000, use_cached_3d_mols=False, ckpt_prefix='') -> None:
        """
        Args:
            exp_folder (str): path to folder containing generator model checkpoints. e.g. outputs/edm_zinc/
            batch_size (int): batch size used to sample from the model
            edge_predictor_path (str): path to folder containing edge predictor model checkpoints. e.g. outputs/edge_model_zinc/
        """
        self.model, utilities_dict = load_ldm_model(exp_folder, ckpt_prefix=ckpt_prefix)
        self.nodes_dist, self.args, self.device, self.dataset_info, property_norms, conditioning, self.prop_dist = utilities_dict.values()
        self.batch_size = batch_size

        self.total_number_samples = total_number_samples
        self.generated_3d_mols = {}
        self.generated_smiles = []
        self.n_smiles_used = 0

        if use_cached_3d_mols:
            with np.load('guacamol_evaluation/eval_zinc_joint_training.npz') as f:
                cached_3d_mols = {key: torch.from_numpy(val) for key, val in f.items()}
            print('Building smiles from cached 3D molecules')
            smiles, _, _ = smiles_from_3d_molecules_with_edge_model(self.edge_model, cached_3d_mols, min(self.batch_size, 100), self.dataset_info, return_mols=False)
            print(f'Done building smiles!')
            self.generated_smiles = smiles
            print(f'Could generate: {len(self.generated_smiles)} smiles')


    def generate(self, number_samples: int, return_mols: bool = False) -> List[str]:
        if number_samples == self.total_number_samples:
            self.n_smiles_used = 0

        cached_smiles = self.generated_smiles[self.n_smiles_used: self.n_smiles_used+number_samples]
        number_samples_to_generate = number_samples - len(cached_smiles)
        self.n_smiles_used += number_samples

        if number_samples_to_generate > 0:
            number_samples = number_samples_to_generate
        else:
            return cached_smiles
        
        # code breaks if asked to generate 1 mol. so sample 2 and return the first one
        return_just_one = False
        if number_samples == 1:
            number_samples = 2
            return_just_one = True

        molecules, smiles = sample_from_ldm(self.model, self.nodes_dist, self.args, self.device, self.dataset_info, 
                                        n_samples=number_samples, batch_size=self.batch_size)

        if self.generated_3d_mols == {}:
            self.generated_3d_mols = copy.deepcopy(molecules)
        else:
            for key in self.generated_3d_mols:
                self.generated_3d_mols[key] = torch.cat([self.generated_3d_mols[key], molecules[key]], dim=0)
        
        if return_just_one:
            smiles = [smiles[0]]

        cached_smiles.extend(smiles)
        self.generated_smiles.extend(smiles)

        return cached_smiles
