import random

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
from torchvision.models import resnet50

# np.random.seed(0)
# random.seed(0)
# torch.random.manual_seed(0)
# torch.manual_seed(0)
# torch.cuda.manual_seed_all(0)


class MyClassifier(nn.Module):
    def zero_one_loss(self, h, t, is_logistic=False):
        self.eval()
        positive = 1
        negative = 0 if is_logistic else -1

        n_p = (t == positive).sum()
        n_n = (t == negative).sum()
        size = n_p + n_n

        n_pp = (h == positive).sum()
        t_p = ((h == positive) * (t == positive)).sum()
        t_n = ((h == negative) * (t == negative)).sum()
        f_p = n_n - t_n
        f_n = n_p - t_p

        # print("size:{0},t_p:{1},t_n:{2},f_p:{3},f_n:{4}".format(
        #     size, t_p, t_n, f_p, f_n))

        presicion = (0.0 if t_p == 0 else t_p / (t_p + f_p))
        recall = (0.0 if t_p == 0 else t_p / (t_p + f_n))

        return presicion, recall, 1 - (t_p + t_n) / size, n_pp

    def error(self, DataLoader, is_logistic=False):
        targets_all = np.array([])
        prediction_all = np.array([])
        self.eval()
        for data, target in DataLoader:
            data = data.to(device, non_blocking=True)
            t = target.detach().cpu().numpy()
            size = len(t)
            if is_logistic:
                h = np.reshape(
                    torch.sigmoid(self(data)).detach().cpu().numpy(), size)
                h = np.where(h > 0.5, 1, 0).astype(np.int32)
            else:
                h = np.reshape(
                    torch.sign(self(data)).detach().cpu().numpy(), size)

            targets_all = np.hstack((targets_all, t))
            prediction_all = np.hstack((prediction_all, h))

        return self.zero_one_loss(prediction_all, targets_all, is_logistic)

    def evalution_with_density(self, DataLoader, prior):
        targets_all = np.array([])
        prediction_all = np.array([])
        self.eval()
        for data, target in DataLoader:
            data = data.to(device)
            t = target.detach().cpu().numpy()
            size = len(t)
            # get f_x
            h = np.reshape(self(data).detach().cpu().numpy(), size)
            # predict with density ratio and threshold
            h = self.predict_with_density_threshold(h, target, prior)

            targets_all = np.hstack((targets_all, t))
            prediction_all = np.hstack((prediction_all, h))

        return self.zero_one_loss(prediction_all, targets_all)

    def predict_with_density_threshold(self, f_x, target, prior):
        density_ratio = f_x / prior
        # ascending sort
        sorted_density_ratio = np.sort(density_ratio)
        size = len(density_ratio)

        n_pi = int(size * prior)
        # print("size: ", size)
        # print("density_ratio shape: ", density_ratio.shape)
        # print("n in test data: ", n_pi)
        # print("n in real data: ", (target == 1).sum())
        threshold = (sorted_density_ratio[size - n_pi] +
                     sorted_density_ratio[size - n_pi - 1]) / 2
        # print("threshold:", threshold)
        h = np.sign(density_ratio - threshold).astype(np.int32)
        return h


class LinearClassifier(MyClassifier, nn.Module):
    def __init__(self, dim):
        super(LinearClassifier, self).__init__()

        self.input_dim = dim
        self.l = nn.Linear(dim, 1)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        x = self.l(x)
        return x


class ThreeLayerPerceptron(MyClassifier, nn.Module):
    def __init__(self, dim):
        super(ThreeLayerPerceptron, self).__init__()

        self.input_dim = dim
        self.l1 = nn.Linear(dim, 100)
        self.l2 = nn.Linear(100, 1)

        self.af = F.relu

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        x = self.l1(x)
        x = self.af(x)
        x = self.l2(x)
        return x


# class FourLayerPerceptron(MyClassifier, nn.Module):
#     def __init__(self, dim):
#         super(FourLayerPerceptron, self).__init__()
#
#         # d-512-128-1
#         # d-500-500-1
#         self.input_dim = dim
#         self.l1 = nn.Linear(dim, 300, bias=False)
#         # self.b1 = nn.BatchNorm1d(300)
#         self.l2 = nn.Linear(300, 300, bias=False)
#         # self.b2 = nn.BatchNorm1d(300)
#         self.l3 = nn.Linear(300, 1)
#
#         self.af = F.softsign
#
#     def forward(self, x):
#         x = x.view(-1, self.input_dim)
#         h = self.l1(x)
#         h = self.b1(h)
#         h = self.af(h)
#         h = self.l2(h)
#         h = self.b2(h)
#         h = self.af(h)
#         h = self.l3(h)
#         return h


class FourLayerPerceptron(MyClassifier, nn.Module):
    def __init__(self, dim, act_func='softsign'):
        super(FourLayerPerceptron, self).__init__()

        self.input_dim = dim

        self.af = nn.Softsign()
        # self.af = nn.ReLU()

        # d-512-128-1
        # d-500-500-1
        self.l_list = [
            nn.Linear(dim, 300, bias=False),
            nn.Linear(300, 300, bias=False),
            # nn.Linear(300, 300, bias=False),
            # nn.Linear(300, 300, bias=False),
        ]

        self.b_list = [
            nn.BatchNorm1d(300),
            nn.BatchNorm1d(300),
            # nn.BatchNorm1d(300),
            # nn.BatchNorm1d(300),
        ]

        self.layers = nn.ModuleList(
            [nn.Sequential(self.l_list[i], self.af) for i in range(2)])

        self.classifier = nn.Linear(300, 1)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        h = x

        for i, layer_module in enumerate(self.layers):
            h = layer_module(h)

        # Classifier
        output = self.classifier(h)

        return output


