import torch
import numpy as np
import torchvision.models as models

from allmodels import MNIST, load_model, load_mnist_data, load_cifar10_data, CIFAR10, load_imagenet_train, load_imagenet_test
from models import PytorchModel


def init_classifier(state):
    # randomized test loader for target sample search
    target_loader = None
    
    if 'MNIST' in state['dataset']:
        _, _, _, gen_dataset = load_mnist_data(state, mode='generator', shuffle_test=False)

        if state['targeted']:
            _, target_loader, _, _ = load_mnist_data(state, mode='generator', shuffle_test=True)
            
        if state['dataset'] == "MNIST":
            model = MNIST()
            model = torch.nn.DataParallel(model)
            load_model(model, state['victim_path'])
            model_wrapper = PytorchModel(model, 
                                         num_classes=10, 
                                         im_mean=None, im_std=None)
            
        elif state['dataset'] == "MNIST_deep_camma":
            from community.causal_robustness.Deep_Camma_Manager_predict import Deep_Camma_Manager_Predict
            test_parameters = {
                "model_save_path": state['victim_path']
            }

            deep_camma_manager = Deep_Camma_Manager_Predict(n_classes=10,
                                                            batch_size=128,
                                                            test_parameters=test_parameters,
                                                            m=1)
            model_wrapper = PytorchModel(deep_camma_manager, 
                                         num_classes=10, 
                                         im_mean=None, im_std=None)

        elif state['dataset'] == "MNIST_rob_manifold" or state['dataset'] == "MNIST_trades" or state['dataset'] == "MNIST_madry":
            from community.jalal_simple_models import MNISTClassifier as Model
            
            model = Model().cuda()
            model.load_state_dict(torch.load(state['victim_path']))
            model = torch.nn.DataParallel(model)
            model_wrapper = PytorchModel(model, 
                                         num_classes=10, 
                                         im_mean=None, im_std=None)
        
    elif 'CIFAR10' in state['dataset']:
        _, _, _, gen_dataset = load_cifar10_data(state, mode='generator', shuffle_test=False)

        if state['targeted']:
            _, target_loader, _, _ = load_cifar10_data(state, mode='generator', shuffle_test=True)
                
        num_classes = 10
        if state['dataset'] == 'CIFAR10':
            model = CIFAR10() 
            model = torch.nn.DataParallel(model)
            load_model(model, state['victim_path'])

            model_wrapper = PytorchModel(model, 
                                         num_classes=num_classes, 
                                         im_mean=None, 
                                         im_std=None)
        # Several robust CIFAR models from RayS repository
        elif state['dataset'] == 'CIFAR10_trades':
            from community import wideresnet
            model = wideresnet.WideResNet().cuda()
            model = torch.nn.DataParallel(model)
            model.module.load_state_dict(torch.load(state['victim_path']))
            model_wrapper = PytorchModel(model, 
                                         num_classes=num_classes, 
                                         im_mean=None, 
                                         im_std=None)

