from tqdm import tqdm

import torch

import schnetpack.transform as trn
from schnetpack import properties
from schnetpack.data.loader import _atoms_collate_fn
from schnetpack.diffusion.sampling_analysis import check_validity, generate_bonds_data
from ase.data import covalent_radii
from ase import Atoms, build


def batch_center_systems(
    systems: torch.tensor, indices: torch.tensor, n_atoms: torch.tensor, dim: int = 0
):
    """
    center the system to have zero center of geometry
    """
    mean = scatter_mean(systems, indices, n_atoms, dim=dim)
    return systems - mean[indices]


def scatter_mean(
    systems: torch.Tensor, indices: torch.Tensor, n_atoms: torch.Tensor, dim: int = 0
) -> torch.Tensor:

    shape = list(systems.shape)
    shape[dim] = len(indices.unique())
    tmp = torch.zeros(shape, dtype=systems.dtype, device=systems.device)
    sum = tmp.index_add_(dim, indices, systems)
    if len(sum.shape) == 1:
        mean = sum / n_atoms
    else:
        mean = sum / n_atoms.unsqueeze(-1)
    return mean


def rmsd(reference: Atoms, sample: Atoms, keep_original=True):
    """
    function to compute the root mean squared deviation betweeen molecules in a batch
    """
    # keep the original positions or overwrite them with the rotated positions
    if keep_original:
        tmp = sample.copy()
    else:
        tmp = sample
        
    # compute the rotation and translation that minimizes the rmsd
    build.minimize_rotation_and_translation(reference, tmp)
    
    # compute the rmsd
    diff = ((reference.positions - tmp.positions)**2).sum(-1).mean()**0.5
    
    return diff


def re_compute_neighbors(old_batch, cutoff=5.0, additional_keys=[]):
    """
    function to recompute neighbors for a batch of systems
    """
    neighbors_calculator = trn.MatScipyNeighborList(cutoff=cutoff)

    # device = old_batch[properties.R].device
    batch = []
    for j, i in enumerate(old_batch[properties.idx_m].unique()):
        mask = old_batch[properties.idx_m] == i
        # inp = {
        #     properties.idx: old_batch[properties.idx][[j]].detach().cpu(),
        #     properties.n_atoms: old_batch[properties.n_atoms][[j]].detach().cpu(),
        #     properties.Z: old_batch[properties.Z][mask].detach().cpu(),
        #     properties.R: old_batch[properties.R][mask].detach().cpu(),
        #     properties.cell: old_batch[properties.cell][[j]].detach().cpu(),
        #     properties.pbc: old_batch[properties.pbc].view(-1,3)[j].detach().cpu(),
        # }
        inp = {
            properties.idx: old_batch[properties.idx][[j]].detach(),
            properties.n_atoms: old_batch[properties.n_atoms][[j]].detach(),
            properties.Z: old_batch[properties.Z][mask].detach(),
            properties.R: old_batch[properties.R][mask].detach(),
            properties.cell: old_batch[properties.cell][[j]].detach(),
            properties.pbc: old_batch[properties.pbc].view(-1, 3)[j].detach(),
        }

        inp = neighbors_calculator(inp)

        batch.append(inp)

    batch = _atoms_collate_fn(batch)

    batch.update({k: old_batch[k] for k in additional_keys})

    # batch = {p: batch[p].to(device) for p in batch}

    return batch


def check_connectivity(inputs, relax_coef=1.17):
    """
    function to simple and fast check connectivity of molecules.
    """
    results = torch.zeros_like(inputs[properties.idx_m])
    sum_w_H = 0
    sum_wo_H = 0
    for m in tqdm(inputs[properties.idx_m].unique()):
        mask = inputs[properties.idx_m] == m
        at_num = inputs[properties.Z][mask]
        dis = inputs[properties.R][mask]

        dists = torch.cdist(dis, dis)
        covalent_dists = relax_coef * (
            covalent_radii[at_num][:, None] + covalent_radii[at_num][None, :]
        )
        connected = ((dists < torch.from_numpy(covalent_dists)) & (0.0 < dists)).any(
            dim=1
        )
        sum_w_H += connected.all().long()
        sum_wo_H += (connected | (at_num == 1)).all().long()
        results[mask] = connected.long()

    num_mols = len(inputs[properties.idx_m].unique()) * 1.0
    return results, sum_w_H / num_mols, sum_wo_H / num_mols


def sample_prior_positions(inputs, noise_schedule, t):
    t = torch.tensor(t).float()
    t = t / noise_schedule.T
    noise_params = noise_schedule(t)
    device = inputs[properties.R].device
    sqrt_beta_bar = noise_params["sqrt_beta_bar"].to(device=device)
    sqrt_alpha_bar = noise_params["sqrt_alpha_bar"].to(device=device)
    eps_x = torch.randn_like(inputs[properties.R])
    eps_x = batch_center_systems(
        eps_x, inputs[properties.idx_m], inputs[properties.n_atoms]
    )
    return sqrt_alpha_bar * inputs["original_R"].to(device=device) + sqrt_beta_bar * eps_x