class MixFourLayerPerceptron(MyClassifier, nn.Module):
    def __init__(self, dim, act_func='relu'):
        super(MixFourLayerPerceptron, self).__init__()

        self.input_dim = dim

        self.af = nn.Softsign()
        # self.af = nn.ReLU()

        # d-512-128-1
        # d-500-500-1
        self.l_list = [
            nn.Linear(dim, 300, bias=False),
            nn.Linear(300, 300, bias=False),
            # nn.Linear(300, 300, bias=False),
            # nn.Linear(300, 300, bias=False),
        ]

        self.b_list = [
            nn.BatchNorm1d(300),
            nn.BatchNorm1d(300),
            # nn.BatchNorm1d(300),
            # nn.BatchNorm1d(300),
        ]

        self.layers = nn.ModuleList(
            [nn.Sequential(self.l_list[i], self.af) for i in range(2)])

        self.classifier = nn.Linear(300, 1)

    def forward(self, x, x2=None, l=None, mix_layer=1000, flag_feature=False):
        x = x.view(-1, self.input_dim)
        if x2 is not None:
            x2 = x2.view(-1, self.input_dim)

        h, h2 = x, x2
        # Perform mix at till the mix_layer
        if mix_layer == -1:
            if h2 is not None:
                h = l * h + (1. - l) * h2

        for i, layer_module in enumerate(self.layers):
            if i <= mix_layer:
                h = layer_module(h)

                if h2 is not None:
                    h2 = layer_module(h2)

            if i == mix_layer:
                if h2 is not None:
                    h = l * h + (1. - l) * h2

            if i > mix_layer:
                h = layer_module(h)

        # Classifier
        output = self.classifier(h)

        if flag_feature:
            return output, h
        else:
            return output


class MultiLayerPerceptron(MyClassifier, nn.Module):
    def __init__(self, dim, act_func='relu'):
        super(MultiLayerPerceptron, self).__init__()

        self.input_dim = dim
        self.num_classifier = 1

        self.l1 = nn.Linear(dim, 300, bias=False)
        self.b1 = nn.BatchNorm1d(300)
        # self.b1 = nn.LayerNorm(300)
        self.l2 = nn.Linear(300, 300, bias=False)
        self.b2 = nn.BatchNorm1d(300)
        # self.b2 = nn.LayerNorm(300)
        self.l3 = nn.Linear(300, 300, bias=False)
        self.b3 = nn.BatchNorm1d(300)
        # self.b3 = nn.LayerNorm(300)
        self.l4 = nn.Linear(300, 300, bias=False)
        self.b4 = nn.BatchNorm1d(300)
        # self.b4 = nn.LayerNorm(300)
        self.l5 = nn.Linear(300, self.num_classifier)
        if act_func == 'relu':
            self.af = F.relu
        elif act_func == 'softsign':
            self.af = F.softsign

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        h = self.l1(x)
        h = self.b1(h)
        h = self.af(h)
        h = self.l2(h)
        h = self.b2(h)
        h = self.af(h)
        h = self.l3(h)
        h = self.b3(h)
        h = self.af(h)
        h = self.l4(h)
        h = self.b4(h)
        h = self.af(h)
        h = self.l5(h)
        return h


class MultiLayerPerceptron_labeled2(MyClassifier, nn.Module):
    def __init__(self, dim, act_func='relu'):
        super(MultiLayerPerceptron_labeled2, self).__init__()

        self.input_dim = dim
        self.num_classifier = 1

        self.l1 = nn.Linear(dim, 300, bias=False)
        # self.b1 = nn.BatchNorm1d(300)
        self.b1 = nn.LayerNorm(300)
        self.l2 = nn.Linear(300, 300, bias=False)
        # self.b2 = nn.BatchNorm1d(300)
        self.b2 = nn.LayerNorm(300)
        self.l3 = nn.Linear(300, 300, bias=False)
        # self.b3 = nn.BatchNorm1d(300)
        self.b3 = nn.LayerNorm(300)
        self.l4 = nn.Linear(300, 300, bias=False)
        # self.b4 = nn.BatchNorm1d(300)
        self.b4 = nn.LayerNorm(300)
        self.l5 = nn.Linear(300, self.num_classifier)
        if act_func == 'relu':
            self.af = F.relu
        elif act_func == 'softsign':
            self.af = F.softsign

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        h = self.l1(x)
        h = self.b1(h)
        h = self.af(h)
        h = self.l2(h)
        h = self.b2(h)
        h = self.af(h)
        h = self.l3(h)
        h = self.b3(h)
        h = self.af(h)
        h = self.l4(h)
        h = self.b4(h)
        h = self.af(h)
        h = self.l5(h)
        return h


