from statistics import mean

import torch
from sklearn.metrics import (accuracy_score, f1_score, precision_score,
                             recall_score)
from torch import nn as nn
from torchvision import models
from utils.utils import (Aggregation_Separation_Loss,
                         Aggregation_Separation_Loss_with_multi_distribution)

from networks.pvt import pvt_small
from networks.swin_transformer.swin_transformer import swin_t
from networks.t2t_vit.models import t2t_vit_14
from networks.vit.pytorch_pretrained_vit.model import ViT


class BaseNetwork(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.config = config
        self.parse_config()
        network = models.vgg.vgg16(pretrained=False)
        self.feature = network.features
        self.pool = network.avgpool
        self.dnn = network.classifier
        self.space = nn.Sequential(
            nn.Linear(1000, 512),
            nn.ReLU(inplace=False),
            nn.Dropout(),
            nn.Linear(512, 128),
            nn.ReLU(inplace=False),
            nn.Linear(128, self.spatial_dimension),
            nn.ReLU(inplace=False),
        )
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)

        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)
        self.loss_function_cls = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.feature(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.dnn(x)
        distribution = self.space(x)
        x = self.classifier(distribution)
        return distribution, x

    def save_state_dict(self):
        torch.save(self.state_dict(), self.model_save_file)
        return 'save model at {}\n'.format(self.model_save_file)

    def load_state_dict(self):
        super().load_state_dict(
            torch.load(self.model_load_file, map_location=self.device))

    def train_one_epoch(self, trainloader):
        predictions = []
        groundtruths = []
        losses_cls = []

        self.train()
        for step, (image, label) in enumerate(trainloader):
            self.optimizer.zero_grad()
            image = image.to(self.device)
            label = label.to(self.device)
            distribution, output = self(image)

            loss_cls = self.loss_function_cls(output, label)
            loss = loss_cls
            loss.backward(retain_graph=True)
            self.optimizer.step()

            _, prediction = torch.max(output, 1)
            predictions += prediction.detach().cpu().numpy().tolist()
            groundtruths += label.detach().cpu().numpy().tolist()
            losses_cls.append(loss_cls.item())

        accuracy = accuracy_score(groundtruths, predictions)
        precision = precision_score(groundtruths,
                                    predictions,
                                    average='macro',
                                    zero_division=0)
        recall = recall_score(groundtruths,
                              predictions,
                              average='macro',
                              zero_division=0)
        f1 = f1_score(groundtruths,
                      predictions,
                      average='macro',
                      zero_division=0)
        losses_cls_avg = mean(losses_cls)

        ret = {
            "loss_cls": losses_cls_avg,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_score": f1
        }

        return ret

    def test_one_epoch(self, testloader):
        predictions = []
        groundtruths = []
        losses_cls = []

        with torch.no_grad():
            self.eval()
            for step, (image, label) in enumerate(testloader):
                image = image.to(self.device)
                label = label.to(self.device)
                distribution, output = self(image)

                loss_cls = self.loss_function_cls(output, label)

                _, prediction = torch.max(output, 1)
                predictions += prediction.detach().cpu().numpy().tolist()
                groundtruths += label.detach().cpu().numpy().tolist()
                losses_cls.append(loss_cls.item())

        accuracy = accuracy_score(groundtruths, predictions)
        precision = precision_score(groundtruths,
                                    predictions,
                                    average='macro',
                                    zero_division=0)
        recall = recall_score(groundtruths,
                              predictions,
                              average='macro',
                              zero_division=0)
        f1 = f1_score(groundtruths,
                      predictions,
                      average='macro',
                      zero_division=0)
        losses_cls_avg = mean(losses_cls)

        ret = {
            "loss_cls": losses_cls_avg,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_score": f1
        }
        return ret

    def predict(self, dataloader):
        predictions = []
        groundtruths = []
        distributions = []
        with torch.no_grad():
            self.eval()
            for step, (image, label) in enumerate(dataloader):
                image = image.to(self.device)
                label = label.to(self.device)
                distribution, output = self(image)
                _, prediction = torch.max(output, 1)
                predictions += prediction.detach().cpu().numpy().tolist()
                groundtruths += label.detach().cpu().numpy().tolist()
                distributions += distribution.detach().cpu().numpy().tolist()

        return groundtruths, predictions, distributions

    def parse_config(self):
        self.pretrain = self.config.pretrain
        self.spatial_dimension = self.config.spatial_dimension
        self.class_nums = self.config.class_nums
        self.learning_rate = self.config.learning_rate
        self.constraint = self.config.constraint
        self.device = self.config.device
        self.model_save_file = self.config.model_save_file
        self.model_load_file = self.config.model_load_file
        self.lambda_cls = self.config.lambda_cls
        self.lambda_penalty = self.config.lambda_penalty
        self.lambda_inner = self.config.lambda_inner
        self.lambda_outer = self.config.lambda_outer
        self.distance_loss = eval('nn.{}()'.format(self.config.distance_loss))


class BaseNetwork_with_ASLoss(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.config = config
        self.parse_config()
        network = models.vgg.vgg16(pretrained=False)
        self.feature = network.features
        self.pool = network.avgpool
        self.dnn = network.classifier
        self.space = nn.Sequential(
            nn.Linear(1000, 512),
            nn.ReLU(inplace=False),
            nn.Dropout(),
            nn.Linear(512, 128),
            nn.ReLU(inplace=False),
            nn.Linear(128, self.spatial_dimension),
            nn.ReLU(inplace=False),
        )
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)

        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)
        self.loss_function_cls = torch.nn.CrossEntropyLoss()
        self.loss_function_distribution = Aggregation_Separation_Loss(
            self.distance_loss, self.constraint)

    def forward(self, x):
        x = self.feature(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.dnn(x)
        distribution = self.space(x)
        x = self.classifier(distribution)
        return distribution, x

    def save_state_dict(self):
        torch.save(self.state_dict(), self.model_save_file)
        return 'save model at {}\n'.format(self.model_save_file)

    def load_state_dict(self):
        super().load_state_dict(
            torch.load(self.model_load_file, map_location=self.device))

    def train_one_epoch(self, trainloader):
        predictions = []
        groundtruths = []
        losses_cls = []
        losses_inner = []
        losses_outer = []
        losses_penalty = []

        self.train()
        for step, (image, label) in enumerate(trainloader):
            self.optimizer.zero_grad()
            image = image.to(self.device)
            label = label.to(self.device)
            distribution, output = self(image)

            loss_inner, loss_outer, loss_penalty = self.loss_function_distribution(
                distribution, label)
            loss_cls = self.loss_function_cls(output, label)
            loss = self.lambda_cls * loss_cls + self.lambda_penalty * loss_penalty + self.lambda_inner * loss_inner - self.lambda_outer * loss_outer
            loss.backward(retain_graph=True)
            self.optimizer.step()

            _, prediction = torch.max(output, 1)
            predictions += prediction.detach().cpu().numpy().tolist()
            groundtruths += label.detach().cpu().numpy().tolist()
            losses_cls.append(loss_cls.item())
            losses_inner.append(loss_inner.item())
            losses_outer.append(loss_outer.item())
            losses_penalty.append(loss_penalty.item())

        accuracy = accuracy_score(groundtruths, predictions)
        precision = precision_score(groundtruths,
                                    predictions,
                                    average='macro',
                                    zero_division=0)
        recall = recall_score(groundtruths,
                              predictions,
                              average='macro',
                              zero_division=0)
        f1 = f1_score(groundtruths,
                      predictions,
                      average='macro',
                      zero_division=0)
        losses_cls_avg = mean(losses_cls)
        losses_inner_avg = mean(losses_inner)
        losses_outer_avg = mean(losses_outer)
        losses_penalty_avg = mean(losses_penalty)

        ret = {
            "loss_inner": losses_inner_avg,
            "loss_outer": losses_outer_avg,
            "loss_penalty": losses_penalty_avg,
            "loss_cls": losses_cls_avg,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_score": f1
        }

        return ret

    def test_one_epoch(self, testloader):
        predictions = []
        groundtruths = []
        losses_cls = []
        losses_inner = []
        losses_outer = []
        losses_penalty = []
        with torch.no_grad():
            self.eval()
            for step, (image, label) in enumerate(testloader):
                image = image.to(self.device)
                label = label.to(self.device)
                distribution, output = self(image)

                loss_inner, loss_outer, loss_penalty = self.loss_function_distribution(
                    distribution, label)
                loss_cls = self.loss_function_cls(output, label)
                loss = 0

                _, prediction = torch.max(output, 1)
                predictions += prediction.detach().cpu().numpy().tolist()
                groundtruths += label.detach().cpu().numpy().tolist()
                losses_cls.append(loss_cls.item())
                losses_inner.append(loss_inner.item())
                losses_outer.append(loss_outer.item())
                losses_penalty.append(loss_penalty.item())

        accuracy = accuracy_score(groundtruths, predictions)
        precision = precision_score(groundtruths,
                                    predictions,
                                    average='macro',
                                    zero_division=0)
        recall = recall_score(groundtruths,
                              predictions,
                              average='macro',
                              zero_division=0)
        f1 = f1_score(groundtruths,
                      predictions,
                      average='macro',
                      zero_division=0)
        losses_cls_avg = mean(losses_cls)
        losses_inner_avg = mean(losses_inner)
        losses_outer_avg = mean(losses_outer)
        losses_penalty_avg = mean(losses_penalty)

        ret = {
            "loss_inner": losses_inner_avg,
            "loss_outer": losses_outer_avg,
            "loss_penalty": losses_penalty_avg,
            "loss_cls": losses_cls_avg,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_score": f1
        }
        return ret

    def predict(self, dataloader):
        predictions = []
        groundtruths = []
        distributions = []
        with torch.no_grad():
            self.eval()
            for step, (image, label) in enumerate(dataloader):
                image = image.to(self.device)
                label = label.to(self.device)
                distribution, output = self(image)
                _, prediction = torch.max(output, 1)
                predictions += prediction.detach().cpu().numpy().tolist()
                groundtruths += label.detach().cpu().numpy().tolist()
                distributions += distribution.detach().cpu().numpy().tolist()

        return groundtruths, predictions, distributions

    def parse_config(self):
        self.pretrain = self.config.pretrain
        self.spatial_dimension = self.config.spatial_dimension
        self.class_nums = self.config.class_nums
        self.learning_rate = self.config.learning_rate
        self.constraint = self.config.constraint
        self.device = self.config.device
        self.model_save_file = self.config.model_save_file
        self.model_load_file = self.config.model_load_file
        self.lambda_cls = self.config.lambda_cls
        self.lambda_penalty = self.config.lambda_penalty
        self.lambda_inner = self.config.lambda_inner
        self.lambda_outer = self.config.lambda_outer
        self.distance_loss = eval('nn.{}()'.format(self.config.distance_loss))


class BaseNetwork_with_ASLoss_multi_distribution(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.config = config
        self.parse_config()
        network = models.vgg.vgg16(pretrained=False)
        self.feature = network.features
        self.pool = network.avgpool
        self.dnn = network.classifier
        self.space_1 = nn.Sequential(
            nn.Linear(1000, 512),
            nn.ReLU(inplace=False),
        )
        self.space_2 = nn.Sequential(
            nn.Dropout(),
            nn.Linear(512, 128),
            nn.ReLU(inplace=False),
        )
        self.space_3 = nn.Sequential(
            nn.Linear(128, self.spatial_dimension),
            nn.ReLU(inplace=False),
        )
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)

        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)
        self.loss_function_cls = torch.nn.CrossEntropyLoss()
        self.loss_function_distribution = Aggregation_Separation_Loss_with_multi_distribution(
            self.distance_loss, self.constraint)

    def forward(self, x):
        x = self.feature(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        distribution_0 = self.dnn(x)
        distribution_1 = self.space_1(distribution_0)
        distribution_2 = self.space_2(distribution_1)
        distribution_3 = self.space_3(distribution_2)
        x = self.classifier(distribution_3)
        all_distributions = [
            distribution_0, distribution_1, distribution_2, distribution_3
        ]
        distributions = [
            all_distributions[idx] for idx in self.distribution_layers
        ]
        return distributions, x

    def save_state_dict(self):
        torch.save(self.state_dict(), self.model_save_file)
        return 'save model at {}\n'.format(self.model_save_file)

    def load_state_dict(self):
        super().load_state_dict(
            torch.load(self.model_load_file, map_location=self.device))

    def train_one_epoch(self, trainloader):
        predictions = []
        groundtruths = []
        losses_cls = []
        losses_inner = []
        losses_outer = []
        losses_penalty = []

        self.train()
        for step, (image, label) in enumerate(trainloader):
            self.optimizer.zero_grad()
            image = image.to(self.device)
            label = label.to(self.device)
            distributions, output = self(image)

            loss_inner, loss_outer, loss_penalty = self.loss_function_distribution(
                distributions, label, self.lambda_distribution)
            loss_cls = self.loss_function_cls(output, label)
            loss = self.lambda_cls * loss_cls + self.lambda_penalty * loss_penalty + self.lambda_inner * loss_inner - self.lambda_outer * loss_outer
            loss.backward(retain_graph=True)
            self.optimizer.step()

            _, prediction = torch.max(output, 1)
            predictions += prediction.detach().cpu().numpy().tolist()
            groundtruths += label.detach().cpu().numpy().tolist()
            losses_cls.append(loss_cls.item())
            losses_inner.append(loss_inner.item())
            losses_outer.append(loss_outer.item())
            losses_penalty.append(loss_penalty.item())

        accuracy = accuracy_score(groundtruths, predictions)
        precision = precision_score(groundtruths,
                                    predictions,
                                    average='macro',
                                    zero_division=0)
        recall = recall_score(groundtruths,
                              predictions,
                              average='macro',
                              zero_division=0)
        f1 = f1_score(groundtruths,
                      predictions,
                      average='macro',
                      zero_division=0)
        losses_cls_avg = mean(losses_cls)
        losses_inner_avg = mean(losses_inner)
        losses_outer_avg = mean(losses_outer)
        losses_penalty_avg = mean(losses_penalty)

        ret = {
            "loss_inner": losses_inner_avg,
            "loss_outer": losses_outer_avg,
            "loss_penalty": losses_penalty_avg,
            "loss_cls": losses_cls_avg,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_score": f1
        }

        return ret

    def test_one_epoch(self, testloader):
        predictions = []
        groundtruths = []
        losses_cls = []
        losses_inner = []
        losses_outer = []
        losses_penalty = []
        with torch.no_grad():
            self.eval()
            for step, (image, label) in enumerate(testloader):
                image = image.to(self.device)
                label = label.to(self.device)
                distributions, output = self(image)

                loss_inner, loss_outer, loss_penalty = self.loss_function_distribution(
                    distributions, label, self.lambda_distribution)
                loss_cls = self.loss_function_cls(output, label)
                loss = 0

                _, prediction = torch.max(output, 1)
                predictions += prediction.detach().cpu().numpy().tolist()
                groundtruths += label.detach().cpu().numpy().tolist()
                losses_cls.append(loss_cls.item())
                losses_inner.append(loss_inner.item())
                losses_outer.append(loss_outer.item())
                losses_penalty.append(loss_penalty.item())

        accuracy = accuracy_score(groundtruths, predictions)
        precision = precision_score(groundtruths,
                                    predictions,
                                    average='macro',
                                    zero_division=0)
        recall = recall_score(groundtruths,
                              predictions,
                              average='macro',
                              zero_division=0)
        f1 = f1_score(groundtruths,
                      predictions,
                      average='macro',
                      zero_division=0)
        losses_cls_avg = mean(losses_cls)
        losses_inner_avg = mean(losses_inner)
        losses_outer_avg = mean(losses_outer)
        losses_penalty_avg = mean(losses_penalty)

        ret = {
            "loss_inner": losses_inner_avg,
            "loss_outer": losses_outer_avg,
            "loss_penalty": losses_penalty_avg,
            "loss_cls": losses_cls_avg,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_score": f1
        }
        return ret

    def predict(self, dataloader):
        predictions = []
        groundtruths = []
        distributions = []
        with torch.no_grad():
            self.eval()
            for step, (image, label) in enumerate(dataloader):
                image = image.to(self.device)
                label = label.to(self.device)
                distribution, output = self(image)
                _, prediction = torch.max(output, 1)
                predictions += prediction.detach().cpu().numpy().tolist()
                groundtruths += label.detach().cpu().numpy().tolist()
                distributions += distribution.detach().cpu().numpy().tolist()

        return groundtruths, predictions, distributions

    def parse_config(self):
        self.pretrain = self.config.pretrain
        self.spatial_dimension = self.config.spatial_dimension
        self.class_nums = self.config.class_nums
        self.learning_rate = self.config.learning_rate
        self.constraint = self.config.constraint
        self.device = self.config.device
        self.model_save_file = self.config.model_save_file
        self.model_load_file = self.config.model_load_file
        self.lambda_cls = self.config.lambda_cls
        self.lambda_penalty = self.config.lambda_penalty
        self.lambda_inner = self.config.lambda_inner
        self.lambda_outer = self.config.lambda_outer
        self.distance_loss = eval('nn.{}()'.format(self.config.distance_loss))
        self.distribution_layers = self.config.config['distribution_layers']
        self.lambda_distribution = self.config.config['lambda_distribution']


class Vgg16(BaseNetwork):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = models.vgg.vgg16(pretrained=self.pretrain)
        self.feature = network.features
        self.pool = network.avgpool
        self.dnn = network.classifier
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)

        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)


