import torch.nn as nn
import torch.nn.functional as F
from model.architecture.ResNet import model_dict

class ConResNet(nn.Module):
    """backbone + projection head"""
    def __init__(self, args, feat_dim=128, is_teacher = False):
        super(ConResNet, self).__init__()
        if is_teacher:
            module, dim_in = model_dict[args.teacher_arch]
        else:
            module, dim_in = model_dict[args.student_arch]
        self.encoder = module()

        if args.head == 'linear':
            self.head = nn.Linear(dim_in, feat_dim)
        elif args.head == 'mlp':
            self.head = nn.Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(inplace=True),
                nn.Linear(dim_in, feat_dim)
            )
        else:
            raise NotImplementedError(
                'labeled head not supported: {}'.format(args.head))

    def reinit_head(self):
        for layers in self.head.children():
            if hasattr(layers, 'reset_parameters'):
                layers.reset_parameters()

    def forward(self, x, return_feat=False, norm=True):
        encoded = self.encoder(x)
        feat = self.head(encoded)
        if norm:
            feat = F.normalize(feat, dim=1)
        if return_feat:
            return feat, encoded
        else:
            return feat


class LinearClassifier(nn.Module):
    """Linear classifier"""
    def __init__(self, args, is_teacher = False):
        super(LinearClassifier, self).__init__()
        if is_teacher:
            _, feat_dim = model_dict[args.teacher_arch]
        else:
            _, feat_dim = model_dict[args.student_arch]
        self.fc = nn.Linear(feat_dim, args.labeled_dataset.num_classes)

    def forward(self, features):
        return self.fc(features)
