import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


class OursNet(nn.Module):
    def __init__(self, base, num_classes, 
                 cls_hidden=1024, proj_hidden=128, proj_size=128, out_hidden=1024):
        super(OursNet, self).__init__()
        self.backbone = base
        self.num_features = base.num_features

        # inlier classifier with mlp
        self.classifier = nn.Sequential(*[nn.Linear(self.num_features, cls_hidden),
                                          nn.ReLU(inplace=False),
                                          nn.Linear(cls_hidden, num_classes)])

        # projection layer for contrastive learning
        self.proj = nn.Sequential(*[nn.Linear(self.num_features, proj_hidden),
                                    nn.ReLU(inplace=False),
                                    nn.Linear(proj_hidden, proj_size)])
        
        # outlier detector (multi-binary classifier) with mlp
        self.out_detector = nn.Sequential(*[nn.Linear(self.num_features, out_hidden),
                                            nn.ReLU(inplace=False),
                                            nn.Linear(out_hidden, out_hidden),
                                            nn.ReLU(inplace=False),
                                            nn.Linear(out_hidden, 2 * num_classes)])

        # initialize
        self.initialize_weights()


    def initialize_weights(self):
        # initialize classifier
        for m in self.classifier:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        # initialize projection layer
        for m in self.proj:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                    
        # initialize outlier detector
        for m in self.out_detector:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                    
        
    def forward(self, x, only_feat=False):
        feat = self.backbone(x, only_feat=True)

        logits = self.classifier(feat)
        feat_proj = self.proj(feat)
        logits_out = self.out_detector(feat)
        
        return_dict = {'logits': logits, 'feat_proj': feat_proj, 'logits_out': logits_out}
        
        return return_dict