import math
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from ..base_model import BaseModel
from .utils.weights_init import weights_init_kaiming, weights_init_classifier

class BasicModel(BaseModel):
    def __init__(self, name='res50'):
        super(BasicModel, self).__init__(basenet=name)
        # self.bn = nn.BatchNorm2d(self.feature_in_dim).apply(weights_init_kaiming)
        # self.avg_pool = F.avg_pool2d
        # self.linear = nn.Linear(self.feature_in_dim, 512)
        # self.relu = nn.LeakyReLU(inplace=True)
        # weights_init_classifier(self.linear)
        # self.conv1_con_1 = Conv1Connect(self.feature_in_dim, 1024)


    def forward(self, batch_imgs):
        # shape: B C H W
        x = batch_imgs
        raw_img_feature = self.base(x)
        # feature channel confusion
        # channel_fusion_feature = self.conv1_con_1(raw_img_feature)
        # x = self.avg_pool(raw_img_feature, raw_img_feature.size()[2:])
        # x = self.bn(x).view(x.size()[:2])
        # x = self.linear(x)
        # channel_fusion_feature = self.relu(x).view(x.size()[:1]+(512,1,1))
        # return raw_img_feature, channel_fusion_feature
        return raw_img_feature