class MixMultiLayerPerceptron(MyClassifier, nn.Module):
    def __init__(self, dim, act_func='relu'):
        super(MixMultiLayerPerceptron, self).__init__()

        self.input_dim = dim
        self.num_classifier = 1

        if act_func == 'relu':
            self.af = nn.ReLU()
        elif act_func == 'softsign':
            self.af = nn.Softsign()

        self.l_list = [
            nn.Linear(dim, 300, bias=False),
            nn.Linear(300, 300, bias=False),
            nn.Linear(300, 300, bias=False),
            nn.Linear(300, 300, bias=False),
        ]

        self.b_list = [
            nn.BatchNorm1d(300),
            nn.BatchNorm1d(300),
            nn.BatchNorm1d(300),
            nn.BatchNorm1d(300),
        ]

        self.layers = nn.ModuleList([
            nn.Sequential(self.l_list[i], self.b_list[i], self.af)
            for i in range(4)
        ])

        self.classifier = nn.Linear(300, 1)

    def forward(self, x, x2=None, l=None, mix_layer=1000, flag_feature=False):
        x = x.view(-1, self.input_dim)
        if x2 is not None:
            x2 = x2.view(-1, self.input_dim)

        h, h2 = x, x2
        # Perform mix at till the mix_layer
        if mix_layer == -1:
            if h2 is not None:
                h = l * h + (1. - l) * h2

        for i, layer_module in enumerate(self.layers):
            if i <= mix_layer:
                h = layer_module(h)

                if h2 is not None:
                    h2 = layer_module(h2)

            if i == mix_layer:
                if h2 is not None:
                    h = l * h + (1. - l) * h2

            if i > mix_layer:
                h = layer_module(h)

        # Classifier
        output = self.classifier(h)

        if flag_feature:
            return output, h
        else:
            return output


class LeNet(MyClassifier, nn.Module):
    def __init__(self, dim):
        super(LeNet, self).__init__()

        self.input_dim = dim

        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.conv3 = nn.Conv2d(16, 120, kernel_size=5)
        self.bn_conv1 = nn.BatchNorm2d(6)
        self.bn_conv2 = nn.BatchNorm2d(16)
        self.mp = nn.MaxPool2d(2)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(120, 84)
        self.bn_fc1 = nn.BatchNorm1d(84)

        self.layer1 = nn.Sequential(self.conv1, self.mp, self.relu)
        self.layer2 = nn.Sequential(self.conv2, self.mp, self.relu)
        self.layer3 = nn.Sequential(self.conv3, self.relu)

        self.layers = nn.ModuleList([self.layer1, self.layer2, self.layer3])

        self.layer4 = nn.Sequential(self.fc1, self.bn_fc1, self.relu)
        self.classifier = nn.Linear(84, 1)

    def forward(self, x):
        h = x
        for i, layer_module in enumerate(self.layers):
            h = layer_module(h)

        h = h.view(h.size(0), -1)
        h = self.layer4(h)
        h = self.classifier(h)
        return h


class MixLeNet(MyClassifier, nn.Module):
    def __init__(self, dim):
        super(MixLeNet, self).__init__()

        self.input_dim = dim

        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.conv3 = nn.Conv2d(16, 120, kernel_size=5)
        self.bn_conv1 = nn.BatchNorm2d(6)
        self.bn_conv2 = nn.BatchNorm2d(16)
        self.mp = nn.MaxPool2d(2)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(120, 84)
        self.bn_fc1 = nn.BatchNorm1d(84)

        self.layer1 = nn.Sequential(self.conv1, self.mp, self.relu)
        self.layer2 = nn.Sequential(self.conv2, self.mp, self.relu)
        self.layer3 = nn.Sequential(self.conv3, self.relu)

        self.layers = nn.ModuleList([self.layer1, self.layer2, self.layer3])

        self.layer4 = nn.Sequential(self.fc1, self.bn_fc1, self.relu)
        self.classifier = nn.Linear(84, 1)

    def forward(self, x, x2=None, l=None, mix_layer=1000, flag_feature=False):
        h, h2 = x, x2
        if mix_layer == -1:
            if h2 is not None:
                h = l * h + (1. - l) * h2

        for i, layer_module in enumerate(self.layers):
            if i <= mix_layer:
                h = layer_module(h)

                if h2 is not None:
                    h2 = layer_module(h2)

            if i == mix_layer:
                if h2 is not None:
                    h = l * h + (1. - l) * h2

            if i > mix_layer:
                h = layer_module(h)

        h_ = h.view(h.size(0), -1)
        h_ = self.layer4(h_)
        h = self.classifier(h_)

        if flag_feature:
            return h, h_
        else:
            return h


class CNNSTL(MyClassifier, nn.Module):
    def __init__(self, dim):
        super(CNNSTL, self).__init__()

        self.input_dim = dim

        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(3, 6, 3)
        self.conv2 = nn.Conv2d(6, 6, 3)
        self.mp = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(6, 16, 5)
        self.conv4 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(32 * 8 * 8, 120)
        self.fc2 = nn.Linear(120, 84)

        # self.m = nn.Dropout2d(0.2)
        # self.n = nn.Dropout(0.2)
        # self.b1 = nn.BatchNorm2d(6)
        # self.b2 = nn.BatchNorm2d(16)
        # self.b3 = nn.BatchNorm1d(120)
        # self.b4 = nn.BatchNorm1d(84)

        self.layer1 = nn.Sequential(self.conv1, self.relu, self.mp)
        self.layer2 = nn.Sequential(self.conv2, self.relu)
        self.layer3 = nn.Sequential(self.conv3, self.relu, self.mp)
        self.layer4 = nn.Sequential(self.conv4, self.relu, self.mp)

        self.layers = nn.ModuleList(
            [self.layer1, self.layer2, self.layer3, self.layer4])

        self.layer5 = nn.Sequential(self.fc1, self.relu, self.fc2, self.relu)

        self.classifier = nn.Linear(84, 1)

    def forward(self, x):
        h = x
        for i, layer_module in enumerate(self.layers):
            h = layer_module(h)

        h = h.view(h.size(0), -1)
        h = self.layer5(h)
        h = self.classifier(h)
        return h


