import numpy as np
from einops import rearrange, repeat
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import scipy
import math

#torch.set_default_dtype(torch.float16)

ACT_FNS = {
    'relu': F.relu,
    'swish': F.silu,
    'quadratic': lambda x: torch.pow(x, 2),
    'softplus': F.softplus,
    'linear': lambda x: x,
    'hermite2': lambda x: (torch.pow(x, 2) - 1) / math.sqrt(2)
}

class FCN(nn.Module):
    def __init__(self, inp_dim, hidden_width, out_dim, n_hid_layers=1, init_scale=1.0,
                 sigmoid_output=False):
        super().__init__()

        self.inp_dim = inp_dim
        self.hidden_width = hidden_width
        self.out_dim = out_dim
        self.n_hid_layers = n_hid_layers
        self.sigmoid_output = sigmoid_output

        self.fc1 = nn.Linear(inp_dim, hidden_width, bias=False)
        self.layers = nn.ModuleList()
        for _ in range(n_hid_layers - 1):
            self.layers.append(nn.Linear(hidden_width, hidden_width, bias=False))
        self.out = nn.Linear(hidden_width, out_dim, bias=False)

        if init_scale != 1.0:
            self.reset_params(init_scale=init_scale)

    def reset_params(self, init_scale=1.0):
        # scaled kaiming uniform code:
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.fc1.weight)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.fc1.weight, -init_scale*bound, init_scale*bound)

        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.out.weight)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.out.weight, -init_scale*bound, init_scale*bound)

        # scaled kaiming normal code:
        '''
        leaky_neg_slope = 0.
        fan = nn.init._calculate_correct_fan(self.fc1.weight, "fan_in")
        gain = nn.init.calculate_gain("leaky_relu", leaky_neg_slope)
        std = gain/math.sqrt(fan)
        nn.init.normal_(self.fc1.weight, mean=0.0, std=init_scale*std)

        fan = nn.init._calculate_correct_fan(self.out.weight, "fan_in")
        gain = nn.init.calculate_gain("leaky_relu", leaky_neg_slope)
        std = gain/math.sqrt(fan)
        nn.init.normal_(self.out.weight, mean=0.0, std=init_scale*std)
        '''

    def forward(self, x, dumb1=None, act_fn='relu'):
        act_fn = ACT_FNS[act_fn]

        if dumb1 is None:
            x = act_fn(self.fc1(x))
            for layer in self.layers:
                x = act_fn(layer(x))

            if self.sigmoid_output:
                return F.sigmoid(self.out(x))
            return self.out(x)

        x = act_fn(self.fc1(x) + dumb1 @ self.fc1.weight.t())
        for layer in self.layers:
            x = act_fn(layer(x))
        x = self.out(x)
    
        if self.sigmoid_output:
            x = F.sigmoid(x)
        return x