class Vgg16_with_ASLoss(BaseNetwork_with_ASLoss):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = models.vgg.vgg16(pretrained=self.pretrain)
        self.feature = network.features
        self.pool = network.avgpool
        self.dnn = network.classifier
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)

        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)


class ResNet50(BaseNetwork):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = models.resnet50(pretrained=self.pretrain)
        self.feature = nn.Sequential(network.conv1, network.bn1, network.relu,
                                     network.maxpool, network.layer1,
                                     network.layer2, network.layer3,
                                     network.layer4)
        self.pool = network.avgpool
        self.dnn = network.fc
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)

        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)


class ResNet50_with_ASLoss(BaseNetwork_with_ASLoss):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = models.resnet50(pretrained=self.pretrain)
        self.feature = nn.Sequential(network.conv1, network.bn1, network.relu,
                                     network.maxpool, network.layer1,
                                     network.layer2, network.layer3,
                                     network.layer4)
        self.pool = network.avgpool
        self.dnn = network.fc
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)

        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)


class EfficientNet_b3(BaseNetwork):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = models.efficientnet_b3(pretrained=self.pretrain)
        self.feature = network.features
        self.pool = network.avgpool
        self.dnn = network.classifier
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)
        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)


class EfficientNet_b3_with_ASLoss(BaseNetwork_with_ASLoss):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = models.efficientnet_b3(pretrained=self.pretrain)
        self.feature = network.features
        self.pool = network.avgpool
        self.dnn = network.classifier
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)
        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)


