import torch
import math
import numpy as np
import torch.nn as nn
import torch.nn.functional as F


def get_timestep_embedding(timesteps, embedding_dim=100, max_positions=10000.0):
# def get_timestep_embedding(timesteps, embedding_dim=100, max_positions=2.0):
  timesteps = timesteps.squeeze()
  timesteps = timesteps * 1000
  assert len(timesteps.shape) == 1  # and timesteps.dtype == tf.int32
  half_dim = embedding_dim // 2
  # magic number 10000 is from transformers
  emb = math.log(max_positions) / (half_dim - 1)
  # emb = math.log(2.) / (half_dim - 1)
  emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) # * max_positions
  # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
  # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
  emb = timesteps.float()[:, None] * emb[None, :]
  # emb = torch.cat([torch.sin(emb)[:,:50], torch.cos(emb)[:, :50]], dim=1)
  # emb = torch.cat([torch.sin(emb)[:,-50:], torch.cos(emb)[:, -50:]], dim=1)
  emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
  if embedding_dim % 2 == 1:  # zero pad
    emb = F.pad(emb, (0, 1), mode='constant')
  # assert emb.shape == (timesteps.shape[0], embedding_dim)
  return emb

# def my_timestep_embedding(timesteps, embedding_dim=100, max_freq=2*np.pi*50+50-30, min_freq=50.0):
def my_timestep_embedding(timesteps, embedding_dim=100, max_freq=64, min_freq=32):
    timesteps = timesteps.squeeze()
    assert len(timesteps.shape) == 1  # and timesteps.dtype == tf.int32
    half_dim = embedding_dim // 2
    emb = torch.linspace(min_freq, max_freq, half_dim).to(timesteps.device)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = F.pad(emb, (0, 1), mode='constant')
    return emb

