import torch
import torch.nn as nn

class ProjectHead(nn.Module):
    def __init__(self, 
            in_channels,
            num_classes, # num_classes
        ):
        super().__init__()
        self.fc = nn.Linear(in_channels, num_classes)

        self.init_weights()

    def init_weights(self):
        def km_init_weights(m):
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight)
        self.fc.apply(km_init_weights)

    def forward(self, x):
        return self.fc(x)


class FCNProjectHead(nn.Module):
    def __init__(self, 
            in_channels,
            num_classes, # num_classes
        ):
        super().__init__()
        self.fc = nn.Linear(in_channels, num_classes)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.init_weights()

    def init_weights(self):
        def km_init_weights(m):
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight)
        self.fc.apply(km_init_weights)

    def forward(self, x):
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)


class VideoProjectHead(nn.Module):
    def __init__(self, 
            in_channels,
            num_classes, # num_classes
        ):
        super().__init__()
        self.fc = nn.Linear(in_channels, num_classes)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.init_weights()

    def init_weights(self):
        def km_init_weights(m):
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight)
        self.fc.apply(km_init_weights)

    def forward(self, x):
        B, F, C, H, W = x.size()
        x = x.reshape(B * F, C, H, W)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        x = self.fc(x)
        _, C = x.size()
        x = x.reshape(B, F, C)
        x = x.mean(dim=1)
        return x