class EfficientNet_b3_with_ASLoss_multi_distribution(
        BaseNetwork_with_ASLoss_multi_distribution):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = models.efficientnet_b3(pretrained=self.pretrain)
        self.feature = network.features
        self.pool = network.avgpool
        self.dnn = network.classifier
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)
        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)


class ResNeXt50(BaseNetwork):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = models.resnext50_32x4d(pretrained=self.pretrain)
        self.feature = nn.Sequential(network.conv1, network.bn1, network.relu,
                                     network.maxpool, network.layer1,
                                     network.layer2, network.layer3,
                                     network.layer4)
        self.pool = network.avgpool
        self.dnn = network.fc
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)

        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)


class ResNeXt50_with_ASLoss(BaseNetwork_with_ASLoss):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = models.resnext50_32x4d(pretrained=self.pretrain)
        self.feature = nn.Sequential(network.conv1, network.bn1, network.relu,
                                     network.maxpool, network.layer1,
                                     network.layer2, network.layer3,
                                     network.layer4)
        self.pool = network.avgpool
        self.dnn = network.fc
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)

        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)


class Vit(BaseNetwork):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = None
        self.feature = None
        self.pool = None
        self.dnn = None
        self.network = ViT('B_16_imagenet1k',
                           pretrained=self.pretrain,
                           image_size=self.config.input_size)
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)
        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)

    def forward(self, x):
        x = self.network(x)
        distribution = self.space(x)
        x = self.classifier(distribution)
        return distribution, x