#         elif state['dataset'] == 'CIFAR10_adv':
#             model = wideresnet.WideResNet().cuda()
#             model = torch.nn.DataParallel(model)
#             model.load_state_dict(torch.load('model/rob_cifar_madry.pt'))
#             test_loader = load_cifar10_test_data(args.batch)
#             model_wrapper = PytorchModel(model, 
#                                          num_classes=num_classes, 
#                                          im_mean=None, 
#                                          im_std=None)

        elif state['dataset'] == 'CIFAR10_madry':
            import tensorflow.compat.v1 as tf
            from community import madry_wrn
            from community.general_tf_model import TensorflowModel
            tf.disable_eager_execution()
            tf.disable_v2_behavior()
            
            model = madry_wrn.Model(mode='eval')
            saver = tf.train.Saver()
            sess = tf.Session()
            saver.restore(sess, tf.train.latest_checkpoint(state['victim_path']))
            model_wrapper = TensorflowModel(model.pre_softmax, model.x_input, 
                                            sess, 
                                            num_classes=num_classes, 
                                            im_mean=None, 
                                            im_std=None)

        elif state['dataset'] == 'CIFAR10_interp':
            from community import wideresnet_interp
            model = wideresnet_interp.WideResNet(depth=28, num_classes=10, widen_factor=10).cuda()
            model = torch.nn.DataParallel(model)
            checkpoint = torch.load(state['victim_path'])
            model.load_state_dict(checkpoint['net'])
            model_wrapper = PytorchModel(model, 
                                         num_classes=num_classes, 
                                         im_mean=[0.5, 0.5, 0.5],
                                         im_std=[0.5, 0.5, 0.5])
        elif state['dataset'] == 'CIFAR10_fs':
            from community import fs_utils
            from community import wideresnet_fs
            basic_net = wideresnet_fs.WideResNet(
                depth=28, num_classes=10, widen_factor=10).cuda()
            basic_net = basic_net.cuda()
            model = fs_utils.Model_FS(basic_net)
            model = torch.nn.DataParallel(model)
            checkpoint = torch.load(state['victim_path'])
            model.load_state_dict(checkpoint['net'])
            model_wrapper = PytorchModel(model, 
                                         num_classes=num_classes, 
                                         im_mean=[0.5, 0.5, 0.5],
                                         im_std=[0.5, 0.5, 0.5])

        elif state['dataset'] == 'CIFAR10_sense':
            from community import wideresnet
            model = wideresnet.WideResNet().cuda()
            model = torch.nn.DataParallel(model)
            model.load_state_dict(torch.load(state['victim_path'])['state_dict'])
            model_wrapper = PytorchModel(model, 
                                         num_classes=num_classes, 
                                         im_mean=None, 
                                         im_std=None)
            
        elif state['dataset'] == 'CIFAR10_smooth20':
            from community.cohen_builder import get_architecture
            from community.cohen_smooth import SmoothModel
            ckpt = torch.load(state['victim_path'])
            model = get_architecture('cifar_resnet20', 'cifar10')
            model.load_state_dict(ckpt['state_dict'], strict=False)
            # Model builder inserts a normalize layer for us.
            model_wrapper = SmoothModel(model, 
                                        num_classes=num_classes, 
                                        im_mean=None, 
                                        im_std=None,
                                        sigma=0.25)
            
        elif state['dataset'] == 'CIFAR10_smooth110':
            from community.cohen_builder import get_architecture
            from community.cohen_smooth import SmoothModel
            ckpt = torch.load(state['victim_path'])
            model = get_architecture('cifar_resnet110', 'cifar10')
            model.load_state_dict(ckpt['state_dict'], strict=False)
            # Model builder inserts a normalize layer for us.
            model_wrapper = SmoothModel(model, 
                                        num_classes=num_classes, 
                                        im_mean=None, 
                                        im_std=None,
                                        sigma=0.25)
            
#         elif state['dataset'] == 'CIFAR10_rst':
#             model = wideresnet_rst.WideResNet_RST()
#             model = torch.nn.DataParallel(model).cuda()
#             model.load_state_dict(torch.load('model/rst_adv.pt.ckpt')['state_dict'])
#             model_wrapper = PytorchModel(model, 
#                                          num_classes=num_classes, 
#                                          im_mean=None, 
#                                          im_std=None)

#         elif state['dataset'] == 'CIFAR10_mart':
#             model = wideresnet_rst.WideResNet_RST().cuda()
#             model = torch.nn.DataParallel(model)
#             model.load_state_dict(torch.load('model/mart_unlabel.pt')['state_dict'])
#             model_wrapper = PytorchModel(model, 
#                                          num_classes=num_classes, 
#                                          im_mean=None, 
#                                          im_std=None)

#         elif state['dataset'] == 'CIFAR10_uat':
#             import tensorflow_hub as hub
#             import tensorflow as tf
#             UAT_HUB_URL = ('./model/uat_model')
#             model = hub.Module(UAT_HUB_URL)
#             my_input = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
#             my_logits = model(dict(x=my_input, decay_rate=0.1, prefix='default'))
#             sess = tf.Session()
#             sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
#             model_wrapper = TensorflowModel(my_logits, my_input, sess, 
#                                             num_classes=num_classes, 
#                                             im_mean=[125.3/255, 123.0/255, 113.9/255], 
#                                             im_std=[63.0/255, 62.1/255, 66.7/255])

#         elif state['dataset'] == 'CIFAR10_overfitting':
#             model = wideresnet_overfitting.WideResNet(depth=34, num_classes=10, widen_factor=20).cuda()
#             model = torch.nn.DataParallel(model)
#             model.load_state_dict(torch.load('model/rob_cifar_overfitting.pth'))
#             model_wrapper = PytorchModel(model, 
#                                             num_classes=num_classes, 
#                                             im_mean=[0.4914, 0.4822, 0.4465], 
#                                             im_std=[0.2471, 0.2435, 0.2616])

#         elif state['dataset'] == 'CIFAR10_pretrain':
#             model = wideresnet_overfitting.WideResNet(depth=28, num_classes=10, widen_factor=10).cuda()
#             model = torch.nn.DataParallel(model)
#             model.load_state_dict(torch.load('model/rob_cifar_pretrain.pt'))
#             model_wrapper = PytorchModel(model, 
#                                             num_classes=num_classes, 
#                                             im_mean=[0.5, 0.5, 0.5], 
#                                             im_std=[0.5, 0.5, 0.5])