class MixCNNSTL(MyClassifier, nn.Module):
    def __init__(self, dim):
        super(MixCNNSTL, self).__init__()

        self.input_dim = dim

        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(3, 6, 3)
        self.conv2 = nn.Conv2d(6, 6, 3)
        self.mp = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(6, 16, 5)
        self.conv4 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(32 * 8 * 8, 120)
        self.fc2 = nn.Linear(120, 84)

        # self.m = nn.Dropout2d(0.2)
        # self.n = nn.Dropout(0.2)
        # self.b1 = nn.BatchNorm2d(6)
        # self.b2 = nn.BatchNorm2d(16)
        # self.b3 = nn.BatchNorm1d(120)
        # self.b4 = nn.BatchNorm1d(84)

        self.layer1 = nn.Sequential(self.conv1, self.relu, self.mp)
        self.layer2 = nn.Sequential(self.conv2, self.relu)
        self.layer3 = nn.Sequential(self.conv3, self.relu, self.mp)
        self.layer4 = nn.Sequential(self.conv4, self.relu, self.mp)

        self.layers = nn.ModuleList(
            [self.layer1, self.layer2, self.layer3, self.layer4])

        self.layer5 = nn.Sequential(self.fc1, self.relu, self.fc2, self.relu)

        self.classifier = nn.Linear(84, 1)

    def forward(self, x, x2=None, l=None, mix_layer=1000, flag_feature=False):
        h, h2 = x, x2
        if mix_layer == -1:
            if h2 is not None:
                h = l * h + (1. - l) * h2

        for i, layer_module in enumerate(self.layers):
            if i <= mix_layer:
                h = layer_module(h)

                if h2 is not None:
                    h2 = layer_module(h2)

            if i == mix_layer:
                if h2 is not None:
                    h = l * h + (1. - l) * h2

            if i > mix_layer:
                h = layer_module(h)

        h_ = h.view(h.size(0), -1)
        h_ = self.layer5(h_)
        h = self.classifier(h_)

        if flag_feature:
            return h, h_
        else:
            return h


class CNN7(MyClassifier, nn.Module):
    def __init__(self, dim):
        super(CNN7, self).__init__()

        self.input_dim = dim

        self.af = nn.ReLU()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(in_channels=96, out_channels=192, kernel_size=5, padding=2)
        self.conv3 = nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=192, out_channels=192, kernel_size=1)
        self.conv5 = nn.Conv2d(in_channels=192, out_channels=10, kernel_size=1)
        self.mp1 = nn.MaxPool2d(kernel_size=2)
        self.mp2 = nn.MaxPool2d(kernel_size=2)
        self.b1 = nn.BatchNorm2d(96)
        self.b2 = nn.BatchNorm2d(192)
        self.b3 = nn.BatchNorm2d(192)
        self.b4 = nn.BatchNorm2d(192)
        self.b5 = nn.BatchNorm2d(10)
        self.fc1 = nn.Linear(8 * 8 * 10, 100)
        self.fc2 = nn.Linear(100, 1)

        self.layer1 = nn.Sequential(self.conv1, self.mp1, self.b1, self.af)
        self.layer2 = nn.Sequential(self.conv2, self.mp2, self.b2, self.af)
        self.layer3 = nn.Sequential(self.conv3, self.b3, self.af)
        self.layer4 = nn.Sequential(self.conv4, self.b4, self.af)
        self.layer5 = nn.Sequential(self.conv5, self.b5, self.af)

        self.layers = nn.ModuleList(
            [self.layer1, self.layer2, self.layer3, self.layer4, self.layer5])

        self.layer6 = nn.Sequential(self.fc1, self.af)

    def forward(self, x):
        h = x
        for i, layer_module in enumerate(self.layers):
            h = layer_module(h)

        h = h.view(h.size(0), -1)
        h = self.layer6(h)
        h = self.fc2(h)
        return h


