import torch
from torch import nn as nn


class MLP(nn.Module):
    def __init__(
        self,
        in_dim: int,
        hid_dim: int,
        n_layers: int,
        out_dim: int,
        dropout: float = 0.1,
        center_init: bool = True,
        mult_output: float = 1.0,
        n_fourier_features: int = None,
        fourier_sigma: float = None,
        learnable_range: bool = False,
        sigmoid_output: bool = False,
    ):
        super(MLP, self).__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.out_dim = out_dim
        self.dropout = dropout
        self.center_init = center_init
        self.learnable_range = learnable_range
        self.sigmoid_output = sigmoid_output
        self.mult_output = mult_output
        if center_init:
            self.register_buffer("centering", torch.zeros(1, self.out_dim))
            self.register_buffer("uninitialized_buffer", torch.tensor(True))
        if learnable_range:
            self.ranges = nn.Parameter(torch.ones(self.out_dim, 2))

        # self.center_data = None if center_init else 0.0
        self.n_fourier_features = n_fourier_features
        self.fourier_sigma = fourier_sigma
        self._build_network()

    def _build_network(self):
        layers = []

        # Add input layer

        if self.fourier_sigma is not None:
            layers.append(
                FourierFeatureLayer(self.in_dim, self.n_fourier_features, scale=self.fourier_sigma)
            )
        in_dim = self.in_dim if self.fourier_sigma is None else 2 * self.n_fourier_features

        layers.append(nn.Linear(in_dim, self.hid_dim))
        layers.append(nn.SiLU())
        # layers.append(Swish())
        self.model1 = nn.Sequential(*layers)

        # Add hidden layers
        layers = []
        for i in range(1, self.n_layers):
            layers.append(nn.Linear(self.hid_dim, self.hid_dim))
            layers.append(nn.Dropout(p=self.dropout))
            layers.append(nn.LeakyReLU())
            # layers.append(Swish())
        self.model2 = nn.Sequential(*layers)

        # Add output layer
        self.model3 = nn.Sequential(nn.Linear(self.hid_dim, self.out_dim))

    def forward(self, x, context=None):
        x_skip = self.model1(x)
        x = self.model2(x_skip)
        y = self.model3(x + x_skip)

        if self.center_init:
            if self.uninitialized_buffer.item():
                self.centering = y.mean(0, keepdim=True).detach()
                self.uninitialized_buffer = torch.tensor(False)
                # self.register_buffer("centering", v)

            y = y - self.centering

        output = self.mult_output * y

        if self.learnable_range:
            # idx = torch.tensor([1, 0, 1, 0, 0, 0, 1, 1], device=x.device)
            sigmoid = torch.nn.Sigmoid()
            softplus = torch.nn.Softplus()
            right_range = softplus(self.ranges[:, 1])
            left_range = -softplus(self.ranges[:, 0])
            # output = (1 - idx) * y + idx * (sigmoid(y) * (right_range - left_range) + left_range)
            output = sigmoid(y) * (right_range - left_range) + left_range

        # sigmoid = torch.nn.Sigmoid()
        output = (
            torch.sigmoid(self.mult_output * y)
            if self.sigmoid_output and not self.learnable_range
            else output
        )
        return output
        return output


class MLPSkip(nn.Module):
    def __init__(self, input_dim, hid_dim, output_dim, n_layers, activation=nn.ReLU()):
        super().__init__()

        assert n_layers >= 1, "Need at least one hidden layer"
        assert isinstance(
            activation, nn.Module
        ), "activation must be a PyTorch module (e.g., nn.ReLU())"

        self.activation = activation
        self.input_layer = nn.Linear(input_dim, hid_dim)

        self.hidden_layers = nn.ModuleList([nn.Linear(hid_dim, hid_dim) for _ in range(n_layers)])

        self.output_layer = nn.Linear(hid_dim, output_dim)

    def forward(self, x, context=None):
        x = self.activation(self.input_layer(x))
        skip = x

        for layer in self.hidden_layers:
            out = self.activation(layer(x))
            x = out + skip
            skip = x

        return self.output_layer(x)


class Swish(nn.Module):
    # Code taken from https://github.com/facebookresearch/flow_matching.git
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.sigmoid(x) * x


class MLPSwish(nn.Module):
    # Code taken from https://github.com/facebookresearch/flow_matching.git
    def __init__(
        self, input_dim: int = 2, time_dim: int = 1, n_layers: int = 5, hidden_dim: int = 128
    ):
        super().__init__()

        self.input_dim = input_dim
        self.time_dim = time_dim
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim

        layers = []
        for i in range(self.n_layers):
            if i == 0:
                layers.append(nn.Linear(input_dim + time_dim, hidden_dim))
                layers.append(Swish())
            elif i < self.n_layers - 1:
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                layers.append(Swish())
            else:
                layers.append(nn.Linear(hidden_dim, input_dim))
        self.main = nn.Sequential(*layers)

    def forward(self, x, t):
        sz = x.size()
        x = x.reshape(-1, self.input_dim)
        t = t.reshape(-1, self.time_dim).float()

        t = t.reshape(-1, 1).expand(x.shape[0], 1)
        h = torch.cat([x, t], dim=1)
        output = self.main(h)

        return output.reshape(*sz)


class FourierFeatureLayer(nn.Module):
    def __init__(self, input_dim, output_dim, scale=10.0):
        super().__init__()
        self.B = nn.Parameter(torch.randn(input_dim, output_dim) * scale, requires_grad=False)

    def forward(self, x):
        # x = torch.special.expit(x)
        x_proj = x @ self.B  # (batch_size, output_dim)
        # return torch.cat([torch.sin(2*torch.pi*x_proj), torch.cos(2*torch.pi*x_proj)], dim=-1)  # (batch_size, 2 * output_dim)
        return torch.cat(
            [torch.sin(x_proj), torch.cos(x_proj)], dim=-1
        )  # (batch_size, 2 * output_dim)
