import os
import torch.nn as nn
from torchvision import models

from algs.utils.conv import *
from utils.func import model_load

def get_model(args):
    models = {}
    noise_model = {}    


    if 'train' in args.job:
        args.log.info('Model Generate ...')
    
        # Noise part
        if args.denoise == 'True':
            if 'noise' in args.job:
                args.log.info('Train noise model ...')
                noise_model = get_classifier(args)
            else:
                if os.path.isfile(args.out_path+'model_noise_end.pt'):
                    args.log.info("Load noise model ...")
                    noise_model = model_load(args.out_path,'noise_end')
                else:
                    args.log.info('No noise model... Train start ...')
                    args.job += '_noise'
                    noise_model = get_classifier(args)
        

                    
        # VANILLA
        if args.alg == 'vanilla':
            if 'final' in args.job:
                models['final'] = get_classifier(args)
        # LfF
        elif args.alg == 'lff':
            if 'bias' in args.job:
                models['bias'] = get_classifier(args)
            if 'final' in args.job:
                models['final'] = get_classifier(args)
        # AFLite
        elif args.alg == 'aflite':
            if 'final' in args.job:
                models['final'] = get_classifier(args)
        # LearnedMixinH
        elif args.alg == 'mixin':
            if 'bias' in args.job:
                models['bias'] = get_classifier(args)
            if 'final' in args.job:
                models['final'] = get_classifier(args)
        # RUBi
        elif args.alg == 'rubi':
            if 'final' in args.job:
                models['final'] = get_classifier(args)
        # Rebias
        elif args.alg == 'rebias':
            if 'bias' in args.job:
                models['bias'] = get_classifier(args)
            if 'final' in args.job:
                models['final'] = get_classifier(args)
        # Repair
        elif args.alg == 'repair':
            if 'bias' in args.job:
                models['bias'] = get_classifier(args)
            if 'final' in args.job:
                models['final'] = get_classifier(args)
        # Ours
        elif args.alg == 'ours':
            if 'bias' in args.job:
                models['bias'] = get_classifier(args)
            if 'final' in args.job:
                models['final'] = get_classifier(args)
        else:
            print('Invalid algorithm ...')
            exit()

    else:
        if os.path.isfile(args.ckpt_path+'model_best.pt'):
            args.log.info('Model load ...')
            models = model_load(args.ckpt_path,'best')
        else:
            args.log.info('No trained model... plz train first...')
            exit()


    return {'noise': noise_model, 'model': models}

def get_classifier(args):
    if args.arch == 'conv0':
        net = CNN_0(args.img_dim)
    elif args.arch == 'conv1':
        net = CNN_1(args.img_dim)
    elif args.arch == 'resnet18':
        net = models.resnet18()
    elif args.arch == 'resnet34':
        net = models.resnet34()
    elif args.arch == 'resnet50':
        net = models.resnet50()
    elif args.arch == 'resnet101':
        net = models.resnet101()
    elif args.arch == 'resnet152':
        net = models.resnet152()
    elif args.arch == 'alexnet':
        net = models.alexnet()
    elif args.arch == 'squeezenet':
        net = models.squeezenet1_0()
    elif args.arch == 'vgg16':
        net = models.vgg16()
    elif args.arch == 'densenet':
        net = models.densenet161()
    elif args.arch == 'inception':
        net = models.inception_v3()
    elif args.arch == 'googlenet':
        net = models.googlenet()
    elif args.arch == 'shufflenet':
        net = models.shufflenet_v2_x1_0()
    elif args.arch == 'mobilenetv2':
        net = models.mobilenet_v2()
    elif args.arch == 'mobilenetv3_large':
        net = models.mobilenet_v3_large()
    elif args.arch == 'mobilenetv3_small':
        net = models.mobilenet_v3_small()
    elif args.arch == 'resnext50':
        net = models.resnext50_32x4d()
    elif args.arch == 'wideresnet_50_2':
        net = models.wide_resnet50_2()
    elif args.arch == 'mnasnet':
        net = models.mnasnet1_0()
    else:
        print('Invalid architecture ...')
        exit()

    output_size = list(net.fc.parameters())[0].shape[1]
    net.fc = nn.Linear(output_size, args.num_labels)

    net = net.to(args.device)
    opt = t.optim.SGD(net.parameters(), lr = args.lr, momentum = args.momentum, weight_decay = args.weight_decay)
    scheduler = t.optim.lr_scheduler.StepLR(opt, step_size=args.lr_decay_step , gamma=args.lr_decay)

    return {'net': net, 'opt': opt, 'scheduler': scheduler}