import numpy as np
import torch
from torch import nn


class SineLayer(nn.Module):
    def __init__(
        self, in_features, out_features, bias=True, is_first=False, omega_0=30
    ):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first

        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
            else:
                self.linear.weight.uniform_(
                    -np.sqrt(6 / self.in_features) / self.omega_0,
                    np.sqrt(6 / self.in_features) / self.omega_0,
                )

    def forward(self, input):
        temp = torch.sin(self.omega_0 * self.linear(input))
        return temp

    def forward_with_intermediate(self, input):
        intermediate = self.omega_0 * self.linear(input)

        return torch.sin(intermediate), intermediate


class Siren(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features,
        hidden_layers,
        out_features,
        outermost_linear=False,
        first_omega_0=30.0,
        hidden_omega_0=30.0,
    ):
        super().__init__()

        self.net = []
        self.net.append(
            SineLayer(
                in_features, hidden_features, is_first=True, omega_0=first_omega_0
            )
        )

        for i in range(hidden_layers):
            self.net.append(
                SineLayer(
                    hidden_features,
                    hidden_features,
                    is_first=False,
                    omega_0=hidden_omega_0,
                )
            )

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)

            with torch.no_grad():
                final_linear.weight.uniform_(
                    -np.sqrt(6 / hidden_features) / hidden_omega_0,
                    np.sqrt(6 / hidden_features) / hidden_omega_0,
                )

            self.net.append(final_linear)
        else:
            self.net.append(
                SineLayer(
                    hidden_features,
                    out_features,
                    is_first=False,
                    omega_0=hidden_omega_0,
                )
            )

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True)
        output = self.net(coords)

        return output, coords


class ParaLayer(nn.Module):
    def __init__(
        self, in_features, out_features, nf, bias=True, is_first=False, omega_0=30
    ):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first

        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.nf = nf

        if nf > 5:

            ws = torch.arange(15, 15 + 25).float()
            self.ws = nn.Parameter(ws, requires_grad=True)
            self.phis = nn.Parameter(requires_grad=True)
            self.bs = nn.Parameter(requires_grad=True)
            self.init_weights()

        else:
            self.linear = nn.Linear(in_features, out_features, bias=bias)
            self.ws = nn.Parameter(torch.ones(nf), requires_grad=True)
            self.bs = nn.Parameter(torch.ones(nf), requires_grad=True)
            self.phis = nn.Parameter(torch.zeros(nf), requires_grad=True)

            self.siren_init_weights()

    def siren_init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
            else:
                self.linear.weight.uniform_(
                    -np.sqrt(6 / self.in_features) / self.omega_0,
                    np.sqrt(6 / self.in_features) / self.omega_0,
                )

    def init_weights(self):
        with torch.no_grad():
            uniform_samples = torch.rand(self.nf)

            # Scale and shift the samples to the range [-π, π]
            lower_bound = -torch.tensor([3.14159265358979323846])  # -π
            upper_bound = torch.tensor([3.14159265358979323846])  # π
            scaled_samples = lower_bound + (upper_bound - lower_bound) * uniform_samples

            self.phis = nn.Parameter(scaled_samples, requires_grad=True)

            # Mean and diversity for Laplace random variable Y
            mean_y = 0
            diversity_y = 2 / (4 * self.nf)
            # Generate Laplace random variable Y
            laplace_samples = torch.distributions.laplace.Laplace(
                mean_y, diversity_y
            ).sample((self.nf,))

            # Compute C from Y
            c_samples = torch.sign(laplace_samples) * torch.sqrt(
                torch.abs(laplace_samples)
            )
            self.bs = nn.Parameter(c_samples, requires_grad=True)

    def forward(self, input):
        temp = self.linear(input)
        return self.param_act(temp)

    def param_act(self, linout):
        ws, bs, phis = (self.ws, self.bs, self.phis)
        linoutx = linout.unsqueeze(-1).repeat_interleave(ws.shape[0], dim=3)
        wsx = ws.expand(linout.shape[0], linout.shape[1], linout.shape[2], -1)
        bsx = bs.expand(linout.shape[0], linout.shape[1], linout.shape[2], -1)
        phisx = phis.expand(linout.shape[0], linout.shape[1], linout.shape[2], -1)
        temp = bsx * (torch.sin((wsx * linoutx) + phisx))
        temp2 = torch.sum(temp, 3)
        return temp2


class STAFNet(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features,
        hidden_layers,
        out_features,
        nf,
        outermost_linear=False,
        first_omega_0=30.0,
        hidden_omega_0=30.0,
    ):
        super().__init__()

        self.net = []
        self.net.append(
            ParaLayer(
                in_features,
                hidden_features,
                is_first=True,
                nf=nf,
                omega_0=first_omega_0,
            )
        )

        for i in range(hidden_layers):
            self.net.append(
                ParaLayer(
                    hidden_features,
                    hidden_features,
                    nf=nf,
                    is_first=False,
                    omega_0=hidden_omega_0,
                )
            )

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)

            self.net.append(final_linear)
        else:
            self.net.append(
                ParaLayer(
                    hidden_features,
                    out_features,
                    nf=nf,
                    is_first=False,
                    omega_0=hidden_omega_0,
                )
            )

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        coords = coords.clone().detach().requires_grad_(True)
        output = self.net(coords)
        return output, coords