class Vit_with_ASLoss(BaseNetwork_with_ASLoss):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = None
        self.feature = None
        self.pool = None
        self.dnn = None
        self.network = ViT('B_16_imagenet1k',
                           pretrained=self.pretrain,
                           image_size=self.config.input_size)
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)
        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)

    def forward(self, x):
        x = self.network(x)
        distribution = self.space(x)
        x = self.classifier(distribution)
        return distribution, x


class SwinTransformer(BaseNetwork):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = None
        self.feature = None
        self.pool = None
        self.dnn = None
        self.network = swin_t(pretrained=self.pretrain)
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)
        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)

    def forward(self, x):
        x = self.network(x)
        distribution = self.space(x)
        x = self.classifier(distribution)
        return distribution, x


class SwinTransformer_with_ASLoss(BaseNetwork_with_ASLoss):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = None
        self.feature = None
        self.pool = None
        self.dnn = None
        self.network = swin_t(pretrained=self.pretrain)
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)
        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)

    def forward(self, x):
        x = self.network(x)
        distribution = self.space(x)
        x = self.classifier(distribution)
        return distribution, x


class PVT(BaseNetwork):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = None
        self.feature = None
        self.pool = None
        self.dnn = None
        self.network = pvt_small(pretrained=self.pretrain)
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)
        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)

    def forward(self, x):
        x = self.network(x)
        distribution = self.space(x)
        x = self.classifier(distribution)
        return distribution, x


