import sys
sys.path.append("/here/is/code/M2F-PINN")
import torch
import torch.nn as nn
import math


class TrainableFourierFeatureEmbedding(nn.Module):
    def __init__(self, input_dims, mapping_size, initial_scale=10.0, device=None):
        super().__init__()
        self.input_dims = input_dims
        self.mapping_size = mapping_size
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")

        b_init = torch.randn((mapping_size, input_dims))
        self.B = nn.Parameter(b_init.to(self.device))

        self.scale = nn.Parameter(torch.tensor([float(initial_scale)], device=self.device))

    def forward(self, x):
        x = x.to(self.device).float()
        effective_B = self.B * self.scale
        x_proj = x @ effective_B.T
        return torch.cat([torch.sin(2. * math.pi * x_proj), torch.cos(2. * math.pi * x_proj)], dim=-1)

    @property
    def output_dims(self):
        return self.mapping_size * 2


class FourierFeatureEmbedding(nn.Module):
    def __init__(self, input_dims, mapping_size, scale=10.0, B_matrix=None, device=None):
        super().__init__()
        self.input_dims = input_dims
        self.mapping_size = mapping_size
        self.scale = scale
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")

        if B_matrix is not None:
            if B_matrix.shape != (mapping_size, input_dims):
                raise ValueError(f"Provided B_matrix has shape {B_matrix.shape}, expected {(mapping_size, input_dims)}")
            self.register_buffer('B', B_matrix.to(self.device))
        else:
            # Initialize B matrix if not provided
            B = torch.randn((mapping_size, input_dims)) * scale
            self.register_buffer('B', B.to(self.device))  # B is (mapping_size, input_dims)

    def forward(self, x):
        # x is expected to be (..., input_dims) and on the same device as B
        x = x.to(self.device)  # Ensure x is on the correct device
        if self.B is None:
            return x

        # Project x onto B directions
        # x @ self.B.T results in shape (..., mapping_size)
        # Ensure x is float for matmul
        x_proj = x.float() @ self.B.T

        # Output is (..., 2 * mapping_size)
        return torch.cat([torch.sin(2. * math.pi * x_proj), torch.cos(2. * math.pi * x_proj)], dim=-1)

    @property
    def output_dims(self):
        return self.mapping_size * 2 if self.B is not None else self.input_dims


class FeedForward(nn.Module):
    def __init__(self, input_dimensions, output_dimensions, layers_config, device=None): # Added device
        super().__init__()
        self.layers = nn.ModuleList()
        current_dim = input_dimensions
        for hidden_dim in layers_config:
            self.layers.append(nn.Linear(current_dim, hidden_dim))
            self.layers.append(nn.GELU()) # Or GELU, SiLU etc.
            current_dim = hidden_dim
        self.layers.append(nn.Linear(current_dim, output_dimensions))
        self.to(device) # Move all layers to device

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x