import math
import torch as torch
import torch.nn as nn


class GroupNorm32(nn.GroupNorm):
    def forward(self, x):
        """"
        x_0 = x[:,:self.channels,...]
        x_1 = x[:,self.channels:,...]

        res_channels = x.shape[1] - self.channels
        x_0 = super().forward(x_0.float()).type(x.dtype)
        gn_res = nn.GroupNorm(1,res_channels)
        x_1 = gn_res(x_1.float()).type(x.dtype)

        x = torch.cat((x_0,x_1),axis=1)
        """
        return super().forward(x.float()).type(x.dtype)


def pruned_group_norm(teams,channels32):
    return GroupNorm32(teams, channels32)


def get_new_out_channel(reversed_channel_1,reversed_channel_2):
    tmp = list(reversed_channel_1) + list(reversed_channel_2)
    tmp = list(set(tmp))
    new_out_channel = len(tmp)
    return new_out_channel, tmp