import torch
import torchsde
import torch.nn as nn
from constants import *
import numpy as np



class PerturbSDE(torch.nn.Module):
    noise_type = 'general'
    sde_type = 'ito'

    def __init__(self,batch_size, data_dim, brownian_size = 2, latent_dim=200):
        super().__init__()
        self.B_1 = nn.Parameter(torch.tensor(np.random.normal(loc=0, scale=0.001, size=(latent_dim, data_dim)), dtype=torch.float32))
        self.B_0 = nn.Parameter(torch.tensor(np.random.normal(loc=0, scale=0.001, size=(data_dim, latent_dim)), dtype=torch.float32))
        self.W = nn.Parameter(torch.tensor(np.random.normal(loc=0, scale=0.001, size=(1, data_dim)), dtype=torch.float32))
        self.f_ = torch.zeros(data_dim, dtype=torch.float32).to(device)
        self.intervention = torch.ones(data_dim, dtype=torch.float32).to(device)

        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid() 
        self.K = nn.Parameter(torch.randn(1, latent_dim, device=device, dtype=torch.float32) * 0.001)
        self.sigmoid_scale = nn.Parameter(torch.randn(1, latent_dim, device=device, dtype=torch.float32) * 0.001)

        self.batch_size = batch_size
        self.data_dim = data_dim
        self.brownian_size = brownian_size

        self.noise_scale = nn.Parameter(torch.tensor(0.1))

    def f(self, t, y):
        diag_intervention = torch.diagflat(self.intervention)
        W_non_negative = torch.relu(self.W)  
        diag_W = torch.diagflat(W_non_negative + 0.01)
        K_non_negative = self.softplus(self.K)

        return self.sigmoid(self.sigmoid_scale * (y @ self.B_0 - K_non_negative)) @ self.B_1 @ diag_intervention - y @ diag_W + self.f_

    def g(self, t, y):
        constant_noise = self.noise_scale * torch.full((self.batch_size, self.data_dim, self.brownian_size), 0.1).to(device)
        return constant_noise



class PerturbSDE_monotonic(torch.nn.Module):
    noise_type = 'general'
    sde_type = 'ito'

    def __init__(self,batch_size, data_dim, brownian_size = 2, latent_dim=200):
        super().__init__()

        self.A = torch.nn.Parameter(torch.zeros(1,latent_dim,400)).to(device)
        self.B = torch.nn.Parameter(torch.zeros(1,latent_dim,400)).to(device)
        self.c = torch.nn.Parameter(torch.zeros(1,1,400)).to(device)
        self.d = torch.nn.Parameter(torch.zeros(1,latent_dim)).to(device)
        self.act = nn.Sigmoid()

        self.B_1 = nn.Parameter(torch.tensor(np.random.normal(loc=0, scale=0.001, size=(latent_dim, data_dim)), dtype=torch.float32))
        self.B_0 = nn.Parameter(torch.tensor(np.random.normal(loc=0, scale=0.001, size=(data_dim, latent_dim)), dtype=torch.float32))
        self.W = nn.Parameter(torch.tensor(np.random.normal(loc=0, scale=0.001, size=(1, data_dim)), dtype=torch.float32))
        self.f_ = torch.zeros(data_dim, dtype=torch.float32).to(device)
        self.intervention = torch.ones(data_dim, dtype=torch.float32).to(device)

        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()
        self.K = nn.Parameter(torch.randn(1, latent_dim, device=device, dtype=torch.float32) * 0.001)

        self.batch_size = batch_size
        self.data_dim = data_dim
        self.brownian_size = brownian_size

    def f(self, t, y):
        diag_intervention = torch.diagflat(self.intervention)
        W_non_negative = torch.relu(self.W)  
        diag_W = torch.diagflat(W_non_negative + 0.01)
        K_non_negative = self.softplus(self.K)

        module_signal = (y @ self.B_0 - K_non_negative)

        z = module_signal.unsqueeze(-1)
        z = self.B * z + self.c
        z = self.act(z)
        z = z * self.A
        z = torch.sum(z, dim = -1) + self.d

        return z @ self.B_1 @ diag_intervention - y @ diag_W + self.f_

    def g(self, t, y):
        constant_noise = torch.full((self.batch_size, self.data_dim, self.brownian_size), 0.1).to(device)
        return constant_noise