class MixCNN7(MyClassifier, nn.Module):
    def __init__(self, dim):
        super(MixCNNCIFAR, self).__init__()

        self.input_dim = dim

        self.af = nn.ReLU()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(in_channels=96, out_channels=192, kernel_size=5, padding=2)
        self.conv3 = nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=192, out_channels=192, kernel_size=1)
        self.conv5 = nn.Conv2d(in_channels=192, out_channels=10, kernel_size=1)
        self.mp1 = nn.MaxPool2d(kernel_size=2)
        self.mp2 = nn.MaxPool2d(kernel_size=2)
        self.b1 = nn.BatchNorm2d(96)
        self.b2 = nn.BatchNorm2d(192)
        self.b3 = nn.BatchNorm2d(192)
        self.b4 = nn.BatchNorm2d(192)
        self.b5 = nn.BatchNorm2d(10)
        self.fc1 = nn.Linear(8 * 8 * 10, 100)
        self.fc2 = nn.Linear(100, 1)

        self.layer1 = nn.Sequential(self.conv1, self.mp1, self.b1, self.af)
        self.layer2 = nn.Sequential(self.conv2, self.mp2, self.b2, self.af)
        self.layer3 = nn.Sequential(self.conv3, self.b3, self.af)
        self.layer4 = nn.Sequential(self.conv4, self.b4, self.af)
        self.layer5 = nn.Sequential(self.conv5, self.b5, self.af)

        self.layers = nn.ModuleList(
            [self.layer1, self.layer2, self.layer3, self.layer4, self.layer5])

        self.layer6 = nn.Sequential(self.fc1, self.af)

    def forward(self, x, x2=None, l=None, mix_layer=1000, flag_feature=False):
        h, h2 = x, x2
        # Perform mix at till the mix_layer
        if mix_layer == -1:
            if h2 is not None:
                h = l * h + (1. - l) * h2

        for i, layer_module in enumerate(self.layers):
            if i <= mix_layer:
                h = layer_module(h)

                if h2 is not None:
                    h2 = layer_module(h2)

            if i == mix_layer:
                if h2 is not None:
                    h = l * h + (1. - l) * h2

            if i > mix_layer:
                h = layer_module(h)

        # Classifier
        h_ = h.view(h.size(0), -1)
        h_ = self.layer6(h_)
        h = self.fc2(h_)

        if flag_feature:
            return h, h_
        else:
            return h


class CNNCIFAR(MyClassifier, nn.Module):
    def __init__(self, dim):
        super(CNNCIFAR, self).__init__()

        self.af = F.relu
        self.input_dim = dim

        self.conv1 = nn.Conv2d(3, 96, 3)
        self.conv2 = nn.Conv2d(96, 96, 3, stride=2)
        self.conv3 = nn.Conv2d(96, 192, 1)
        self.conv4 = nn.Conv2d(192, 10, 1)
        self.fc1 = nn.Linear(1960, 1000)
        self.fc2 = nn.Linear(1000, 1000)
        self.fc3 = nn.Linear(1000, 1)

    def forward(self, x):
        h = self.conv1(x)
        h = self.af(h)
        h = self.conv2(h)
        h = self.af(h)
        h = self.conv3(h)
        h = self.af(h)
        h = self.conv4(h)
        h = self.af(h)

        h = h.view(h.size(0), -1)
        h = self.fc1(h)
        h = self.af(h)
        h = self.fc2(h)
        h = self.af(h)
        h = self.fc3(h)
        return h


class MixCNNCIFAR(MyClassifier, nn.Module):
    def __init__(self, dim):
        super(MixCNNCIFAR, self).__init__()

        self.num_classifier = 1

        self.af = nn.ReLU()
        self.input_dim = dim

        self.conv_list = [
            nn.Conv2d(3, 96, 3),
            nn.Conv2d(96, 96, 3, stride=2),
            nn.Conv2d(96, 192, 1),
            nn.Conv2d(192, 10, 1),
        ]
        self.fc1 = nn.Linear(1960, 1000)
        self.fc2 = nn.Linear(1000, 1000)
        self.fc3 = nn.Linear(1000, 1)

        self.layers = nn.ModuleList([
            nn.Sequential(self.conv_list[i], self.af)
            for i in range(4)
        ])

        self.classifier1 = nn.Sequential(
            self.fc1,
            self.af,
            self.fc2,
            self.af,
        )

    def forward(self, x, x2=None, l=None, mix_layer=1000, flag_feature=False):
        h, h2 = x, x2
        # Perform mix at till the mix_layer
        if mix_layer == -1:
            if h2 is not None:
                h = l * h + (1. - l) * h2

        for i, layer_module in enumerate(self.layers):
            if i <= mix_layer:
                h = layer_module(h)

                if h2 is not None:
                    h2 = layer_module(h2)

            if i == mix_layer:
                if h2 is not None:
                    h = l * h + (1. - l) * h2

            if i > mix_layer:
                h = layer_module(h)

        # Classifier
        h_ = h.view(h.size(0), -1)
        h_ = self.classifier1(h_)
        h = self.fc3(h_)

        if flag_feature:
            return h, h_
        else:
            return h


