import torch

class SDE(torch.nn.Module):
    """
    Stochastic Differential Equations (SDE) class with drift and diffusion functions.
    
    This class provides vectorized implementations for the drift and diffusion functions
    associated with a given SDE. The implementation is optimized for efficiency using
    PyTorch operations.
    """
    
    noise_type = 'general'
    sde_type = 'ito'
    
    def __init__(self, Sigma, Mu):
        """
        Initializes the SDE module.
        """
        super().__init__()
        self.Sigma = Sigma
        self.Mu = Mu
        self.state_size = 2
        
    def f(self, t, y):
        """
        Drift function.
        
        Args:
        - t (float): The current time.
        - y (torch.Tensor): The state tensor of shape (batch_size, state_size).
        
        Returns:
        - torch.Tensor: The drift tensor of shape (batch_size, state_size).
        """
        f_truth = torch.zeros(y.shape[0], y.shape[1])
        f_truth[:, 0] = self.Mu[0][0] * y[:, 0] + self.Mu[0][1] * y[:, 1]
        f_truth[:, 1] = self.Mu[1][0] * y[:, 0] + self.Mu[1][1] * y[:, 1]
        return f_truth
    
    def g(self, t, y):
        """
        Diffusion function.
        
        Args:
        - t (float): The current time.
        - y (torch.Tensor): The state tensor of shape (batch_size, state_size).
        
        Returns:
        - torch.Tensor: The diffusion tensor of shape (batch_size, state_size, brownian_size).
        """
        Sigma_tensor = torch.tensor(self.Sigma)
        
        g_truth = torch.zeros(y.shape[0], y.shape[1], 2)
        g_truth[:, 0, 0] = Sigma_tensor[0][0] * y[:, 0]
        g_truth[:, 0, 1] = Sigma_tensor[0][1] * y[:, 1]
        g_truth[:, 1, 0] = Sigma_tensor[1][0] * y[:, 0]
        g_truth[:, 1, 1] = Sigma_tensor[1][1] * y[:, 1]
        
        return g_truth

# # Constants (for reference in the file, should be defined by the user before using the class)
# mu1 = 0.5
# mu2 = -0.5
# Sigma = [[0.1, 0.2], [0.3, 0.4]]
# brownian_size = 2
