import os
import sys
sys.path.insert(0, './')

import torch
import torch.nn as nn
import pdb

from arch.Preprocess import DataNormalizeLayer
from arch.CNN import CNN
from arch.VGG import VGG11, VGG13, VGG16, VGG19
from arch.ResNet import ResNet18, ResNet34, ResNet50, ResNet101
from arch.WideResNet import WideResNet16, WideResNet22, WideResNet28
from arch.Swin import Swin_T, Swin_S, Swin_B
from arch.DNN import DNN_small


normalize_dict = {
    'mnist': {'bias': [0.5, 0.5, 0.5], 'scale': [0.5, 0.5, 0.5]},
    'svhn': {'bias': [0.4380, 0.4440, 0.4730], 'scale': [0.1751, 0.1771, 0.1744]},
    'cifar10': {'bias': [0.4914, 0.4822, 0.4465], 'scale': [0.2023, 0.1994, 0.2010]},
    'syn_cifar10': {'bias': [0.4914, 0.4822, 0.4465], 'scale': [0.2023, 0.1994, 0.2010]},
    'cifar100': {'bias': [0.5071, 0.4867, 0.4408], 'scale': [0.2675, 0.2565, 0.2761]},
    'cifar100_superclass': {'bias': [0.5071, 0.4867, 0.4408], 'scale': [0.2675, 0.2565, 0.2761]},
    'imagenet': {'bias': [0.485, 0.456, 0.406], 'scale': [0.229, 0.224, 0.225]},
    'imagenet100': {'bias': [0.485, 0.456, 0.406], 'scale': [0.229, 0.224, 0.225]},
    'tiny_imagenet': {'bias': [0.485, 0.456, 0.406], 'scale': [0.229, 0.224, 0.225]},
    'purchase100': {'bias': [0.0], 'scale': [1.0]},
}

num_classes_dict = {
    'mnist': 10,
    'cifar10': 10,
    'syn_cifar10': 10,
    'cifar100': 100,
    'cifar100_superclass': 20,
    'svhn': 10,
    'imagenet100': 100,
    'tiny_imagenet': 200,
    'imagenet': 1000,
    'adults': 2,
    'purchase100': 100
}

input_dim_dict = {
    'purchase100': 600,
    'adults': 104,
}

def parse_model(dataset, arch, normalize=True, **kwargs):
    assert dataset in ['cifar10', 'syn_cifar10', 'cifar100', 'cifar100_superclass', 'mnist', 'svhn', 'imagenet100', 'tiny_imagenet', 'imagenet', 'purchase100', 'adults'], 'Dataset not included!'

    num_classes = num_classes_dict[dataset]
    input_dim = input_dim_dict.get(dataset, None)

    if arch.lower() in ['resnet', 'resnet18']:
        net = ResNet18(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False)
    elif arch.lower() in ['resnet34']:
        net = ResNet34(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False)
    elif arch.lower() in ['resnet50']:
        net = ResNet50(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False)
    elif arch.lower() in ['resnet101']:
        net = ResNet101(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False)
    elif arch.lower() in ['resnet_pretrain', 'resnet18_pretrain']:
        net = ResNet18(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False, pretrain=True)
    elif arch.lower() in ['resnet34_pretrain']:
        net = ResNet34(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False, pretrain=True)
    elif arch.lower() in ['resnet50_pretrain']:
        net = ResNet50(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False, pretrain=True)
    elif arch.lower() in ['resnet101_pretrain']:
        net = ResNet101(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False, pretrain=True)
    elif arch.lower() in ['wideresnet28', 'wideresnet']:
        net = WideResNet28(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False)
    elif arch.lower() in ['wideresnet22']:
        net = WideResNet22(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False)
    elif arch.lower() in ['wideresnet16']:
        net = WideResNet16(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False)
    elif arch.lower() in ['swin_t']:
        net = Swin_T(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False, pretrained=False)
    elif arch.lower() in ['swin_s']:
        net = Swin_S(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False, pretrained=False)
    elif arch.lower() in ['swin_b']:
        net = Swin_B(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False, pretrained=False)
    elif arch.lower() in ['swin_t_pretrain']:
        net = Swin_T(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False, pretrained=True)
    elif arch.lower() in ['swin_s_pretrain']:
        net = Swin_S(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False, pretrained=True)
    elif arch.lower() in ['swin_b_pretrain']:
        net = Swin_B(num_classes = num_classes, imagenet=True if dataset in ['imagenet', 'imagenet100'] else False, pretrained=True)
    elif arch.lower() in ['cnn']:
        net = CNN(num_classes = num_classes)
    elif arch.lower() in ['vgg11', 'vgg']:
        net = VGG11(num_classes = num_classes)
    elif arch.lower() in ['vgg13']:
        net = VGG13(num_classes = num_classes)
    elif arch.lower() in ['vgg16']:
        net = VGG16(num_classes = num_classes)
    elif arch.lower() in ['vgg19']:
        net = VGG19(num_classes = num_classes)
    elif arch.lower() in ['mlp', 'mlp_small']:
        net = DNN_small(input_dim=input_dim, num_classes = num_classes)
    else:
        raise ValueError('Unrecognized architecture: %s' % arch)

    if normalize and arch.lower() not in ['mlp']:
        normalize_layer = DataNormalizeLayer(bias = normalize_dict[dataset]['bias'], scale = normalize_dict[dataset]['scale'])
        return nn.Sequential(normalize_layer, net)
    else:
        return nn.Sequential(net)