#!/usr/bin/env python
# coding=utf-8
import sys
sys.path.append('./')
import torch
import torch.nn as nn


class ADM_block(nn.Module):
    '''
    For example:
    Model: ViT-base
    Input_channel: default 768, which is the dim of ViT.
    out_channel : default 256, which is dimensionality of penultimate layer features.
    num_classes: default 31, the class of dataset.
    height: default 7, for ViT-B-32 is 7, while for ViT-B-16 is 14,
    wight: the same as height
    '''
    def __init__(self, input_channel, output_channel, num_classes, height=14, wight=14):
        super(ADM_block, self).__init__()
        self.h = height
        self.w = wight
        self.feas = nn.Sequential(nn.Conv2d(input_channel, output_channel, 3, stride=2),
                                  nn.BatchNorm2d(output_channel),
                                  nn.ReLU(inplace=True),
                                  nn.Conv2d(output_channel, output_channel, 3, stride=1),
                                  nn.BatchNorm2d(output_channel),
                                  nn.ReLU(inplace=True),
                                  nn.Conv2d(output_channel, output_channel, 3, stride=1))
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.fc = nn.Sequential(nn.Linear(output_channel, output_channel),
                        nn.ReLU(inplace=True),
                        nn.Linear(output_channel, num_classes))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(
                    m.weight)  
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)   
    def reshape_transform(self, tensor, height=14, width=14):  # tensor.shape[1]  == 197
        size_ = tensor.shape[1]
        if size_ == 197:  # for vit normal
            result = tensor[:, 1:, :].reshape(tensor.size(0),
                                            height, width, tensor.size(2))
        elif size_ == 198: # for deit distilled
            result = tensor[:, 2:, :].reshape(tensor.size(0),
                                            height, width, tensor.size(2))
        # Bring the channels to the first dimension, like in CNNs.
        result = result.transpose(2, 3).transpose(1, 2)
        return result


    def forward(self, x, get_feas=False):
        shape_feas = self.reshape_transform(x, height=self.h, width=self.w) # torch.Size([1, 768, 14, 14])
        feas = self.feas(shape_feas)
        feas_avg = self.avg_pool(feas)            
        feas = torch.flatten(feas_avg, 1)
        if get_feas:
            return feas
        out = self.fc(feas)
        return (feas, out)
def obatin_vit_adm(model='vit_base', num_classes=31):
    input_channel = 768
    if model == 'vit_large':
        input_channel = 1024
    if model == 'vit_base' or model =='vit_small' or model =='deit_tiny':
        input_channel = 768
    if  model == 'deit_small':
        input_channel = 384
    if  model == 'deit_tiny':
        input_channel = 192

    GoAN = ADM_block(input_channel=input_channel, output_channel=256, num_classes=num_classes, height=14, wight=14)
    return GoAN


if __name__ == '__main__':
    pass
