import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


class OpenNet(nn.Module):
    def __init__(self, base, num_classes, 
                 cls_hidden=1024, out_hidden=1024, mlp=True):
        super(OpenNet, self).__init__()
        self.backbone = base
        self.num_features = base.num_features

        if mlp:
            # inlier classifier with mlp
            self.classifier = nn.Sequential(*[nn.Linear(self.num_features, cls_hidden),
                                            nn.ReLU(inplace=False),
                                            nn.Linear(cls_hidden, num_classes)])

            # outlier detector (multi-binary classifier) with mlp
            self.out_detector = nn.Sequential(*[nn.Linear(self.num_features, out_hidden),
                                                nn.ReLU(inplace=False),
                                                nn.Linear(out_hidden, out_hidden),
                                                nn.ReLU(inplace=False),
                                                nn.Linear(out_hidden, 2 * num_classes)])
        else:
            self.classifier = nn.Sequential(*[nn.Linear(self.num_features, num_classes)])
            self.out_detector = nn.Sequential(*[nn.Linear(self.num_features, 2 * num_classes)])
        
        # initialize
        self.initialize_weights()


    def initialize_weights(self):
        # initialize classifier
        for m in self.classifier:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

        # initialize outlier detector
        for m in self.out_detector:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                    
        
    def forward(self, x, **kwargs):
        feat = self.backbone(x, only_feat=True)
        
        logits = self.classifier(feat)
        logits_out = self.out_detector(feat)
        
        return_dict = {'logits': logits, 'logits_out': logits_out}
        
        return return_dict