import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


class ProSubNet(nn.Module):
    def __init__(self, base, num_classes, 
                 cls_hidden=1024, proj_hidden=128, proj_size=128):
        super(ProSubNet, self).__init__()
        self.backbone = base
        self.num_features = base.num_features

        # inlier classifier with mlp
        self.classifier = nn.Sequential(*[nn.Linear(self.num_features, cls_hidden),
                                          nn.ReLU(inplace=False),
                                          nn.Linear(cls_hidden, num_classes)])

    # projection layer for contrastive learning (TF와 동일하게 1-layer)
        self.proj = nn.Linear(self.num_features, proj_size)
        

        # initialize
        self.initialize_weights()


    def initialize_weights(self):
        # initialize classifier
        for m in self.classifier:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        # initialize projection layer
        if isinstance(self.proj, nn.Linear):
            nn.init.kaiming_normal_(self.proj.weight, mode='fan_out', nonlinearity='relu')
            if self.proj.bias is not None:
                nn.init.constant_(self.proj.bias, 0)
                    
        # # initialize outlier detector
        # for m in self.out_detector:
        #     if isinstance(m, nn.Linear):
        #         nn.init.xavier_normal_(m.weight)
        #         if m.bias is not None:
        #             nn.init.constant_(m.bias, 0)
                    
        
    def forward(self, x, **kwargs):
        feat = self.backbone(x, only_feat=True)
        logits = self.classifier(feat)
        feat_proj = self.proj(feat)

        
        return_dict = {'logits': logits, 'feat': feat, 'proj': feat_proj}
        
        return return_dict