import math
import os
import sys
import torch
from torchvision.datasets import CIFAR10, ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from models import ModelWrapper

from models import *

CIFAR10_MEAN = torch.tensor([0.4914, 0.4822, 0.4465]).cuda()
CIFAR10_STD = torch.tensor([0.2023, 0.1994, 0.2010]).cuda()
CIFAR10_SIZE = 32

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).cuda()
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).cuda()
IMAGENET_SIZE = 224

def test_transforms_imagenet(size):
    # Preprocess ImageNet test images
    return transforms.Compose([
        transforms.Resize(int(size*1.14)),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
    ])


def load_data(args):
    # Load dataset
    if args.data_type == 'CIFAR10':
        testset = CIFAR10(args.data_dir, train=False,
                          transform=transforms.ToTensor(),
                          download=False)
    elif args.data_type == 'ImageNet':
        testset = ImageFolder(os.path.join(args.data_dir, 'ImageNet/test'),
                              transform=test_transforms_imagenet(IMAGENET_SIZE))
    else:
        raise NotImplementedError

    loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False if args.data_type=='GTSRB' else True)
    
    return loader


def load_model(args):
    # Load model
    model_class = getattr(sys.modules[__name__], args.model_type)

    if (args.data_type == 'ImageNet' and args.train_type == 'nat') or  \
            (args.data_type == 'CIFAR10' and args.model_type != 'WideResNet') :
        base_model = model_class(pretrained=True)
        base_model = base_model.cuda()
        base_model = torch.nn.DataParallel(base_model)

    else:
        base_model = model_class()
        checkpoint_path = os.path.join(
            args.model_dir,
            args.data_type.lower(),
            args.model_type.lower(),
            args.train_type,
            'ckpt.pt'
        )
        checkpoint = torch.load(checkpoint_path)

        if args.data_type == 'CIFAR10':
            if 'adv' in args.train_type or 'nat' in args.train_type:
                checkpoint = {k[len('model.'):]:v for k, v in checkpoint.items() if k.startswith('model')}
            elif 'trades' in args.train_type:
                checkpoint = {('module.'+k):v for k, v in checkpoint.items() if not k.startswith('sub_block')}
        
        elif args.data_type == 'ImageNet':
            if 'gaussian_0.25' in args.train_type:
                sd = checkpoint['state_dict']
                checkpoint = {k[2:]:v for k, v in sd.items()}
            elif 'gaussian' in args.train_type or 'fast' in args.train_type:
                checkpoint = checkpoint['state_dict']
            elif 'adv' in args.train_type:
                state_dict_path = 'model'
                if not ('model' in checkpoint):
                    state_dict_path = 'state_dict'

                sd = checkpoint[state_dict_path]
                sd = {k[len('module.'):]:v for k, v in sd.items()}
                sd = {k[len('model.'):]:v for k,v in sd.items() if k.startswith('model')}
                checkpoint = {('module.'+k):v for k, v in sd.items()}
        else:
            raise NotImplementedError
            
        base_model = base_model.cuda()
        base_model = torch.nn.DataParallel(base_model)
        base_model.load_state_dict(checkpoint)

    if 'trades' in args.train_type:
        model = base_model
    else:
        mean = eval(args.data_type.upper() + '_MEAN')
        std = eval(args.data_type.upper() + '_STD')
        model = ModelWrapper(base_model, mean, std)
    
    model.eval()
    return model

def save_dir(args):
    # Get save directory
    safe_spot_dir = os.path.join(
        args.output_dir,
        args.data_type.lower(),
        args.model_type.lower(),
        args.train_type.lower(),
        args.attack_type.lower(),
        args.defense_type.lower(),
    )

    return safe_spot_dir


def evaluate(model, image, label, return_full_preds=False):
    # evaluate images
    output = model(image)
    pred = torch.max(output, 1)[1]
    num_correct = (pred.cpu().numpy() == label.cpu().numpy()).sum()
    
    if return_full_preds:
        softmax = torch.nn.Softmax(dim=1)(output).detach().cpu().numpy()
        return num_correct, softmax

    return num_correct


def evaluate_rand(model, image, label, scale, num_samples_test):
    # evaluate images (randomized smoothing)
    num_images = image.shape[0]
    preds = None
    num_correct = 0
      
    for i in range(num_images):
        batch_size = 50
        num_batches = int(math.ceil(num_samples_test / batch_size))
        ps = None

        for j in range(num_batches):
            bstart = j * batch_size
            bend = min(bstart + batch_size, num_samples_test)
            
            image_noise = image[i].unsqueeze(0).repeat(bend - bstart, 1, 1, 1)
            image_noise = image_noise + scale * torch.randn_like(image_noise)
            
            output = model(image_noise)
            p = torch.max(output, 1)[1]
            ps = p if ps is None else torch.cat([ps, p], dim=0)

        count = torch.bincount(ps, minlength=output.shape[1])
         
        pred = torch.argmax(count)
        preds = pred.view(1) if preds is None else torch.cat([preds, pred.view(1)], dim=0)
        if pred == label[i]:
            num_correct += 1
       
    return preds, num_correct

