import torch
import math
from torch import nn


_act_init_dict = {"relu": nn.ReLU, "leaky_relu": nn.LeakyReLU, "silu": nn.SiLU}


def timestep_embedding(timesteps, dim, max_period=10000):
    ### Heavilt borrowed from https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/nn.py
    """Create sinusoidal timestep embeddings.

    Args:
        timesteps (torch.Tensor): a 1-D tensor of N indices.
        dim (int): the dimension of the output
        max_period (int, optional): controls the minimum frequency of the embeddings. Defaults to 10000.

    Returns:
        torch.Tensor: an N x dim tensor of positional embeddings
        for a timestep t, its positional embedding is a vector p where
        p_i = sin(w_k * t) if i = 2k and p_i = cos(w_k * t) if i = 2k + 1.
        Here, w_k = 1 / (max_period ** (2 * k / dim)).
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * (torch.arange(start=0, end=half, dtype=torch.float32) / half)
    ).to(device=timesteps.device)
    phases = timesteps[:, None].float() * freqs[None]
    p = timesteps.new_zeros(len(timesteps), dim)
    p[:, 0:-1:2] = torch.sin(phases)
    p[:, 1::2] = torch.cos(phases)
    return p


class ResidualBlock(nn.Module):
    def __init__(self, dim, t_dim, h_dim, act="relu", group_norm=False):
        assert dim == h_dim
        super().__init__()
        self.fc1 = nn.Linear(dim, h_dim)
        self.emb = nn.Linear(t_dim, h_dim)
        if group_norm:
            self.fc2 = nn.Sequential(nn.GroupNorm(32, h_dim), nn.Linear(h_dim, dim))
        else:
            self.fc2 = nn.Linear(h_dim, dim)
        self.act = _act_init_dict[act]()

    def forward(self, x, temb):
        h = self.fc1(x) + self.emb(temb)
        h = self.act(h)
        return x + self.fc2(h)