import torch
import numpy as np
import torch.nn.functional as F
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.layer import (new_layer_config,
                                                   BatchNorm1dNode)


class FeatureEncoder(torch.nn.Module):
    """
    Encoding node and edge features

    Args:
        dim_in (int): Input feature dimension
    """
    def __init__(self, dim_in, model_cfg=None):
        super(FeatureEncoder, self).__init__()
        self.dim_in = dim_in
        if model_cfg is None:
            model_cfg = cfg

        # Encode integer node features via nn.Embeddings
        self.node_encoder = register.node_encoder_dict[model_cfg.dataset.node_encoder_name](model_cfg.gnn.dim_inner, model_cfg=model_cfg)
        if model_cfg.dataset.node_encoder_bn:
            self.node_encoder_bn = BatchNorm1dNode(
                new_layer_config(model_cfg.gnn.dim_inner, -1, -1, has_act=False,
                                    has_bias=False, cfg=model_cfg))
        # Update dim_in to reflect the new dimension fo the node features
        self.dim_in = model_cfg.gnn.dim_inner
        model_cfg.gnn.dim_edge = model_cfg.gnn.dim_inner

        # Encode integer edge features via nn.Embeddings
        self.edge_encoder = register.edge_encoder_dict[model_cfg.dataset.edge_encoder_name](model_cfg.gnn.dim_edge, model_cfg)
        if model_cfg.dataset.edge_encoder_bn:
            self.edge_encoder_bn = BatchNorm1dNode(
                new_layer_config(model_cfg.gnn.dim_edge, -1, -1, has_act=False,
                                    has_bias=False, cfg=model_cfg))

    def forward(self, batch, unconditional_prop=0.0):
        for module in self.children():
            if module.__class__.__name__ == 'OCBNodeEncoder':
                kwargs = {'unconditional_prop': unconditional_prop}
            else:
                kwargs = {}
            batch = module(batch, **kwargs)
        return batch
    

class OneHotPerturb(torch.nn.Module):
    """
    The classifier wants continuous inputs - this module allows it to see something else than just one hots.
    """

    def __init__(self):
        super(OneHotPerturb, self).__init__()

    def forward(self, batch):
        if cfg.dataset.task_type == 'generative':
            # Used as a frozen guiding classifier
            return batch
        else:
            # Classifier training

            # x
            oh_x = F.one_hot(batch.x[:, 0], num_classes=cfg.dataset.nnode_types)
            coeff = max(np.random.randn() * 7 + 12.5, 1e-2)
            perturb = torch.rand(oh_x.shape).to(oh_x.device) * coeff * 1e-3
            batch.x = (oh_x + perturb) / (oh_x + perturb).sum(dim=-1, keepdim=True)

            # edge_attr --> the encoder only sees 1 as 0 edges are discarded, thus simply substract a random epsilon
            perturb = (torch.randn(len(batch.edge_attr)).to(batch.edge_attr.device) * 2e-2 + 5e-2).clamp(min=1e-5)
            batch.edge_attr = (batch.edge_attr - perturb)[:, None]

            # x_features
            perturb = (torch.randn(len(batch.x_features)).to(batch.x_features.device) * 1e-2).clamp(min=-5e-2, max=5e-2)
            # perturb = (batch.x_features != 0) * perturb[:, None]
            batch.x_features = batch.x_features + perturb[:, None]

            return batch

            
            
