import torch
import torch.nn as nn
from .resnet_cifar import *
from .densenet_cifar import *
from .vgg_cifar import *
from .mobilenetv2_cifar import *
from .blocks import *

__all__ = ['OurNet']

class OurNet(nn.Module):
    def __init__(self, dml_arch, dataset, depth, num_branches, bottleneck=False, se=False):
        super(OurNet, self).__init__()
        self.num_branches = num_branches
        if dataset == 'cifar10':
            num_classes = 10
        elif dataset == 'cifar100':
            num_classes = 100
        else:
            num_classes = 1000
        self.num_classes = num_classes

        for i in range(self.num_branches+1):
            if dml_arch.startswith('ResNet') and dataset.startswith('cifar'):
                setattr(self, 'stu'+str(i), call_ResNet(dataset=dataset, depth=depth, bottleneck=bottleneck, se=se, KD=True))
            elif dml_arch.startswith('DenseNet') and dataset.startswith('cifar'):
                setattr(self, 'stu'+str(i), densenet40k12(dataset=dataset, KD=True))
            elif dml_arch.startswith('VGG') and dataset.startswith('cifar'):
                setattr(self, 'stu'+str(i), call_VGG(dataset=dataset, depth=depth, KD=True))
            elif dml_arch.startswith('MobileNetV2') and dataset.startswith('cifar'):
                setattr(self, 'stu'+str(i), call_MobileNetV2(dataset=dataset, KD=True))
            elif dml_arch.startswith('ResNet') and dataset.startswith('imagenet'):
                raise NotImplementedError
            else:
                raise ValueError('Select which model what you wanna train!')

    def forward(self, x):
        _, pro = self.stu0(x)
        pro = pro.unsqueeze(-1)

        for i in range(1, self.num_branches):
            _, temp_pro = getattr(self, 'stu'+str(i))(x)
            temp_pro = temp_pro.unsqueeze(-1)
            pro = torch.cat([pro, temp_pro], -1)

        _, stu = getattr(self, 'stu'+str(self.num_branches))(x)
        
        return pro, stu, pro

if __name__ == '__main__':
    model = OurNet(dml_arch='VGG', dataset='cifar100', depth=16, num_branches=3)
    print(model)
    x = torch.randn(2, 3, 32, 32)
    en, stu = model(x)
    print(en.size())