class CNN(MyClassifier, nn.Module):
    def __init__(self, dim):
        super(CNN, self).__init__()

        self.af = F.relu
        self.input_dim = dim
        self.num_classifier = 1

        self.conv1 = nn.Conv2d(3, 96, 3, padding=1)
        self.conv2 = nn.Conv2d(96, 96, 3, padding=1)
        self.conv3 = nn.Conv2d(96, 96, 3, padding=1, stride=2)
        self.conv4 = nn.Conv2d(96, 192, 3, padding=1)
        self.conv5 = nn.Conv2d(192, 192, 3, padding=1)
        self.conv6 = nn.Conv2d(192, 192, 3, padding=1, stride=2)
        self.conv7 = nn.Conv2d(192, 192, 3, padding=1)
        self.conv8 = nn.Conv2d(192, 192, 1)
        self.conv9 = nn.Conv2d(192, 10, 1)
        self.b1 = nn.BatchNorm2d(96)
        self.b2 = nn.BatchNorm2d(96)
        self.b3 = nn.BatchNorm2d(96)
        self.b4 = nn.BatchNorm2d(192)
        self.b5 = nn.BatchNorm2d(192)
        self.b6 = nn.BatchNorm2d(192)
        self.b7 = nn.BatchNorm2d(192)
        self.b8 = nn.BatchNorm2d(192)
        self.b9 = nn.BatchNorm2d(10)
        self.fc1 = nn.Linear(640, 1000)
        self.fc2 = nn.Linear(1000, 1000)
        self.fc3 = nn.Linear(1000, self.num_classifier)

    def forward(self, x):
        h = self.conv1(x)
        h = self.b1(h)
        h = self.af(h)
        h = self.conv2(h)
        h = self.b2(h)
        h = self.af(h)
        h = self.conv3(h)
        h = self.b3(h)
        h = self.af(h)
        h = self.conv4(h)
        h = self.b4(h)
        h = self.af(h)
        h = self.conv5(h)
        h = self.b5(h)
        h = self.af(h)
        h = self.conv6(h)
        h = self.b6(h)
        h = self.af(h)
        h = self.conv7(h)
        h = self.b7(h)
        h = self.af(h)
        h = self.conv8(h)
        h = self.b8(h)
        h = self.af(h)
        h = self.conv9(h)
        h = self.b9(h)
        h = self.af(h)

        h = h.view(h.size(0), -1)
        h = self.fc1(h)
        h = self.af(h)
        h = self.fc2(h)
        h = self.af(h)
        h = self.fc3(h)
        return h


class MixCNN(MyClassifier, nn.Module):
    def __init__(self, dim):
        super(MixCNN, self).__init__()

        self.af = nn.ReLU()
        self.input_dim = dim

        self.conv_list = [
            nn.Conv2d(3, 96, 3, padding=1),
            nn.Conv2d(96, 96, 3, padding=1),
            nn.Conv2d(96, 96, 3, padding=1, stride=2),
            nn.Conv2d(96, 192, 3, padding=1),
            nn.Conv2d(192, 192, 3, padding=1),
            nn.Conv2d(192, 192, 3, padding=1, stride=2),
            nn.Conv2d(192, 192, 3, padding=1),
            nn.Conv2d(192, 192, 1),
            nn.Conv2d(192, 10, 1),
        ]
        self.b_list = [
            nn.BatchNorm2d(96),
            nn.BatchNorm2d(96),
            nn.BatchNorm2d(96),
            nn.BatchNorm2d(192),
            nn.BatchNorm2d(192),
            nn.BatchNorm2d(192),
            nn.BatchNorm2d(192),
            nn.BatchNorm2d(192),
            nn.BatchNorm2d(10),
        ]
        self.fc1 = nn.Linear(640, 1000)
        self.fc2 = nn.Linear(1000, 1000)
        self.fc3 = nn.Linear(1000, 1)

        self.layers = nn.ModuleList([
            nn.Sequential(self.conv_list[i], self.b_list[i], self.af)
            for i in range(9)
        ])

        self.classifier1 = nn.Sequential(
            self.fc1,
            self.af,
            self.fc2,
            self.af,
        )

    def forward(self, x, x2=None, l=None, mix_layer=1000, flag_feature=False):
        h, h2 = x, x2
        # Perform mix at till the mix_layer
        if mix_layer == -1:
            if h2 is not None:
                h = l * h + (1. - l) * h2

        for i, layer_module in enumerate(self.layers):
            if i <= mix_layer:
                h = layer_module(h)

                if h2 is not None:
                    h2 = layer_module(h2)

            if i == mix_layer:
                if h2 is not None:
                    h = l * h + (1. - l) * h2

            if i > mix_layer:
                h = layer_module(h)

        # Classifier
        h_ = h.view(h.size(0), -1)
        h_ = self.classifier1(h_)
        h = self.fc3(h_)

        if flag_feature:
            return h, h_
        else:
            return h


class TextClassificationModel(MyClassifier, nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)


class CNN_Text(MyClassifier, nn.Module):
    def __init__(self,
                 pretrained_embedding=None,
                 freeze_embedding=False,
                 vocab_size=None,
                 embedding_dim=300,
                 n_filters=100,
                 filter_sizes=[3, 4, 5],
                 n_classes=1,
                 dropout=0.5,
                 pad_idx=0):
        """
        The constructor for CNN_NLP class.

        Args:
            pretrained_embedding (torch.Tensor): Pretrained embeddings with
                shape (vocab_size, embedding_dim)
            freeze_embedding (bool): Set to False to fine-tune pretraiend
                vectors. Default: False
            vocab_size (int): Need to be specified when not pretrained word
                embeddings are not used.
            embedding_dim (int): Dimension of word vectors. Need to be specified
                when pretrained word embeddings are not used. Default: 300
            n_filters (List[int]): List of number of filters, has the same
                length as `filter_sizes`. Default: [100, 100, 100]
            filter_sizes (List[int]): List of filter sizes. Default: [3, 4, 5]

            n_classes (int): Number of classes. Default: 1
            dropout (float): Dropout rate. Default: 0.5
        """
        super(CNN_Text, self).__init__()

        # Embedding layer
        if pretrained_embedding is not None:
            self.vocab_size, self.embedding_dim = pretrained_embedding.shape
            self.embedding = nn.Embedding.from_pretrained(
                pretrained_embedding, freeze=freeze_embedding)
        else:
            self.embedding_dim = embedding_dim
            self.embedding = nn.Embedding(num_embeddings=vocab_size,
                                          embedding_dim=self.embedding_dim,
                                          padding_idx=pad_idx,
                                          max_norm=5.0)

        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels=1,
                      out_channels=n_filters,
                      kernel_size=(fs, self.embedding_dim))
            for fs in filter_sizes
        ])
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(len(filter_sizes) * n_filters, n_classes)

    def forward(self, text):
        #text = [batch size, sent len]

        embedded = self.embedding(text)

        #embedded = [batch size, sent len, emb dim]

        embedded = embedded.unsqueeze(1)

        #embedded = [batch size, 1, sent len, emb dim]

        conved_n = [F.relu(conv(embedded).squeeze(3)) for conv in self.convs]

        #conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]

        pooled_n = [
            F.max_pool1d(conved_i, conved_i.shape[2]).squeeze(2)
            for conved_i in conved_n
        ]

        #pooled_n = [batch size, n_filters]

        cat = self.dropout(torch.cat(pooled_n, dim=1))

        #cat = [batch size, n_filters * len(filter_sizes)]

        return self.fc(cat)


