import pickle

import rdkit.Chem.AllChem as Chem
import torch
import torch.nn.functional as F

from visualization.visualizer_edm import plot_data3d
from geo_ldm.latent_diffuser import EnLatentDiffusion

dataset_info = {
    'name': 'zinc250k',
    'atom_encoder': {'C': 0, 'N': 1, 'O': 2, 'F': 3, 'P': 4, 'S': 5, 'Cl': 6, 'Br': 7, 'I': 8},
    'atom_decoder': ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I'],
    'max_n_nodes': 38,
    'n_nodes': {6: 2, 7: 5, 8: 12, 9: 60, 10: 172, 11: 634, 12: 1042, 13: 1550, 14: 2454, 15: 3828, 16: 5554, 
                17: 7598, 18: 10252, 19: 12791, 20: 15683, 21: 18075, 22: 16394, 23: 18530, 24: 20444, 25: 20151, 
                26: 17047, 27: 13878, 28: 8863, 29: 6857, 30: 5366, 31: 4232, 32: 3311, 33: 2308, 34: 1551, 35: 902, 
                36: 347, 37: 117, 38: 1},
    'atom_types': {0: 3753634, 1: 621757, 2: 508274, 3: 69973, 4: 116, 5: 90704, 6: 37828, 7: 11241, 8: 795},
    'colors_dic': ['C7', 'C0', 'C3', 'C1', 'C4', 'C8', 'C9', 'C11', 'C12'],
    'radius_dic': [0.3] * 9,
    'with_h': False
}

def smiles_to_3d(smiles, randomSeed=85):
    mol = Chem.MolFromSmiles(smiles)
    # remove stereo
    Chem.RemoveStereochemistry(mol)
    mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))

    Chem.SanitizeMol(mol)
    mol = Chem.AddHs(mol)
    conf_id = Chem.EmbedMolecule(mol, maxAttempts=1000, randomSeed=randomSeed)
    Chem.MMFFOptimizeMolecule(mol)

    # compute 3D coords
    mol = Chem.RemoveHs(mol)

    pos = mol.GetConformer().GetPositions()
    pos = torch.Tensor(pos)

    atom_encoder = dataset_info["atom_encoder"]
    atom_indices = []
    for atom in mol.GetAtoms():
        atom_indices.append(atom_encoder[atom.GetSymbol()])

    return pos, atom_indices


def plot_3d_mol(pos, atom_indices, save_path):
    plot_data3d(pos, atom_indices, dataset_info, save_path=save_path, bg='white', plot_edges=False, spheres_3d=False, transparent=True)

def sample_random_3d(pos, atom_indices):
    random_pos = torch.randn(pos.size())
    random_atom_indices = torch.distributions.categorical.Categorical(torch.Tensor([1]*9)).sample(torch.Size((len(atom_indices),)))
    return random_pos, random_atom_indices

def plot_different_time_steps(pos, atom_indices, ts):
    exp_name = 'ldm_training_no_norm' # put most recent run
    with open(f'outputs/{exp_name}/args.pickle', 'rb') as f:
        args = pickle.load(f)
    vdm = EnLatentDiffusion(
        vae=torch.nn.Identity(),
        trainable_ae=args.trainable_ae,
        dynamics=torch.nn.Identity(),
        in_node_nf=len(dataset_info['atom_decoder']) + int(args.include_atomic_numbers) + 3,
        n_dims=3,
        timesteps=args.diffusion_steps,
        noise_schedule=args.diffusion_noise_schedule,
        noise_precision=args.diffusion_noise_precision,
        loss_type=args.diffusion_loss_type,
        norm_values=args.normalize_factors,
        include_charges=args.include_atomic_numbers,
        joint_training=args.joint_training,
        use_eps_correction=args.use_eps_correction,
        joint_space=args.joint_space,
        lambda_joint_loss=args.lambda_joint_loss,
        )
    # Sample zt ~ Normal(alpha_t x, sigma_t)
    eps = vdm.sample_combined_position_feature_noise(
        n_samples=1, n_nodes=pos.size(0), node_mask=torch.ones((1, pos.size(0), 1)))

    h_cat = F.one_hot(torch.Tensor(atom_indices).to(torch.int64), num_classes=9).unsqueeze(0)
    h_int = F.one_hot(torch.zeros((25,)).to(torch.int64), num_classes=3).unsqueeze(0)
    x = pos.unsqueeze(0)
    xh = torch.cat([x, h_cat, h_int], dim=2)

    for timestep in ts:
        t = torch.Tensor([timestep])
        gamma_t = vdm.inflate_batch_array(vdm.gamma(t), x)

        # Compute alpha_t and sigma_t from gamma.
        alpha_t = vdm.alpha(gamma_t, x)
        sigma_t = vdm.sigma(gamma_t, x)
        # Sample z_t given x, h for timestep t, from q(z_t | x, h)
        z_t = alpha_t * xh + sigma_t * eps

        perturbed_pos = z_t[0, :, :3]
        perturbed_atom_indices = z_t[0, :, 3:-3].argmax(-1).numpy()
        plot_3d_mol(perturbed_pos, perturbed_atom_indices, save_path=f'figures/poster/mol3d_t_{timestep}.png')