class PVT_with_ASLoss(BaseNetwork_with_ASLoss):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = None
        self.feature = None
        self.pool = None
        self.dnn = None
        self.network = pvt_small(pretrained=self.pretrain)
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)
        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)

    def forward(self, x):
        x = self.network(x)
        distribution = self.space(x)
        x = self.classifier(distribution)
        return distribution, x


class T2T_ViT(BaseNetwork):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = None
        self.feature = None
        self.pool = None
        self.dnn = None
        self.network = t2t_vit_14(pretrained=self.pretrain)
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)
        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)

    def forward(self, x):
        x = self.network(x)
        distribution = self.space(x)
        x = self.classifier(distribution)
        return distribution, x


class T2T_ViT_with_ASLoss(BaseNetwork_with_ASLoss):
    def __init__(self, config) -> None:
        super().__init__(config)
        network = None
        self.feature = None
        self.pool = None
        self.dnn = None
        self.network = t2t_vit_14(pretrained=self.pretrain)
        self.classifier = nn.Linear(self.spatial_dimension, self.class_nums)
        self.optimizer = torch.optim.SGD(params=self.parameters(),
                                         lr=self.learning_rate,
                                         weight_decay=5e-4,
                                         momentum=0.9)

    def forward(self, x):
        x = self.network(x)
        distribution = self.space(x)
        x = self.classifier(distribution)
        return distribution, x
