import torch
from torch import nn
from efficientnet_pytorch import EfficientNet
from torch.nn import init



class effNet(nn.Module):
    def __init__(self):
        super(effNet, self).__init__()
        self.efficient = EfficientNet.from_pretrained('efficientnet-b0')
        # self.efficient = EfficientNet.from_name('efficientnet-b0')

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # self.eff = nn.Sequential()
        self.fc1 = nn.Linear(1280, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 1)
        self.drop= nn.Dropout(p=0.5, inplace=False)

        init.normal_(self.fc1.weight, mean=0, std=0.01)
        init.normal_(self.fc2.weight, mean=0, std=0.01)
        init.normal_(self.fc3.weight, mean=0, std=0.01)
    def forward(self, x):
        # print('x', x.shape)#[32, 3, 224, 224]
        # print('x_ref', x_ref.shape)#[32, 3, 224, 224]

        #pretraining code
        endpoints_x = self.efficient.extract_endpoints(x)
        x2 = endpoints_x['reduction_2']
        x3 = endpoints_x['reduction_3']
        x4 = endpoints_x['reduction_4']
        # endpoints_x_ref = self.efficient.extract_endpoints(x_ref)
        # x2 = endpoints_x['reduction_2']
        # x3 = endpoints_x['reduction_3']
        # x4 = endpoints_x['reduction_4']
        x = endpoints_x['reduction_6'] #torch.Size([12, 1280, 7, 7])
        # x_ref = endpoints_x_ref['reduction_6'] #torch.Size([12, 1280, 7, 7])
        temp = x #torch.Size([12, 1280, 7, 7])
        temp = self.avgpool(temp) #torch.Size([12, 1280, 1, 1])
        temp = torch.flatten(temp, 1)#torch.Size([12, 1280])
        q = torch.nn.functional.relu(self.fc1(temp)) #torch.Size([12, 512])
        # q = self.drop(q)
        # q = torch.nn.functional.dropout(q)
        q = torch.nn.functional.relu(self.fc2(q)) #torch.Size([12, 512])
        # q = self.drop(q)
        q = self.fc3(q)#torch.Size([12, 1])
        return q, [x2, x3, x4, x], [self.fc1.weight ,self.fc2.weight ,self.fc3.weight]

        # 抽取特征提取模块
        # endpoints_x = self.efficient.extract_endpoints(data)
        # feature_red2 = endpoints_x['reduction_2']
        # feature_red3 = endpoints_x['reduction_3']
        # feature_red4 = endpoints_x['reduction_4']
        # feature_red5 = endpoints_x['reduction_6']
        # return [feature_red2, feature_red3, feature_red4, feature_red5]