def sample_R(
    inputs,
    model,
    noise_schedule,
    cutoff=5.0,
    T=1000,
    start=None,
    random=True,
    use_forces=True,
    save_progress=True,
    progress_stride=1,
    use_cpu=False,
    recompute_neighbors=False,
    use_orig = False
):
    """
    function to sample R from the model
    """

    n_samples = len(inputs[properties.n_atoms])
    device = inputs[properties.R].device
    if start is None:
        start = T - 1
    
    if use_orig:
        zt  = inputs[properties.R].clone()
    elif start < 0:
        zt  = inputs[properties.R].clone()
        start = T-1
    elif random:
        zt = torch.randn_like(inputs[properties.R])
    else:
        zt = sample_prior_positions(inputs, noise_schedule, start)

    zt = batch_center_systems(zt, inputs[properties.idx_m], inputs[properties.n_atoms])

    inputs[properties.R] = zt

    # recompute neighbors once such that all atoms are neighbors of themselves
    if not recompute_neighbors:
        if not use_cpu:
            inputs = {p: inputs[p].cpu() for p in inputs}

        inputs = re_compute_neighbors(inputs, cutoff=5000.0)

        if not use_cpu:
            inputs = {p: inputs[p].to(device=device) for p in inputs}

    if save_progress:
        n = (start + 1) // progress_stride
        molecules = torch.zeros((n + 1,) + inputs[properties.R].size(), device=device)
        molecules[-1] = zt
    else:
        molecules = None

    # !!!!!!!!!!!! to be fixed (just append alpha_1 to the alpha_ts) ? : this is doing only 999 denoising step insted of 1000
    # and end up with alpha_2 and not alpha_1
    # because alphas where computed from alphas_bar[1:] meaning starting from alpha_2_bar / alpha_1_bar = alpha_2 and not alpha_1
    # !!! the code in transforms.py needs to be changed for using the new alphas if they are corrected
    for t in tqdm(range(start, -1, -1)):
        t_array = torch.full((n_samples, 1), fill_value=t, device=device)
        s_array = t_array - 1  # step s = t-1
        s_array = s_array / T
        t_array = t_array / T

        t_noise_params_train = noise_schedule(t_array, stage="fit")
        t_noise_params_test = noise_schedule(t_array, stage="test")
        sqrt_beta_t_bar = t_noise_params_train["sqrt_beta_bar"][
            inputs[properties.idx_m]
        ]
        beta_t = t_noise_params_test["beta_full"][inputs[properties.idx_m]]
        sqrt_alpha_t = t_noise_params_test["sqrt_alpha_full"][inputs[properties.idx_m]]
        sqrt_sigma_t = t_noise_params_test["sqrt_sigma"][inputs[properties.idx_m]]

        # Uncomment the following if predicting position instead of noise
        # alpha_s_bar = s_noise_params_train['alpha_bar'][inputs[properties.idx_m]]
        # beta_s_bar = s_noise_params_train['beta_bar'][inputs[properties.idx_m]]
        # beta_t_bar = t_noise_params_train['beta_bar'][inputs[properties.idx_m]]
        # sqrt_alpha_t_bar = t_noise_params_train['sqrt_alpha_bar'][inputs[properties.idx_m]]

        inputs[properties.R] = zt

        # recompute neighbors only if necessary (!!!!!!!!! can it be made more efficient ?)
        if recompute_neighbors:
            if not use_cpu:
                inputs = {p: inputs[p].cpu() for p in inputs}

            inputs = re_compute_neighbors(inputs, cutoff=cutoff)

            if not use_cpu:
                inputs = {p: inputs[p].to(device=device) for p in inputs}

        # !!!!!!!! here it s - and not + because the eps was defined as forces which is the negative of noise (should be done wherever eps is used)

        # append time if needed to predict noise
        inputs["diff_step"] = t_array.flatten()[inputs[properties.idx_m]]

        eps_t = model(inputs)["eps_pred"].detach()

        if use_forces:
            eps_t = -eps_t
        # eps_t = torch.randn_like(eps_t)

        mu = (1.0 / sqrt_alpha_t) * (zt - (beta_t / sqrt_beta_t_bar) * eps_t)
        # x_0 = (1. / sqrt_alpha_t_bar) *  (zt - sqrt_beta_t_bar * eps_t)
        # mu = sqrt_alpha_t * (beta_s_bar / beta_t_bar) * zt + (alpha_s_bar ** 0.5 * beta_t) / beta_t_bar * x_0

        # Compute sigma for p(zs | zt).
        # as in diff papers for images no noise when t=0 meaning for alpha_1 because x_0 will be the original end output meaning the mean and no noise
        # however this doesnt exist in E3D imp because s=0 is z_0 meaning still not the original x,gh which will be then sampled after the loop from z_0
        # normally for this imp not like E3D the final x,h are sampled for for s=0
        if t == 0:
            sigma = torch.zeros_like(
                sigma, device=device
            )  # no noise when t=0, use the mean as the final sample
        else:
            sigma = sqrt_sigma_t

        # Sample zs given the paramters derived from zt.
        eps_x = torch.randn_like(eps_t)
        eps_x = batch_center_systems(
            eps_x, inputs[properties.idx_m], inputs[properties.n_atoms]
        )
        zs = mu + sigma * eps_x

        # (!!! is this needed when noise already centered ?)Project down to avoid numerical runaway of the center of gravity.
        zs = batch_center_systems(
            zs, inputs[properties.idx_m], inputs[properties.n_atoms]
        )

        if save_progress and t % progress_stride == 0:
            molecules[t // progress_stride] = zs

        zt = zs  # next step

    if not use_cpu:
        return zt.cpu(), molecules.cpu() if molecules is not None else molecules
    else:
        return zt, molecules


def sample_R_time(
    inputs,
    model,
    noise_schedule,
    cutoff=5.0,
    T=1000,
    start=None,
    random=True,
    use_forces=True,
    save_progress=True,
    progress_stride=1,
    use_cpu=False,
    recompute_neighbors=False,
    aggregate_atomwise=True,
    max_steps=2000,
    convergence_step=0,
    check_stability=False,
    bonds_data=None,
    min_steps=0,
    return_stability=False
):
    """
    function to sample R from the model predicting time
    """

    device = inputs[properties.R].device

    if check_stability and bonds_data is None:
        bonds_data = generate_bonds_data()

    if start is None:
        start = T - 1
    
    if start < 0:
        zt  = inputs[properties.R].clone()
    elif random:
        zt = torch.randn_like(inputs[properties.R])
    else:
        zt = sample_prior_positions(inputs, noise_schedule, start)

    zt = batch_center_systems(zt, inputs[properties.idx_m], inputs[properties.n_atoms])
    inputs[properties.R] = zt

    # recompute neighbors once such that all atoms are neighbors of themselves
    if not recompute_neighbors:
        if not use_cpu:
            inputs = {p: inputs[p].cpu() for p in inputs}

        inputs = re_compute_neighbors(inputs, cutoff=5000.0)

        if not use_cpu:
            inputs = {p: inputs[p].to(device=device) for p in inputs}

    if save_progress:
        molecules = [zt.clone()]
        time_steps = [
            torch.full(
                (len(inputs[properties.idx_m]),), fill_value=start, device=device
            )
        ]
    else:
        molecules = None
        time_steps = None

    num_steps = torch.zeros(
        len(inputs[properties.n_atoms]), dtype=torch.long, device=device
    )

    i = 0
    mask_not_stable = torch.tensor([True] * len(zt), device=device)
    pbar = tqdm(total=max_steps)
    while i < max_steps:

        # predict time and noise
        inputs[properties.R].requires_grad = True
        pred = model(inputs)
        eps_t = pred["eps_pred"].detach()
        t_array = pred["diff_step_pred"].detach()
        inputs[properties.R].requires_grad = False
        
        if len(t_array.shape) == 2:
            t_array = t_array.argmax(-1) / (T * 1.0)

        # !!!!!!!! here it s - and not + because the eps was defined as forces which is the negative of noise (should be done wherever eps is used)
        if use_forces:
            eps_t = -eps_t

        # average if atomwise prediction
        if aggregate_atomwise and len(t_array) == len(inputs[properties.idx_m]):
            t_array = scatter_mean(
                t_array, inputs[properties.idx_m], inputs[properties.n_atoms]
            )

        # next step
        s_array = t_array - 1.0 / T

        if len(t_array) != len(inputs[properties.idx_m]):
            t_array = t_array[inputs[properties.idx_m]]
            s_array = s_array[inputs[properties.idx_m]]

        # denoise only non-stable molecules/atoms

        t = torch.round(t_array * T).long()

        s_array = torch.where(t <= 0, 0.0, s_array)

        # load noise parameters
        t_noise_params_train = noise_schedule(t_array, stage="fit")
        t_noise_params_test = noise_schedule(t_array, stage="test")
        sqrt_beta_t_bar = t_noise_params_train["sqrt_beta_bar"][:, None]
        beta_t = t_noise_params_test["beta_full"][:, None]
        sqrt_alpha_t = t_noise_params_test["sqrt_alpha_full"][:, None]
        sqrt_sigma_t = t_noise_params_test["sqrt_sigma"][:, None]

        # Uncomment the following if predicting position instead of noise
        # alpha_s_bar = s_noise_params_train['alpha_bar'][inputs[properties.idx_m]]
        # beta_s_bar = s_noise_params_train['beta_bar'][inputs[properties.idx_m]]
        # beta_t_bar = t_noise_params_train['beta_bar'][inputs[properties.idx_m]]
        # sqrt_alpha_t_bar = t_noise_params_train['sqrt_alpha_bar'][inputs[properties.idx_m]]

        # ToDo: pre-compute the coefficients in the noise schedule to overcome rounding errors using float32 !!!
        mu = (1.0 / sqrt_alpha_t) * (zt - (beta_t / sqrt_beta_t_bar) * eps_t)

        # x_0 = (1. / sqrt_alpha_t_bar) *  (zt - sqrt_beta_t_bar * eps_t)
        # mu = sqrt_alpha_t * (beta_s_bar / beta_t_bar) * zt + (alpha_s_bar ** 0.5 * beta_t) / beta_t_bar * x_0

        # Compute sigma for p(zs | zt).
        # as in diff papers for images no noise when t=0 meaning for alpha_1 because x_0 will be the original end output meaning the mean and no noise
        # however this doesnt exist in E3D imp because s=0 is z_0 meaning still not the original x,gh which will be then sampled after the loop from z_0
        # normally for this imp not like E3D the final x,h are sampled for for s=0
        sigma = torch.zeros_like(sqrt_sigma_t, device=device)
        _mask = t != 0  # no noise when t=0, use the mean as the final sample
        sigma[_mask] = sqrt_sigma_t[_mask]

        # Sample zs given the paramters derived from zt.
        eps_x = torch.randn_like(eps_t, device=device)
        eps_x = batch_center_systems(
            eps_x, inputs[properties.idx_m], inputs[properties.n_atoms]
        )
        zs = mu + sigma * eps_x

        # (!!! is this needed when noise already centered ?)Project down to avoid numerical runaway of the center of gravity.
        zs = batch_center_systems(
            zs, inputs[properties.idx_m], inputs[properties.n_atoms]
        )

        zt[mask_not_stable] = zs[mask_not_stable]  # next step
        inputs[properties.R] = zt

        # recompute neighbors only if necessary (!!!!!!!!! can it be made more efficient ?)
        if recompute_neighbors:
            if not use_cpu:
                inputs = {p: inputs[p].cpu() for p in inputs}

            inputs = re_compute_neighbors(inputs, cutoff=cutoff)

            if not use_cpu:
                inputs = {p: inputs[p].to(device=device) for p in inputs}

        if save_progress and i % progress_stride == 0:
            molecules.append(zt.clone())
            time_steps.append(t)

        mean_t = scatter_mean(t, inputs[properties.idx_m], inputs[properties.n_atoms])
        done = torch.round(mean_t) <= convergence_step

        if check_stability and i > min_steps:
            (
                bonds,
                stable_atoms,
                stable_molecules,
                connected,
                stable_atoms_wo_h,
                stable_molecules_wo_h,
                connected_wo_h,
            ) = check_validity(inputs, *bonds_data.values(), progress_bar=False)
            stable = torch.tensor(stable_molecules, device=device, dtype=torch.bool)
            done = done | stable

        i += 1
        pbar.update(1)

        mask_not_stable = ~done[inputs[properties.idx_m]]
        num_steps[done & (num_steps == 0)] = i

        if done.all():
            break

    num_steps[num_steps == 0] = i

    pbar.close()

    if save_progress:
        time_steps.reverse()
        molecules.reverse()
        time_steps = torch.cat(time_steps).reshape(-1, *time_steps[0].shape).cpu()
        molecules = torch.cat(molecules).reshape(-1, *molecules[0].shape).cpu()
    num_steps = num_steps.cpu()

    if check_stability and return_stability:
        res_stability = {
            "bonds": bonds,
            "stable_atoms": stable_atoms,
            "stable_molecules": stable_molecules,
            "connected": connected,
            "stable_atoms_wo_h": stable_atoms_wo_h,
            "stable_molecules_wo_h": stable_molecules_wo_h,
            "connected_wo_h": connected_wo_h,
        }
    else:
        res_stability = None

    if not use_cpu:
        return zt.cpu(), molecules, time_steps, num_steps, res_stability
    else:
        return zt, molecules, time_steps, num_steps, res_stability


def sample_multinomial(dist: torch.tensor, categories: torch.tensor = None):
    samples_idx = torch.multinomial(dist, 1, replacement=True).squeeze(-1)
    samples = categories[samples_idx] if categories is not None else samples_idx
    return samples