class CNN1d_Text(MyClassifier, nn.Module):
    def __init__(self,
                 pretrained_embedding=None,
                 freeze_embedding=False,
                 vocab_size=None,
                 embedding_dim=300,
                 n_filters=100,
                 filter_sizes=[3, 4, 5],
                 n_classes=1,
                 dropout=0.5,
                 pad_idx=0):
        """
        The constructor for CNN_NLP class.

        Args:
            pretrained_embedding (torch.Tensor): Pretrained embeddings with
                shape (vocab_size, embedding_dim)
            freeze_embedding (bool): Set to False to fine-tune pretraiend
                vectors. Default: False
            vocab_size (int): Need to be specified when not pretrained word
                embeddings are not used.
            embedding_dim (int): Dimension of word vectors. Need to be specified
                when pretrained word embeddings are not used. Default: 300
            n_filters (List[int]): List of number of filters, has the same
                length as `filter_sizes`. Default: [100, 100, 100]
            filter_sizes (List[int]): List of filter sizes. Default: [3, 4, 5]

            n_classes (int): Number of classes. Default: 1
            dropout (float): Dropout rate. Default: 0.5
        """
        super(CNN1d_Text, self).__init__()

        # Embedding layer
        if pretrained_embedding is not None:
            self.vocab_size, self.embedding_dim = pretrained_embedding.shape
            self.embedding = nn.Embedding.from_pretrained(
                pretrained_embedding, freeze=freeze_embedding)
        else:
            self.embedding_dim = embedding_dim
            self.embedding = nn.Embedding(num_embeddings=vocab_size,
                                          embedding_dim=self.embedding_dim,
                                          padding_idx=pad_idx,
                                          max_norm=5.0)

        self.convs = nn.ModuleList([
            nn.Conv1d(in_channels=self.embedding_dim,
                      out_channels=n_filters,
                      kernel_size=fs) for fs in filter_sizes
        ])

        self.fc = nn.Linear(len(filter_sizes) * n_filters, n_classes)

        self.dropout = nn.Dropout(dropout)

    def forward(self, text):

        #text = [batch size, sent len]

        embedded = self.embedding(text)

        #embedded = [batch size, sent len, emb dim]

        embedded = embedded.permute(0, 2, 1)

        #embedded = [batch size, emb dim, sent len]

        conved = [F.relu(conv(embedded)) for conv in self.convs]

        #conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]

        pooled = [
            F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved
        ]

        #pooled_n = [batch size, n_filters]

        cat = self.dropout(torch.cat(pooled, dim=1))

        #cat = [batch size, n_filters * len(filter_sizes)]

        return self.fc(cat)


# class MixCNN1d_Text(MyClassifier, nn.Module):
#     def __init__(self,
#                  pretrained_embedding=None,
#                  freeze_embedding=False,
#                  vocab_size=None,
#                  embedding_dim=300,
#                  n_filters=100,
#                  filter_sizes=[3, 4, 5],
#                  n_classes=1,
#                  dropout=0.5,
#                  pad_idx=0):
#         """
#         The constructor for CNN_NLP class.

#         Args:
#             pretrained_embedding (torch.Tensor): Pretrained embeddings with
#                 shape (vocab_size, embedding_dim)
#             freeze_embedding (bool): Set to False to fine-tune pretraiend
#                 vectors. Default: False
#             vocab_size (int): Need to be specified when not pretrained word
#                 embeddings are not used.
#             embedding_dim (int): Dimension of word vectors. Need to be specified
#                 when pretrained word embeddings are not used. Default: 300
#             n_filters (List[int]): List of number of filters, has the same
#                 length as `filter_sizes`. Default: [100, 100, 100]
#             filter_sizes (List[int]): List of filter sizes. Default: [3, 4, 5]

#             n_classes (int): Number of classes. Default: 1
#             dropout (float): Dropout rate. Default: 0.5
#         """
#         super(MixCNN1d_Text, self).__init__()

