import sys
import torch
import torch.nn as nn
import torch.nn.functional as F

from pdb import set_trace

class MLP_module(nn.Module):
    def __init__(self, in_features, out_features, affine=True):
        super(MLP_module, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        print("affine is {}".format(affine))
        self.bn = nn.BatchNorm1d(out_features, affine=affine)
        self.relu = nn.ReLU()

    def forward(self, x):
        self.features = []

        out = self.relu(self.linear(x))
        out = self.bn(out)

        self.features.append(out.detach())

        return out


def get_feature(model):
    features_dict = dict()
    for name, module in model.named_modules():
        if isinstance(module, MLP_module):
            features = module.features
            module.features = []
            features_dict[name] = features

    return features_dict


class MLP(nn.Module):
    def __init__(self, hidden, depth=6, fc_bias=True, mlp_bias=0, mlp_bias_multiply=1.0,
                 num_classes=10, affine=True):
        # Depth means how many layers before final linear layer
        
        super(MLP, self).__init__()
        layers = [MLP_module(3072, hidden, affine=affine)]
        for i in range(depth - 1):
            layers += [MLP_module(hidden, hidden, affine=affine)]
        
        self.layers = nn.Sequential(*layers)
        self.fc = nn.Linear(hidden, num_classes, bias=fc_bias)
        print("fc_bias is {}".format(fc_bias))

        # kaiming init
        for m in self.layers.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

        # init
        mlp_bias_cur = mlp_bias
        if abs(mlp_bias) > 1e-6:
            for m in self.layers:
                assert isinstance(m, MLP_module)
                nn.init.constant_(m.linear.bias, mlp_bias_cur)
                mlp_bias_cur *= mlp_bias_multiply

    def clip_grad_bias(self):
        # print("clip grad bias")
        for m in self.modules():
            if isinstance(m, MLP_module):
                if m.linear.bias.requires_grad:
                    m.linear.bias.grad = m.linear.bias.grad.clip(min=0)

        bias_max = -1000
        for m in self.modules():
            if isinstance(m, MLP_module):
                if m.linear.bias is not None:
                    bias_max = max(bias_max, m.linear.bias.max().item())
        print("max bias is {}".format(bias_max))

    def forward(self, x, return_mid_features=False):
        x = x.view(x.shape[0], -1)
        x = self.layers(x)
        # features = F.normalize(x)
        x = self.fc(x)

        features_dict = get_feature(self)

        if return_mid_features:
            return x, x, features_dict

        return x, x