def my_timestep_embedding2(timesteps, embedding_dim=300):
    timesteps = timesteps.squeeze()
    half_dim = embedding_dim // 2
    mask = (timesteps < 0.7).float()
    min_freq = 32
    max_freq = 64
    # max_freq = 2 * np.pi * half_dim + min_freq - 30
    emb = torch.linspace(min_freq, max_freq, half_dim).to(timesteps.device)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = F.pad(emb, (0, 1), mode='constant')

    min_freq = 100
    max_freq = 200.0
    # max_freq = 2 * np.pi * half_dim + min_freq - 30
    emb2 = torch.linspace(min_freq, max_freq, half_dim).to(timesteps.device)
    # emb2 = (timesteps + np.pi).float()[:, None] * emb2[None, :]
    emb2 = timesteps.float()[:, None] * emb2[None, :]
    emb2 = torch.cat([torch.sin(emb2), torch.cos(emb2)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb2 = F.pad(emb2, (0, 1), mode='constant')
    return mask[:,None] * emb + (1-mask)[:, None] * emb2

def my_timestep_embedding3(timesteps, embedding_dim=300):
    timesteps = timesteps.squeeze()
    half_dim = embedding_dim // 2
    mask = (timesteps < 0.5).float() * 2 - 1
    min_freq = 32
    max_freq = 64
    # max_freq = 2 * np.pi * half_dim + min_freq - 30
    emb = torch.linspace(min_freq, max_freq, half_dim-1).to(timesteps.device)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    emb = torch.cat([emb, mask[:, None]], -1)
    emb = F.pad(emb, (0, 1), mode='constant')
    return emb

def my_timestep_embedding4(timesteps, embedding_dim=300, min_freq = 64, max_freq = 84 ):
    timesteps = timesteps.squeeze()
    # max_freq = 2 * np.pi * half_dim + min_freq - 30
    emb = torch.linspace(min_freq, max_freq, embedding_dim).to(timesteps.device)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cos(emb)
    return emb

class GaussianFourierProjection(nn.Module):
  """Gaussian Fourier embeddings for noise levels."""
  def __init__(self, embedding_size=50, scale=16.0):
    super().__init__()
    self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False).to('cuda') + 32

  def forward(self, x):
    # x = torch.log(1-x)
    x_proj = x * self.W[None, :] * 2 * np.pi
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

class MLP1(nn.Module):
    def __init__(self, input_dim=2, hidden_num=100, structure_str='1000-1000-1000-1000'):
        self.hidden_num = hidden_num
        current_random_state = torch.random.get_rng_state()
        # torch.manual_seed(4)
        super().__init__()
        # current_dim = input_dim + 1
        # hidden_dims = [int(i) for i in structure_str.split('-')]
        # self.layers = []
        # for i, hidden_dim in enumerate(hidden_dims):
        #     self.layers.append(nn.Linear(current_dim, hidden_dim, bias=True))
        #     current_dim = hidden_dim
        #     self.layers.append(nn.Tanh())
        # self.layers.append(nn.Linear(current_dim, input_dim))
        # self.model = nn.Sequential(*self.layers)

        self.fc1 = nn.Linear(input_dim, hidden_num, bias=True)
        # self.fc1 = nn.Linear(hidden_num, hidden_num, bias=True)
        # self.fc1 = nn.Linear(input_dim+1, hidden_num, bias=True)
        self.fc2 = nn.Linear(hidden_num, hidden_num, bias=True)
        self.fc3 = nn.Linear(hidden_num, input_dim, bias=True)
        # self.fc2 = nn.Linear(hidden_num * 2, hidden_num, bias=True)
        # self.fc3 = nn.Linear(hidden_num * 2, input_dim, bias=True)
        self.gfe = GaussianFourierProjection(embedding_size=int(hidden_num/2))
        self.tfc = nn.Linear(hidden_num, hidden_num, bias=True)
        self.act = lambda x: torch.tanh(x)
        # self.act = lambda x: torch.relu(x)
        # self.act = lambda x: torch.nn.LeakyReLU()(x)
        torch.random.set_rng_state(current_random_state)

    def forward(self, x_input, t):
        x_input = x_input + 2
        # x_input = my_timestep_embedding(x_input, embedding_dim=self.hidden_num, min_freq=0.1, max_freq=1)
        # log_mean_coeff = -0.25 * (1-t) ** 2 * (20 - 0.1) - 0.5 * (1-t) * 0.1
        # t = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
        # inputs = torch.cat([x_input, std], dim=1)
        # inputs = torch.cat([x_input, t], dim=1)
        inputs = x_input
        # x = self.model(inputs)
        x = self.fc1(inputs)
        # temb = self.tfc(get_timestep_embedding(t))
        temb = self.tfc(my_timestep_embedding(t, embedding_dim=self.hidden_num))
        # temb = self.tfc(my_timestep_embedding3(t, embedding_dim=self.hidden_num))
        # temb = self.tfc(my_timestep_embedding4(t, embedding_dim=self.hidden_num, min_freq=32, max_freq=64))
        # temb = self.tfc(self.gfe(t))
        x = x + temb
        # x = torch.cat([x, temb],1)
        x = self.act(x)
        x = self.fc2(x)
        x = x + temb
        # x = torch.cat([x, temb],1)
        x = self.act(x)
        x = self.fc3(x)
        return x


# residual
class MLP2(nn.Module):
    def __init__(self, input_dim=2, hidden_num=100):
        super().__init__()
        self.fc1 = nn.Linear(input_dim + 1, hidden_num, bias=True)
        self.fc2 = nn.Linear(hidden_num, hidden_num, bias=True)
        self.fc3 = nn.Linear(hidden_num, input_dim, bias=True)

        self.fc4 = nn.Linear(input_dim, hidden_num, bias=True)
        self.fc5 = nn.Linear(hidden_num, hidden_num, bias=True)
        self.fc6 = nn.Linear(hidden_num, input_dim, bias=True)
        self.act = lambda x: torch.tanh(x)

    def forward(self, x_input, t):
        inputs = torch.cat([x_input, t], dim=1)
        x = self.fc1(inputs)
        x = self.act(x)
        x = self.fc2(x)
        x = self.act(x)
        x = self.fc3(x)

        residual = self.fc4(x_input)
        residual = self.act(residual)
        residual = self.fc2(residual)
        residual = self.act(residual)
        residual = self.fc3(residual)

        x = x + residual

        return x

class Repara(nn.Module):
    def __init__(self, input_dim=2, structure_str='1000',
                 residual = False,
                 relu=True, hidden_num=100):
        super().__init__()

        # self.name = f'repara'
        # # 
        # self.fc1_normal = nn.Linear(input_dim + 1, hidden_num, bias=True)
        # self.fc2_normal = nn.Linear(hidden_num, hidden_num, bias=True)
        # self.fc4_normal = nn.Linear(hidden_num, hidden_num, bias=True)
        # self.fc3_normal = nn.Linear(hidden_num, input_dim, bias=True)
        #
        #
        # # 
        # self.fc1_zero = nn.Linear(input_dim + 1, hidden_num, bias=True)
        # self.fc2_zero = nn.Linear(hidden_num, hidden_num, bias=True)
        # self.fc3_zero = nn.Linear(hidden_num, input_dim, bias=True)
        #
        # self.scale = nn.Parameter(torch.ones(1))  # 
        #
        #
        # # 
        # def zero_weights(m):
        #     if isinstance(m, nn.Linear):
        #         # torch.nn.init.constant_(m.weight, 0.0)
        #         # torch.nn.init.constant_(m.bias, 0.0)
        #         torch.nn.init.normal_(m.weight, 0.0, 0.0)
        #         torch.nn.init.normal_(m.bias, 0.0, 0.0)
        #         # torch.nn.init.kaiming_normal_(m.weight, nonlinearity='tanh')
        #         # nn.init.kaiming_normal_(m.weight, nonlinearity='tanh')
        #         # nn.init.kaiming_normal_(m.weight, nonlinearity='tanh')
        #
        # # 
        # self.fc1_zero.apply(zero_weights)
        # self.fc2_zero.apply(zero_weights)
        # self.fc3_zero.apply(zero_weights)
        #
        # self.act = lambda x: torch.tanh(x)

        print('Repara')
        self.name = f'Repara_{structure_str}'
        self.residual = residual
        self.relu = relu
        if self.residual:
            self.name += '_residual'
        if self.relu:
            self.name += '_relu'

        self.model_1 = MLP(input_dim, structure_str, self.residual, self.relu)

    def forward(self, x_input, t):
        inputs = torch.cat([x_input, t], dim=1)
        # mask = (t < 1).float() * self.scale# 

        log_mean_coeff = -0.25 * (1-t) ** 2 * (20 - 0.1) - 0.5 * (1-t) * 0.1
        std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
        mask = std

        # 
        # x_normal = self.fc1_normal(inputs)
        # x_normal = self.act(x_normal)
        # x_normal = self.fc2_normal(x_normal)
        # x_normal = self.act(x_normal)
        # x_normal = self.fc4_normal(x_normal)
        # x_normal = self.act(x_normal)
        # x_normal = self.fc3_normal(x_normal)
        x_normal = self.model_1(x_input, t)

        # 
        # x_zero = self.fc1_zero(inputs)
        # x_zero = self.act(x_zero)
        # x_zero = self.fc2_zero(x_zero)
        # x_zero = self.act(x_zero)
        # x_zero = self.fc3_zero(x_zero)
        # x_zero = x_input + x_zero

        # t
        # for eps-pred
        # return x_normal * (1 - mask) + x_input * mask
        # for score-pred
        return x_normal * (1 - mask) - x_input * mask / std



class SkipConnection(nn.Module):
    def __init__(self, input_dim=2, structure_str='1000',
                 residual = False,
                 relu=True, hidden_num=100):
        super().__init__()

        # self.name = f'repara'
        # # 
        # self.fc1_normal = nn.Linear(input_dim + 1, hidden_num, bias=True)
        # self.fc2_normal = nn.Linear(hidden_num, hidden_num, bias=True)
        # self.fc4_normal = nn.Linear(hidden_num, hidden_num, bias=True)
        # self.fc3_normal = nn.Linear(hidden_num, input_dim, bias=True)
        #
        #
        # # 
        # self.fc1_zero = nn.Linear(input_dim + 1, hidden_num, bias=True)
        # self.fc2_zero = nn.Linear(hidden_num, hidden_num, bias=True)
        # self.fc3_zero = nn.Linear(hidden_num, input_dim, bias=True)

        self.fc1_zero = nn.Linear( 1, 30, bias=True)
        self.fc2_zero = nn.Linear(30, 30, bias=True)
        self.fc3_zero = nn.Linear(30, 1, bias=True)

        self.fc1_one = nn.Linear( 1, 30, bias=True)
        self.fc2_one = nn.Linear(30, 30, bias=True)
        self.fc3_one = nn.Linear(30, 1, bias=True)
        #
        # self.scale = nn.Parameter(torch.ones(1))  # 
        #
        #
        # # 
        # def zero_weights(m):
        #     if isinstance(m, nn.Linear):
        #         # torch.nn.init.constant_(m.weight, 0.0)
        #         # torch.nn.init.constant_(m.bias, 0.0)
        #         torch.nn.init.normal_(m.weight, 0.0, 0.0)
        #         torch.nn.init.normal_(m.bias, 0.0, 0.0)
        #         # torch.nn.init.kaiming_normal_(m.weight, nonlinearity='tanh')
        #         # nn.init.kaiming_normal_(m.weight, nonlinearity='tanh')
        #         # nn.init.kaiming_normal_(m.weight, nonlinearity='tanh')
        #
        # # 
        # self.fc1_zero.apply(zero_weights)
        # self.fc2_zero.apply(zero_weights)
        # self.fc3_zero.apply(zero_weights)
        #
        self.act = lambda x: torch.tanh(x)

        print('Skip Connection')
        self.name = f'SC_{structure_str}'
        self.residual = residual
        self.relu = relu
        if self.residual:
            self.name += '_residual'
        if self.relu:
            self.name += '_relu'

        self.model_1 = MLP(input_dim, structure_str, self.residual, self.relu)

    def forward(self, x_input, t):

        c1 = self.act(self.fc1_zero(t))
        c1 = self.act(self.fc2_zero(c1))
        c1 = self.fc3_zero(c1)

        c2 = self.act(self.fc1_one(t))
        c2 = self.act(self.fc2_one(c2))
        c2 = self.fc3_one(c2)

        log_mean_coeff = -0.25 * (1-t) ** 2 * (20 - 0.1) - 0.5 * (1-t) * 0.1
        std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
        mask = std
        x_normal = self.model_1(x_input, t)

        return x_normal * c1 - x_input * c2 / std

class MLP(nn.Module):
    def __init__(self, input_dim=2, hidden_num=10):
        super().__init__()
        # 
        self.relu = True
        self.name = f'MLP_two_model'
        if self.relu:
            self.name += '_relu'

        self.fc1_1 = nn.Linear(input_dim + 1, hidden_num, bias=True)
        self.fc2_1 = nn.Linear(hidden_num, hidden_num, bias=True)
        self.fc4_1 = nn.Linear(hidden_num, hidden_num, bias=True)
        self.fc3_1 = nn.Linear(hidden_num, input_dim, bias=True)

        # 
        self.fc1_2 = nn.Linear(input_dim + 1, hidden_num, bias=True)
        self.fc2_2 = nn.Linear(hidden_num, hidden_num, bias=True)
        self.fc4_2 = nn.Linear(hidden_num, hidden_num, bias=True)
        self.fc3_2 = nn.Linear(hidden_num, input_dim, bias=True)

        self.act = lambda x: torch.tanh(x)
        if self.relu:
            self.act = lambda x: torch.relu(x)

    def forward(self, x_input, t):
        inputs = torch.cat([x_input, t], dim=1)
        mask = (t < 0.6).float()

        # 
        x_1 = self.fc1_1(inputs)
        x_1 = self.act(x_1)
        x_1 = self.fc2_1(x_1)
        x_1 = self.act(x_1)
        x_1 = self.fc4_1(x_1)
        x_1 = self.act(x_1)
        x_1 = self.fc3_1(x_1)

        x_2 = self.fc1_2(inputs)
        x_2 = self.act(x_2)
        x_2 = self.fc2_2(x_2)
        x_2 = self.act(x_2)
        x_2 = self.fc4_2(x_2)
        x_2 = self.act(x_2)
        x_2 = self.fc3_2(x_2)

        # t
        return x_1 * (1 - mask) + x_2 * mask

class TwoModel(nn.Module):
    def __init__(self, input_dim=2,
                 # structure_str='100-100-100-100-100-100-100-100-100-100',
                 # structure_str='50-50',
                 # structure_str='100-100-100',
                 # structure_str='1000-1000-1000',
                 # structure_str='100-50-25-10-25-50-100',
                 # structure_str='50-50-50-50-50',
                 # structure_str='100-100-100-100-100-100',
                 structure_str='1000',
                 residual = False,
                 relu=True,
                 ):
        super().__init__()
        print('Two Model')
        self.name = f'Two_Model_{structure_str}'
        self.residual = residual
        self.relu = relu
        if self.residual:
            self.name += '_residual'
        if self.relu:
            self.name += '_relu'

        self.model_1 = MLP(input_dim, structure_str, self.residual, self.relu)
        self.model_2 = MLP(input_dim, structure_str, self.residual, self.relu)

    def forward(self, x_input, t):
        inputs = torch.cat([x_input, t], dim=1)
        mask = (t < 0.6).float()

        # 
        x_1 = self.model_1(x_input, t)
        x_2 = self.model_2(x_input, t)

        return x_1 * (1 - mask) + x_2 * mask

class MLP(nn.Module):
    def __init__(self, input_dim=2,
                 # structure_str='100-100-100-100-100-100-100-100-100-100',
                 # structure_str='50-50',
                 # structure_str='100-100-100',
                 # structure_str='1000-1000-1000',
                 # structure_str='100-50-25-10-25-50-100',
                 # structure_str='50-50-50-50-50',
                 # structure_str='100-100-100-100-100-100',
                 structure_str='100',
                 residual = False,
                 relu=False,
                 dropout=False,
                 ):
        super().__init__()
        print('Residual MLP')
        self.residual = residual
        self.relu = relu
        self.plus_input = False
        self.dropout = dropout
        self.name = f'MLP_{structure_str}'
        if self.residual:
            self.name += '_residual'
        if self.relu:
            self.name += '_relu'
            if self.plus_input:
                self.name += '_plusinput'
        if self.dropout:
            self.name += '_dropout'

        current_dim = input_dim + 1
        hidden_dims = [int(i) for i in structure_str.split('-')]
        self.layers = []
        for i, hidden_dim in enumerate(hidden_dims):
            # self.layers.append(nn.Linear(current_dim, hidden_dim, bias=True))
            # current_dim = hidden_dim
            # # self.layers.append(nn.Sigmoid())
            # self.layers.append(nn.ReLU())
            # # self.layers.append(nn.Tanh())
            if self.relu:
                self.layers.append(nn.Sequential(nn.Linear(current_dim, hidden_dim, bias=True),
                                                 nn.ReLU(),
                                                 ))
                if self.dropout:
                    self.layers.append(nn.Dropout(p=0.2))
            else:
                self.layers.append(nn.Sequential(nn.Linear(current_dim, hidden_dim, bias=True),
                                                 nn.Tanh(),
                                                 ))
                if self.dropout:
                    self.layers.append(nn.Dropout(p=0.2))
            current_dim = hidden_dim
        self.layers.append(nn.Linear(current_dim, input_dim))
        self.model = nn.Sequential(*self.layers)
        # self._initialize_weights()


    def _initialize_weights(self):
        # 
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    def forward(self, x_input, t):
        if self.relu and self.plus_input:
            x_input = x_input + 6
        inputs = torch.cat([x_input, t], dim=1)

        log_mean_coeff = -0.25 * (1-t) ** 2 * (20 - 0.1) - 0.5 * (1-t) * 0.1
        std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
        mask = std
        # x = self.model(inputs)
        x = inputs
        for i, layer in enumerate(self.layers):
            if i in [0, len(self.layers) -1]:
                x = layer(x)
            else:
                current_x = x
                x = layer(x)
                if self.residual:
                    x = x + current_x
        # return x * (1 - mask) + x_input * mask
        # return x_input + x
        return x

class RandomFeature(nn.Module):
    def __init__(self, d=2, p=10):
        """
        
        :param d: 
        :param p: 
        """
        super(RandomFeature, self).__init__()
        self.d = d
        self.p = p
        #  W
        # self.W = nn.Parameter(torch.randn(d + 1, p) / (d+1) ** 0.5 , requires_grad=False)  #  W
        self.W = nn.Parameter(torch.randn(d + 1, p) , requires_grad=False)  #  W
        #  θ
        # self.theta = nn.Parameter(torch.randn(p, d) / p**0.5)
        self.theta = nn.Parameter(torch.randn(p, d))
        self.name = f'RandomFeatureModel_p{self.p}'

    def forward(self, x, t):
        """
        
        :param x:  [N, d]
        :param t:  [N, 1]
        :return: 
        """
        #  x  t
        inputs = torch.cat([x, t], dim=1)  # [N, d+1]
        #  σ(W^T [x, t]^T)
        random_features = F.relu(torch.matmul(inputs, self.W))  # [N, p]
        #  θ^T σ(W^T [x, t]^T)
        output = torch.matmul(random_features, self.theta)  # [N, d]
        return output



class UnetFC(nn.Module):
    def __init__(self, input_dim=2,
                 # config_string='32-64-128',
                 # config_string='64-128-256',
                 # config_string='128-256-512',
                 config_string='32-64-128-256-512',
                 # config_string='1000-1000-1000-1000-1000-1000',
                 activation='relu'):
        super(MLP, self).__init__()
        print('Using 1D U-Net.')

        # Parse the configuration string
        layers = list(map(int, config_string.split('-')))

        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.skip_connections = []

        # Build encoder layers
        for i in range(len(layers) - 1):
            self.encoder.append(
                nn.Sequential(
                    nn.Linear(layers[i], layers[i + 1]),
                    self._get_activation(activation)
                )
            )
        self.encoder.insert(0,
                            nn.Sequential(
                                nn.Linear(input_dim+1, layers[0]),
                                self._get_activation(activation))
                            )

        # Build decoder layers (reversed configuration)
        for i in range(len(layers)-1, 0, -1):
            self.decoder.append(
                nn.Sequential(
                    nn.Linear(layers[i] * 2, layers[i - 1]),  # *2 to accommodate skip connections
                    # nn.Linear(layers[i], layers[i - 1]),  # *2 to accommodate skip connections
                    self._get_activation(activation)
                )
            )
        # Final output layer (maps back to the original dimension)
        self.final_layer = nn.Linear(layers[0] * 2, input_dim)  # *2 for the last skip connection
        # self.final_layer = nn.Linear(layers[0], input_dim)  # *2 for the last skip connection

    def _get_activation(self, activation):
        if activation == 'relu':
            return nn.ReLU()
        elif activation == 'sigmoid':
            return nn.Sigmoid()
        elif activation == 'tanh':
            return nn.Tanh()
        else:
            raise ValueError(f"Unsupported activation function: {activation}")

    def forward(self, x, t):
        log_mean_coeff = -0.25 * (1-t) ** 2 * (20 - 0.1) - 0.5 * (1-t) * 0.1
        std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
        mask = std
        x_input = x
        x = torch.cat([x, t], dim=1)
        # Encoding path
        enc_outputs = []
        for enc in self.encoder:
            x = enc(x)
            enc_outputs.append(x)

        # Decoding path with skip connections
        for i, dec in enumerate(self.decoder):
            x = torch.cat((x, enc_outputs[-(i + 1)]), dim=1)  # Concatenate skip connection
            # x = x + enc_outputs[-(i + 1)]
            x = dec(x)

        # Final layer with the last skip connection
        x = torch.cat((x, enc_outputs[0]), dim=1)
        # x = x + enc_outputs[0]
        x = self.final_layer(x)

        # return x * (1 - mask) + x_input * mask
        return x

class UNet(torch.nn.Module):
    # takes an input image and time, returns the score function
    def __init__(self, d):
        super().__init__()
        nch = 2
        # chs = [19, 64, 128, 256, 256]
        # chs = [9, 9, 9]
        chs = [256, 256, 256]
        self.chs = chs
        self._convs = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Conv2d(2, chs[0], kernel_size=3, padding=1),  # (batch, ch, 28, 28)
                torch.nn.ReLU(),  # (batch, 8, 28, 28)
            ),
            torch.nn.Sequential(
                torch.nn.MaxPool2d(kernel_size=2, stride=2),  # (batch, ch, 14, 14)
                torch.nn.Conv2d(chs[0], chs[1], kernel_size=3, padding=1),  # (batch, ch, 14, 14)
                torch.nn.ReLU(),  # (batch, 8, 28, 28)
            ),
            torch.nn.Sequential(
                torch.nn.MaxPool2d(kernel_size=2, stride=2),  # (batch, ch, 7, 7)
                torch.nn.Conv2d(chs[1], chs[2], kernel_size=3, padding=1),  # (batch, ch, 7, 7)
                torch.nn.ReLU(),  # (batch, 8, 28, 28)
            ),
        ])
        self._tconvs = torch.nn.ModuleList([
            torch.nn.Sequential(
                # input is the output of convs[4]
                torch.nn.ConvTranspose2d(chs[-1], chs[-2], kernel_size=3, stride=2, padding=1, output_padding=1),  # (batch, 64, 4, 4)
                torch.nn.ReLU(),  # (batch, 8, 28, 28)
            ),
            torch.nn.Sequential(
                # input is the output from the above sequential concated with the output from convs[3]
                torch.nn.ConvTranspose2d(chs[-2] * 2, chs[-3], kernel_size=3, stride=2, padding=1, output_padding=1),  # (batch, 32, 7, 7)
                torch.nn.ReLU(),  # (batch, 8, 28, 28)
            ),
            torch.nn.Sequential(
                torch.nn.Conv2d(chs[0] * 2, 1, kernel_size=3, padding=1),  # (batch, 1, 28, 28)
            ),
        ])

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # x: (..., ch0 * 28 * 28), t: (..., 1)
        x2 = torch.reshape(x, (*x.shape[:-1], 1, 28, 28))  # (..., ch0, 28, 28)
        tt = t[..., None, None].expand(*t.shape[:-1], 1, 28, 28)  # (..., 1, 28, 28)
        x2t = torch.cat((x2, tt), dim=-3)
        signal = x2t
        signals = []
        for i, conv in enumerate(self._convs):
            signal = conv(signal)
            if i < len(self._convs) - 1:
                signals.append(signal)

        for i, tconv in enumerate(self._tconvs):
            if i == 0:
                signal = tconv(signal)
            else:
                signal = torch.cat((signal, signals[-i]), dim=-3)
                signal = tconv(signal)
        signal = torch.reshape(signal, (*signal.shape[:-3], -1))  # (..., 1 * 28 * 28)
        return signal

        # int_beta = (0.1 + 0.5 * (20 - 0.1) * t) * t  # integral of beta
        # var_t = -torch.expm1(-int_beta)
        # return - signal / var_t
