from typing import List
import torch.nn as nn
import torch
from greatx.nn.layers import Sequential, activations
from greatx.nn.layers.rung_conv import RUNGConv 


def get_mcp_att_func(gamma, ep=0.01):
    def att(w):
        w += ep
        z = w.sqrt()
        high_idx = torch.where(z > gamma)
        z[z <= gamma] = 1 / (2 * (z[z <= gamma])) - 1 / (2 * gamma)
        z[high_idx] = 0
        return z
    return att


class RUNG(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 hids: List[int] = [64],
                 w_func: callable = None,
                 lam_hat: float = 0.9,
                 dropout: float = 0.5,
                 gamma: float = 0.5,
                 ep: float = 0.01,
                 quasi_newton: bool = True,
                 eta: float = 0.01,
                 prop_step: int = 10):

        super().__init__()

        if w_func is None:
            # Default to MCP with gamma=0.5 and ep=0.01
            w_func = get_mcp_att_func(gamma=gamma, ep=ep)

        layers = []
        for hid in hids:
            layers.append(nn.Linear(in_channels, hid))
            layers.append(activations.get('relu'))
            layers.append(nn.Dropout(dropout))
            in_channels = hid
        layers.append(nn.Linear(in_channels, out_channels))
        self.mlp = Sequential(*layers)

        self.rung_conv = RUNGConv(
            lam_hat=lam_hat,
            w_func=w_func,
            quasi_newton=quasi_newton,
            eta=eta,
            prop_step=prop_step
        )

    def reset_parameters(self):
        self.mlp.reset_parameters()
        self.rung_conv.reset_parameters()

    def forward(self, x, edge_index, edge_weight=None):
        x = self.mlp(x)
        x = self.rung_conv(x, edge_index, edge_weight)
        return x