#         elif state['dataset'] == 'CIFAR10_fast':
#             model = preact_resnet.PreActResNet18().cuda()
#             model.load_state_dict(torch.load('model/rob_cifar_fast_epoch30.pth'))
#             model_wrapper = PytorchModel(model, 
#                                             num_classes=num_classes, 
#                                             im_mean=[0.4914, 0.4822, 0.4465], 
#                                             im_std=[0.2471, 0.2435, 0.2616])

#         elif state['dataset'] == 'CIFAR10_compact':
#             model = torch.nn.DataParallel(wideresnet_compact.wrn_28_10())
#             ckpt = torch.load('model/rob_cifar_compact.pth.tar', map_location="cpu")["state_dict"]
#             model.load_state_dict(ckpt)
#             model.cuda()
#             model_wrapper = PytorchModel(model, 
#                                          num_classes=num_classes, 
#                                          im_mean=None, 
#                                          im_std=None)

#         elif state['dataset'] == 'CIFAR10_mma':
#             from advertorch_examples.models import get_cifar10_wrn28_widen_factor
#             model = get_cifar10_wrn28_widen_factor(4).cuda()
#             model = torch.nn.DataParallel(model)
#             model.module.load_state_dict(torch.load('model/rob_cifar_mma.pt')['model'])
#             model_wrapper = PytorchModel(model, 
#                                          num_classes=num_classes, 
#                                          im_mean=None, 
#                                          im_std=None)

#         elif state['dataset'] == 'CIFAR10_he':
#             model = wideresnet_he.WideResNet(normalize = True).cuda()
#             model = torch.nn.DataParallel(model)
#             model.module.load_state_dict(torch.load('model/rob_cifar_pgdHE.pt'))
#             model_wrapper = PytorchModel(model, 
#                                          num_classes=num_classes, 
#                                          im_mean=None, 
#                                          im_std=None)
    elif 'Imagenet' in state['dataset']:
        # Don't normalize since we want to attack un-normalized images over [0, 1].
        # Model will perform normalization, see models.PytorchModel
        _, gen_dataset = load_imagenet_test(state, normalize=False, shuffle_test=False)

        if state['targeted']:
            target_loader, _ = load_imagenet_test(state, normalize=False, shuffle_test=True)
        
        if state['dataset'] == 'Imagenet':
            print(f"Using {state['victim_architecture']} for victim model.")
            model = models.__dict__[state["victim_architecture"]](pretrained=True)
            model = torch.nn.DataParallel(model)

            model_wrapper = PytorchModel(model, 
                                 num_classes=1000, 
                                 im_mean=[0.485, 0.456, 0.406], 
                                 im_std=[0.229, 0.224, 0.225])
        elif state['dataset'] == 'Imagenet_madry8' or state['dataset'] == 'Imagenet_madry4':
            # Based on https://github.com/MadryLab/robustness/blob/79d371fd799885ea5fe5553c2b749f41de1a2c4e/robustness/model_utils.py#L53
            model = models.__dict__["resnet50"](pretrained=False)
            ckpt = torch.load(state['victim_path'])
            sd = ckpt['model']
            sd = {k[len('module.'):]:v for k,v in sd.items()}
            # remove normalizer module that robustness bakes into checkpoint and truncate "model."
            # also strips out attacker.model weights which are identical to model.
            model_sd = {}
            for k, v in sd.items():
                if 'attacker.' in k:
                    continue
                if 'model.' in k:
                    model_sd[k[len('model.'):]] = v
                    
            model.load_state_dict(model_sd)
            model = torch.nn.DataParallel(model)

            model_wrapper = PytorchModel(model, 
                                 num_classes=1000, 
                                 im_mean=[0.485, 0.456, 0.406], 
                                 im_std=[0.229, 0.224, 0.225])
        elif state['dataset'] == 'Imagenet_smooth50':
            from community.cohen_builder import get_architecture
            from community.cohen_smooth import SmoothModel
            ckpt = torch.load(state['victim_path'])
            model = get_architecture('resnet50', 'imagenet')
            model.load_state_dict(ckpt['state_dict'])
            # Model builder inserts a normalize layer for us.
            model_wrapper = SmoothModel(model, 
                                        num_classes=1000, 
                                        im_mean=None, 
                                        im_std=None,
                                        sigma=0.5)
        
    else:
        print("Unsupport dataset")
        os.exit(1)
    
    return model_wrapper, gen_dataset, target_loader