import torch
from typing import Optional

import numpy as np
import torch
import hienet._keys as KEY
from hienet._const import AtomGraphDataType

def add_gaussian_noise_to_position(
    batch, std, corrupt_ratio=None, all_atoms=True
):
    """
    1.  Update `pos` in `batch`.
    2.  Add `noise_vec` to `batch`, which will serve as the target for denoising positions.
    3.  Add `denoising_pos_forward` to switch to denoising mode during training.
    4.  Add `noise_mask` for partially corrupted structures when `corrupt_ratio` is not None.
    5.  If `all_atoms` == True, we add noise to all atoms including fixed ones.
    6.  Check whether `batch` has `md`. We do not add noise to structures from MD split.
    """
    noise_vec = torch.zeros_like(batch[KEY.POS])
    noise_vec = noise_vec.normal_(mean=0.0, std=std)

    if corrupt_ratio is not None:
        noise_mask = torch.rand(
            (batch[KEY.POS].shape[0]),
            dtype=batch[KEY.POS].dtype,
            device=batch[KEY.POS].device,
        )
        noise_mask = noise_mask < corrupt_ratio
        noise_vec[(~noise_mask)] *= 0
        batch[KEY.NOISE_MASK] = noise_mask
    
    # Not add noise to structures from MD split
    # if hasattr(batch, 'md'):
    #     batch_index = batch.batch
    #     md_index = batch.md.bool()
    #     md_index = md_index[batch_index]
    #     noise_mask = (~md_index)
    #     noise_vec[(~noise_mask)] *= 0
    #     if hasattr(batch, 'noise_mask'):
    #         batch.noise_mask = batch.noise_mask * noise_mask
    #     else:
    #         batch.noise_mask = noise_mask

    pos = batch[KEY.POS]
    new_pos = pos + noise_vec
    if all_atoms:
        batch[KEY.POS] = new_pos
    else:
        # What do fixed atoms represent?
        free_mask = batch.fixed == 0.0
        batch[KEY.POS][free_mask] = new_pos[free_mask]

    batch[KEY.NOISE_VEC] = noise_vec
    batch[KEY.DENOISNG_POS_FORWARD] = True

    return batch