import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler

import torchvision
import torchvision.transforms as transforms

import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH

from hpbandster.core.worker import Worker

import logging
logging.basicConfig(level=logging.DEBUG)
import numpy as np
import models

from goodfellow_backprop import goodfellow_backprop

import os
import time

DATADIR = '/mnt/scratch/xiaoxiang/haozhe/data'
DEBUG = True

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class PyTorchWorker(Worker):
        def __init__(self, **kwargs):
                super().__init__(**kwargs)
                self.train_loader, self.validation_loader, self.test_loader = None, None, None
                self.model = None
                self.arch = ''
                self.num_classes = 10
                self.mode = 'valid_acc'
                self.train_dataset = None
                self.test_dataset = None
                self.train_sampler = None
                self.valid_sampler = None
                self.train_loader = None
                self.test_loader = None
                self.validation_loader = None
                self.alpha = 0.05
                self.gpu = 0
                self.debug_file = 'debug_log'


        def load_dataset(self, dataset):
                if dataset == 'cifar10':
                    print('=> loading cifar10 data...')
                    normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])

                    train_dataset = torchvision.datasets.CIFAR10(
                        root=DATADIR,
                        train=True,
                        download=True,
                        transform=transforms.Compose([
                            transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            normalize,
                        ]))

                    test_dataset = torchvision.datasets.CIFAR10(
                        root=DATADIR,
                        train=False,
                        download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            normalize,
                        ]))
                elif dataset == 'cifar100':
                    print('=> loading cifar100 data...')
                    normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])

                    train_dataset = torchvision.datasets.CIFAR100(
                        root=DATADIR,
                        train=True,
                        download=True,
                        transform=transforms.Compose([
                            transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            normalize,
                        ]))
                    test_dataset = torchvision.datasets.CIFAR100(
                        root=DATADIR,
                        train=False,
                        download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            normalize,
                        ]))
                elif dataset == 'fashion':
                    print('=> loading fashion mnist...')
                    transform = transforms.Compose([
                                transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
                    train_dataset = torchvision.datasets.FashionMNIST(DATADIR, download=True, train=True, transform=transform)
                    test_dataset = torchvision.datasets.FashionMNIST(DATADIR, download=True, train=False, transform=transform)

                elif dataset == 'mnist':
                    print('=> loading mnist...')
                    kwargs = {'num_workers': 1, 'pin_memory': True}
                    train_dataset = torchvision.datasets.MNIST(DATADIR, train=True, download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307,), (0.3081,))]))
                    test_dataset = torchvision.datasets.MNIST(DATADIR, train=False, transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307,), (0.3081,))]))
                elif dataset == 'SVHN':
                    print('=> loading SVHN...')

                    def target_transform(target):
                        return int(target[0]) - 1

                    kwargs = {'num_workers': 1, 'pin_memory': True}
                    train_dataset = torchvision.datasets.SVHN(
                            root=DATADIR, split='train', download=False,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]),
                            target_transform=None )


                    test_dataset = torchvision.datasets.SVHN(
                            root=DATADIR, split='test', download=False,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))
                else:
                    raise RuntimeError


                num_train = len(train_dataset)
                indices = list(range(num_train))
                split = int(np.floor(0.2 * num_train))

                np.random.seed(10)
                np.random.shuffle(indices)

                train_idx, valid_idx = indices[:split], indices[split:]
                self.valid_sampler = SubsetRandomSampler(valid_idx[:split])
                self.train_sampler = SubsetRandomSampler(train_idx)

                self.train_dataset = train_dataset
                self.test_dataset = test_dataset


        def set_gpu(self, gpu):
            self.gpu = gpu

        def set_model(self, arch, num_classes):
            self.arch = arch
            self.num_classes = num_classes

        def set_alpha(self, alpha):
            self.alpha = alpha

        def set_mode(self, mode):
            self.mode = mode

        def set_debug_file(self, debug_file):
            self.debug_file = debug_file

        def build_model(self):
            arch = self.arch
            num_classes = self.num_classes
            print('=> Building model...')
            if 'conv' in arch:
                self.model = models.__dict__[arch]()
            else:
                self.model = models.__dict__[arch](num_classes=num_classes)
            pytorch_total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            print('total num of trainable params', pytorch_total_params)
            self.number_of_parameters = pytorch_total_params
            if DEBUG:
                with open('bohb_debug/' + self.debug_file, 'a') as f:
                    f.write('A new model has been built\n')

        def adjust_learning_rate(self, optimizer, epoch, model_type, config_lr):
            """For resnet, the lr starts from 0.1, and is divided by 10 at 80 and 120 epochs"""
            if model_type == 1:
                if epoch < 80:
                    lr = config_lr
                elif epoch < 120:
                    lr = config_lr * 0.1
                else:
                    lr = config_lr * 0.01
            elif model_type == 4:
                if epoch < 4:
                    lr = config_lr
                elif epoch < 10:
                    lr = config_lr * 0.4
                elif epoch < 16:
                    lr = config_lr * 0.2
                elif epoch < 24:
                    lr = config_lr * 0.1
                else:
                    lr = config_lr * 0.02

            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        def get_lr(self, optimizer):
            for param_group in optimizer.param_groups:
                        return param_group['lr']

        def compute(self, config, budget, working_directory, *args, **kwargs):
                self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=['batch_size'], sampler=self.train_sampler, num_workers=1)
                self.copy_train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=config['batch_size'], sampler=self.train_sampler, num_workers=1)
                self.validation_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=config['batch_size'], sampler=self.valid_sampler, num_workers=1)
                self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2)

                self.build_model()
                model = self.model

                if 'resnet' in self.arch:
                    model_type = 1
                else:
                    model_type = 4

                model = model.cuda()
                optimizer = torch.optim.SGD(model.parameters(), lr=config['lr'], momentum=config['sgd_momentum'], weight_decay=config['sgd_decay'])
                criterion = nn.CrossEntropyLoss().cuda()

                norm_sum = 0
                total_norm_time = 0
                for epoch in range(int(budget)):
                        self.adjust_learning_rate(optimizer, epoch, model_type, config['lr'])
                        if DEBUG:
                            with open('bohb_debug/' + self.debug_file, 'a') as f:
                                f.write(str(config) +'\n')
                                f.write('epoch:{epoch}\tlr:{lr}'.format(epoch=epoch, lr=self.get_lr(optimizer)) +'\n')
                                f.write('len of train_loader {l}\n'.format(l=len(self.train_loader)))
                                f.write('batch size %s'% str(self.train_loader.batch_size) +'\n\n\n')
                        loss = 0
                        model.train()
                        for i, (x, y) in enumerate(self.train_loader):
                                x, y = x.cuda(), y.cuda()
                                optimizer.zero_grad()
                                output, _, _= model(x)
                                loss = criterion(output, y)
                                loss.backward()
                                optimizer.step()
                                if 'itrgrads' in self.mode:
                                    def get_grads_norm(trainloader, model, criterion):
                                        norm_time = time.time()
                                        model.eval()
                                        avg_norm = AverageMeter()
                                        for i, (input, target) in enumerate(trainloader):
                                            input, target = input.cuda(), target.cuda()
                                            output, activations, linearCombs = model.forward(input)
                                            loss = criterion(output, target)
                                            linearGrads = torch.autograd.grad(loss, linearCombs)
                                            gradients = goodfellow_backprop(activations, linearGrads)

                                            for sample_grad in gradients[-2]:
                                                norm = torch.norm(sample_grad, 2)
                                                avg_norm.update(norm.item(), 1)
                                        norm_time = time.time() - norm_time
                                        model.train()
                                        return avg_norm.avg, norm_time

                                    avg_norm, norm_time = get_grads_norm(self.copy_train_loader, model, criterion)
                                    norm_sum += avg_norm
                                    total_norm_time += norm_time

                        if 'agn' in self.mode:
                            def get_grads_norm(trainloader, model, criterion):
                                norm_time = time.time()
                                model.eval()
                                avg_norm = AverageMeter()
                                for i, (input, target) in enumerate(trainloader):
                                    input, target = input.cuda(), target.cuda()
                                    output, activations, linearCombs = model.forward(input)
                                    loss = criterion(output, target)
                                    linearGrads = torch.autograd.grad(loss, linearCombs)
                                    gradients = goodfellow_backprop(activations, linearGrads)

                                    for sample_grad in gradients[-2]:
                                        norm = torch.norm(sample_grad, 2)
                                        avg_norm.update(norm.item(), 1)
                                norm_time = time.time() - norm_time
                                return avg_norm.avg, norm_time

                            avg_norm, norm_time = get_grads_norm(self.train_loader, model, criterion)
                            norm_sum += avg_norm
                            total_norm_time += norm_time

                            if DEBUG:
                                with open('bohb_debug/' + self.debug_file, 'a') as f:
                                    f.write('norm_sum is {norm_sum}'.format(norm_sum=norm_sum))

                        elif 'allgrads' in self.mode:
                            def get_grads_norm(trainloader, model, criterion):
                                norm_time = time.time()
                                model.eval()
                                avg_norm = AverageMeter()
                                for i, (input, target) in enumerate(trainloader):
                                    input, target = input.cuda(), target.cuda()

                                    output, activations, linearCombs = model.forward(input)
                                    each_loss = F.cross_entropy(output, target, reduction='none')

                                    grad_list = []
                                    for l in each_loss:
                                        model.zero_grad()
                                        l.backward(retain_graph=True)
                                        grad_list.append(list([p.grad.clone() for p in model.parameters()]))
                                    grads = [] # grads hold the individual gradients
                                    for p_id in range(len(list(model.parameters()))):
                                        grads.append(torch.cat([grad_list[n][p_id].unsqueeze(0) for n in range(input.shape[0])]))
                                    for j in range(len(grads)):
                                        for idx, sample_grad in enumerate(grads[j]):
                                            norm = torch.norm(sample_grad, 2)
                                            avg_norm.update(norm.item(), 1)
                                norm_time = time.time() - norm_time
                                return avg_norm.avg, norm_time

                            avg_norm, norm_time = get_grads_norm(self.train_loader, model, criterion)
                            norm_sum += avg_norm
                            total_norm_time += norm_time


                train_accuracy, train_loss = self.evaluate_accuracy(model, self.train_loader, criterion)
                validation_accuracy, validation_loss = self.evaluate_accuracy(model, self.validation_loader, criterion)
                test_accuracy, test_loss = self.evaluate_accuracy(model, self.test_loader, criterion)

                if DEBUG:
                    with open('bohb_debug/' + self.debug_file, 'a') as f:
                        f.write('accuracy\t{train}\t{val}\t{test}\n'.format(train=train_accuracy, val=validation_accuracy, test=test_accuracy))
                        f.write('loss\t{train}\t{val}\t{test}\n'.format(train=train_loss, val=validation_loss, test=test_loss))

                if self.mode == 'valid_acc':
                    ret = 1-validation_accuracy
                elif self.mode == 'train_acc':
                    ret = 1-train_accuracy
                elif self.mode == 'test_acc':
                    ret = 1-test_accuracy
                elif self.mode == 'valid_loss':
                    ret = validation_loss
                elif self.mode == 'train_loss':
                    ret = train_loss
                elif self.mode == 'test_loss':
                    ret = test_loss
                elif self.mode == 'valid_loss_agn':
                    ret = validation_loss + norm_sum * self.alpha
                elif self.mode == 'train_loss_agn':
                    ret = train_loss + norm_sum * self.alpha
                elif self.mode == 'test_loss_agn':
                    ret = test_loss + norm_sum * self.alpha
                elif self.mode in ['valid_loss_agn_tuple','valid_loss_allgrads_tuple', 'valid_loss_itrgrads_tuple']:
                    ret = (validation_loss, norm_sum, self.alpha)
                elif self.mode in ['test_loss_agn_tuple', 'test_loss_allgrads_tuple']:
                    ret = (train_loss, norm_sum, self.alpha)
                elif self.mode in ['train_loss_agn_tuple', 'train_loss_allgrads_tuple']:
                    ret = (test_loss, norm_sum, self.alpha)


                return ({
                        'loss': ret, # remember: HpBandSter always minimizes!
                        'info': {'mode': self.mode,
                                'alpha': self.alpha,
                                'test accuracy': test_accuracy,
                                'train accuracy': train_accuracy,
                                'validation accuracy': validation_accuracy,
                                'test loss': test_loss,
                                'train_loss': train_loss,
                                'validation loss': validation_loss,
                                'agn norm': norm_sum,
                                'total norm time': total_norm_time,
                                'avg norm time': total_norm_time/int(budget),
                                'number of parameters': self.number_of_parameters,
                                }

                })

        def evaluate_accuracy(self, model, data_loader, criterion):
                model.eval()
                correct=0
                losses = AverageMeter()
                with torch.no_grad():
                        for x, y in data_loader:
                                x, y = x.cuda(), y.cuda()
                                output, _, _ = model(x)
                                loss = criterion(output, y)
                                losses.update(loss.item(), x.size(0))
                                pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
                                correct += pred.eq(y.view_as(pred)).sum().item()
                #import pdb; pdb.set_trace()
                accuracy = correct/len(data_loader.sampler)
                return(accuracy, losses.avg)


        @staticmethod
        def get_configspace(resnet110):
                """
                It builds the configuration space with the needed hyperparameters.
                It is easily possible to implement different types of hyperparameters.
                Beside float-hyperparameters on a log scale, it is also able to handle categorical input parameter.
                :return: ConfigurationsSpace-Object
                """
                cs = CS.ConfigurationSpace()

                lr = CSH.UniformFloatHyperparameter('lr', lower=1e-7, upper=0.5, default_value=0.1, log=True)

                sgd_momentum = CSH.UniformFloatHyperparameter('sgd_momentum', lower=0.0, upper=0.99, default_value=0.6, log=False)
                sgd_decay = CSH.UniformFloatHyperparameter('sgd_decay', lower=5e-7, upper=0.05, default_value=0.0005, log=False)
                batch_size = CSH.UniformIntegerHyperparameter('batch_size', lower=32, upper=512 if resnet110 else 1000, default_value=100, log=False)

                cs.add_hyperparameters([lr, sgd_momentum, sgd_decay, batch_size])

                print(cs)

                return cs



if __name__ == "__main__":
        worker = PyTorchWorker(run_id='0')
        worker.load_dataset('fashion')
        worker.build_model('conv_mnist', 10)
        cs = worker.get_configspace(False)

        config = cs.sample_configuration().get_dictionary()
        print(config)
        res = worker.compute(config=config, budget=2, working_directory='.')
        print(res)