#         # Embedding layer
#         if pretrained_embedding is not None:
#             self.vocab_size, self.embedding_dim = pretrained_embedding.shape
#             self.embedding = nn.Embedding.from_pretrained(
#                 pretrained_embedding, freeze=freeze_embedding)
#         else:
#             self.embedding_dim = embedding_dim
#             self.embedding = nn.Embedding(num_embeddings=vocab_size,
#                                           embedding_dim=self.embedding_dim,
#                                           padding_idx=pad_idx,
#                                           max_norm=5.0)

#         self.convs = nn.ModuleList([
#             nn.Conv1d(in_channels=self.embedding_dim,
#                       out_channels=n_filters,
#                       kernel_size=fs) for fs in filter_sizes
#         ])

#         self.fc = nn.Linear(len(filter_sizes) * n_filters, n_classes)

#         self.dropout = nn.Dropout(dropout)

#     def forward(self,
#                 text,
#                 text2=None,
#                 l=None,
#                 mix_layer=1000,
#                 flag_feature=False):

#         #text = [batch size, sent len]

#         embedded = self.embedding(text)

#         #embedded = [batch size, sent len, emb dim]

#         embedded = embedded.permute(0, 2, 1)

#         #embedded = [batch size, emb dim, sent len]

#         conved = [F.relu(conv(embedded)) for conv in self.convs]

#         #conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]

#         pooled = [
#             F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved
#         ]

#         #pooled_n = [batch size, n_filters]

#         cat = self.dropout(torch.cat(pooled, dim=1))

#         #cat = [batch size, n_filters * len(filter_sizes)]

#         return self.fc(cat)


class MLP_Text(MyClassifier, nn.Module):
    def __init__(self,
                 pretrained_embedding=None,
                 freeze_embedding=True,
                 vocab_size=None,
                 embedding_dim=300,
                 hidden_size=300,
                 n_classes=1,
                 pad_idx=0):
        """
        The constructor for CNN_NLP class.

        Args:
            pretrained_embedding (torch.Tensor): Pretrained embeddings with
                shape (vocab_size, embedding_dim)
            freeze_embedding (bool): Set to False to fine-tune pretraiend
                vectors. Default: False
            vocab_size (int): Need to be specified when not pretrained word
                embeddings are not used.
            embedding_dim (int): Dimension of word vectors. Need to be specified
                when pretrained word embeddings are not used. Default: 300
            n_filters (List[int]): List of number of filters, has the same
                length as `filter_sizes`. Default: [100, 100, 100]
            filter_sizes (List[int]): List of filter sizes. Default: [3, 4, 5]

            n_classes (int): Number of classes. Default: 1
            dropout (float): Dropout rate. Default: 0.5
        """
        super(MLP_Text, self).__init__()

        # Embedding layer
        if pretrained_embedding is not None:
            self.vocab_size, self.embedding_dim = pretrained_embedding.shape
            self.embedding = nn.Embedding.from_pretrained(
                pretrained_embedding, freeze=freeze_embedding)
        else:
            self.embedding_dim = embedding_dim
            self.embedding = nn.Embedding(num_embeddings=vocab_size,
                                          embedding_dim=self.embedding_dim,
                                          padding_idx=pad_idx,
                                          max_norm=5.0)
        self.af = nn.Softsign()
        # self.af = nn.ReLU()

        self.avgpool = nn.AdaptiveAvgPool1d(16 * hidden_size)
        self.fc1 = nn.Linear(16 * hidden_size, 4 * hidden_size)
        self.bn1 = nn.BatchNorm1d(4 * hidden_size)
        self.fc2 = nn.Linear(4 * hidden_size, hidden_size)
        self.bn2 = nn.BatchNorm1d(hidden_size)

        self.fc = nn.Linear(hidden_size, n_classes)

    def forward(self, text):

        #text = [batch size, sent len]

        embedded = self.embedding(text)
        embedded = embedded.detach()

        #embedded = [batch size, sent len, emb dim]

        out = embedded.view(1, embedded.size()[0], -1)

        #out = [1, batch size, emb dim * sent len]

        out = self.avgpool(out)
        out = out.squeeze(0)
        out = self.fc1(out)
        out = self.bn1(out)
        out = self.af(out)
        out = self.fc2(out)
        out = self.bn2(out)
        out = self.af(out)

        return self.fc(out)


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(MyClassifier, nn.Module):
    def __init__(self, block, num_blocks, num_classes=1, input_shape=None):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.num_classifier = num_classes

        # Calculate the final feature dimension based on input shape
        if input_shape is None:
            input_shape = (128, 128)  # Default size if not specified

        # Initial convolutional layer
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        # ResNet layers
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        # Calculate final feature map size
        self.feature_size = self._calculate_feature_size(input_shape)

        # Adaptive pooling to handle variable input sizes
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        # Final classifier
        self.classifier = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def _calculate_feature_size(self, input_shape):
        # Helper function to calculate feature map size after convolutions
        height, width = input_shape
        # Account for initial conv and all ResNet layers
        for i in range(4):  # 4 layer groups in ResNet
            height = (height + 1) // 2
            width = (width + 1) // 2
        return height * width * 512 * Bottleneck.expansion

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out


def ResNet50(dim=None):
    # If dim is provided, calculate the input shape
    if dim is not None:
        # Assume dim is total number of pixels
        input_channels = 3
        side_length = int(np.sqrt(dim / input_channels))
        input_shape = (side_length, side_length)
    else:
        input_shape = (128, 128)  # default size

    return ResNet(Bottleneck, [3, 4, 6, 3], input_shape=input_shape)