import math
import torch
import torch.nn as nn


def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


class Translator(nn.Module):

    def __init__(self, channel, net_width, net_depth, net_act, net_norm, net_pooling):
        super(Translator, self).__init__()
        in_channels = channel
        layers = []
        for d in range(net_depth):
            layers += [nn.Conv2d(in_channels, net_width * (2 ** d), kernel_size=(3, 3),
                                 padding=1, padding_mode='reflect')]
            in_channels = net_width * (2 ** d)
            if net_norm != 'identity':
                layers += [self._get_normlayer(net_norm, in_channels)]
            layers += [self._get_activation(net_act)]
            if d != net_depth - 1:
                layers += [self._get_pooling(net_pooling)]
        self.encoder = nn.Sequential(*layers)
        self.embed = nn.Sequential(
            nn.Linear(in_channels, in_channels),
            nn.SiLU(),
            nn.Linear(in_channels, in_channels),
        )
        self.dim = in_channels
        in_channels *= 2
        layers = []
        for d in range(net_depth - 1, -1, -1):
            if d != net_depth - 1:
                layers += [nn.Upsample(scale_factor=2)]
            out_channels = net_width * (2 ** (d - 1)) if d != 0 else channel
            layers += [nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1, padding_mode='reflect')]
            if net_norm != 'identity' and d != 0:
                layers += [self._get_normlayer(net_norm, out_channels)]
            if d != 0:
                layers += [self._get_activation(net_act)]
            in_channels = out_channels
        self.decoder = nn.Sequential(*layers)

    def forward(self, x, n_classes=10):
        if x.dim() == 5:
            b, n = x.shape[:2]
            x_ = x.flatten(0, 1)
        else:
            b = 1
            n = x.shape[0]
            x_ = x
        feat = self.encoder(x_)
        emb = self.embed(timestep_embedding(torch.tensor([n // n_classes] * b * n, device=x.device), self.dim)).unsqueeze(
            -1).unsqueeze(-1).repeat(1, 1, feat.shape[2], feat.shape[3])
        output = self.decoder(torch.cat([feat, emb], dim=1))
        if x.dim() == 5:
            return x + output.unflatten(0, (b, n))
        else:
            return x + output

    def _get_activation(self, net_act):
        if net_act == 'sigmoid':
            return nn.Sigmoid()
        elif net_act == 'relu':
            return nn.ReLU(inplace=True)
        elif net_act == 'leakyrelu':
            return nn.LeakyReLU(negative_slope=0.01)
        else:
            exit('unknown activation function: %s' % net_act)

    def _get_pooling(self, net_pooling):
        if net_pooling == 'maxpooling':
            return nn.MaxPool2d(kernel_size=2, stride=2)
        elif net_pooling == 'avgpooling':
            return nn.AvgPool2d(kernel_size=2, stride=2)
        else:
            exit('unknown net_pooling: %s' % net_pooling)

    def _get_normlayer(self, net_norm, channels):
        # shape_feat = (c*h*w)
        if net_norm == 'batch':
            return nn.BatchNorm2d(channels, affine=True)
        elif net_norm == 'instance':
            return nn.GroupNorm(channels, channels, affine=True)
        elif net_norm == 'group':
            return nn.GroupNorm(32, channels, affine=True)
        elif net_norm == 'identity':
            return None
        else:
            exit('unknown net_norm: %s' % net_norm)

