from typing import Any

import numpy as np
import torch
from torch_geometric.transforms import BaseTransform

class ConstantFeatureReplacement(BaseTransform):

    def __init__(self, c=1, **kwargs):
        super().__init__()
        self.c = c

    def forward(self, data: Any) -> Any:
        constant_features = torch.ones_like(data.x)
        data.x = constant_features

        return data

class StandardNormalFeatureReplacement(BaseTransform):

    def __init__(self, fixed_seed:bool = False, **kwargs):
        super().__init__()
        self.seed = 0
        self.fixed_seed = fixed_seed

    def forward(self, data: Any) -> Any:
        rng = np.random.default_rng(seed=self.seed)
        if not self.fixed_seed:
            self.seed += 1
        standard_normal_features = torch.tensor(rng.normal(loc=0.0, scale=1.0, size=data.x.shape), dtype=torch.float)

        data.x = standard_normal_features

        return data

