from torch import nn
import torch

class WMLP(nn.Module):
    def __init__(self, dim, out_dim=None, w=32, time_varying=False):
        super().__init__()
        self.time_varying = time_varying
        if out_dim is None:
            out_dim = dim
        self.net = torch.nn.Sequential(
            #torch.nn.BatchNorm1d(dim + (1 if time_varying else 0)),
            torch.nn.Linear(dim + (1 if time_varying else 0), w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            #torch.nn.BatchNorm1d(w),  # Batch normalization layer added
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            #torch.nn.BatchNorm1d(w),  # Batch normalization layer added
            torch.nn.SELU(),
            torch.nn.Linear(w, out_dim),
        )

    def forward(self, t, x, *args, **kwargs):
        device = x.device
        t_ = t.view(-1, 1).expand(x.size(0), 1).to(device)
        xt = torch.cat([x, t_], dim=1)
        return self.net(xt)



class IPMLP(nn.Module):
    def __init__(self, dim, w=32):
        super().__init__()
        self.NN = nn.Sequential(
            torch.nn.Linear(dim + 1 , w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, dim),
        )
        
    def forward(self, t, x, *args, **kwargs):
        x = x.requires_grad_(True)
        t_ = t.view(-1, 1).expand(x.size(0), 1).to(x.dtype)
        s = self.NN(torch.cat([x, t_], dim=1))
        energy = (x.T @ s).sum(dim=1, keepdim=True)  # (B,1)
        grad_x = torch.autograd.grad(
            outputs=energy,
            inputs=x,
            grad_outputs=torch.ones_like(energy),
            create_graph=True,
            retain_graph=True
        )[0]
        
        return grad_x

class GMLP(nn.Module):
    def __init__(self, dim, w=64):
        super().__init__()
        self.NN = nn.Sequential(
            nn.Linear(dim + 1, w),
            nn.SELU(),
            nn.Linear(w, w),
            nn.SELU(),
            nn.Linear(w, w),
            nn.SELU(),
            nn.Linear(w, dim),
        )
        
    def forward(self, t, x, *args, **kwargs):
        # Make sure we're on the right device and dtype
        device = x.device
        dtype = x.dtype
        
        # Create a copy of x that requires gradients
        x = x.detach().clone().to(device).requires_grad_(True)
        
        # Reshape t and ensure it's on the same device and dtype
        t_ = t.view(-1, 1).expand(x.size(0), 1).to(device=device, dtype=dtype)
        
        # Forward pass through the network
        s = self.NN(torch.cat([x, t_], dim=1))
        energy = s.sum(dim=1, keepdim=True)
        
        # Compute gradients with explicit device handling
        ones = torch.ones_like(energy, device=device)
        grad_x = torch.autograd.grad(
            outputs=energy,
            inputs=x,
            grad_outputs=ones,
            create_graph=True,
            retain_graph=True
        )[0]
        
        return grad_x

class EMLP(nn.Module):
    def __init__(self, dim, w=64):
        super().__init__()
        self.NN = nn.Sequential(
            nn.Linear(dim + 1, w),
            nn.SELU(),
            nn.Linear(w, w),
            nn.SELU(),
            nn.Linear(w, w),
            nn.SELU(),
            nn.Linear(w, dim),
        )
        
    def forward(self, t, x, *args, **kwargs):
        # IMPORTANT: Do not transfer tensors between devices during the forward pass
        # Assume everything is already on the correct device
        device = x.device
        
        # Create a fresh tensor that requires gradients, but keep device the same
        x_with_grad = x.detach().requires_grad_(True)
        
        # Reshape t for concatenation
        t_shaped = t.view(-1, 1).expand(x.size(0), 1)
        
        # Forward pass through the network
        s = self.NN(torch.cat([x_with_grad, t_shaped], dim=1))
        energy = ((x_with_grad - s) ** 2).sum(dim=1, keepdim=True)
        
        # Create the gradient outputs tensor on the same device
        ones = torch.ones_like(energy)
        
        # Compute gradients
        grad_x = torch.autograd.grad(
            outputs=energy,
            inputs=x_with_grad,
            grad_outputs=ones,
            create_graph=True,
            retain_graph=True
        )[0]
        return grad_x
