from utils import print_memory_usage, tovec
from data import mnist, fashion_mnist, emnist_letters, cifar10, omniglot
from net import StackedInvariantSubspaceModule, chunked_inference
from classifiers import NearestNeighborClassifier, LogisticRegressionCV, Evaluator, OmniglotEvaluator

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import optuna

import numpy as np
import argparse
import os
import json

parser = argparse.ArgumentParser(description='Hyperparameter optimization')

# experiment parameters
parser.add_argument('--name', type=str, required=True, help='name of experiment')
parser.add_argument('--gpu', type=int, default=0, help='which gpu')
parser.add_argument('--n_trials', type=int, default=100, help='number of hparams trials')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--create_study', action='store_true', help='create a database to store trials')
parser.add_argument('--verbose', type=int, default=0, help='verbose')

# data parameters 
parser.add_argument('--data', type=str, default='mnist', choices=['mnist','fashion_mnist','emnist_letters','cifar10','omniglot'])

# network parameters
parser.add_argument('--n_layers', type=int, default=1, help='number of layers')
parser.add_argument('--input_zca', action='store_true', help='apply zca to the inputs')
parser.add_argument('--output_pool', type=int, default=-1, help='avg pool to the specified output size (-1) for no pooling')
parser.add_argument('--preserve_size', action='store_true', help='preserve network size at each layer')

parser.add_argument('--layer_pool', nargs="+", type=str, default=['none'], choices=['none','avg','max'], help='pool after every layer')
parser.add_argument('--layer_standardize', action='store_true',  help='standardize feature maps after every layer')
parser.add_argument('--layer_zca', action='store_true', help='zca after every layer')

# subspace parameters
parser.add_argument('--n_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--chunks_per_update', type=int, default=1, help='number of batches per update')
parser.add_argument('--tie_parameters', action='store_true', help='use same params in every layer')

parser.add_argument('--rescale_features', action='store_true', help='rescale subspace feature maps')
parser.add_argument('--signed_features_if_1D', action='store_true', help='use signed features (only applicable if subspace_dim=1)')
parser.add_argument('--random_features', action='store_true', help='use random features')

parser.add_argument('--max_n_subspaces', nargs="+", type=int, default=[64], help='max n_subspaces for each layer')
parser.add_argument('--max_subspace_dim', nargs="+", type=int, default=[32], help='max subspaces_dim for each layer')
parser.add_argument('--max_kernel_size', nargs="+", type=int, default=[28], help='max kernel_size for each layer')
parser.add_argument('--min_n_subspaces', nargs="+", type=int, default=[2], help='min n_subspaces for each layer')

# evaluation parameters 
parser.add_argument('--ks', type=int, nargs="+", default=[10,30,100,300,1000,3000])
parser.add_argument('--nevals', type=int, nargs="+", default=[1])
parser.add_argument('--clf', type=str, default='logistic_regression', choices=['logistic_regression','nearest_neighbor','omniglot'])

class Objective:
    def __init__(self, args):
        # format args
        if len(args.max_n_subspaces) == 1: args.max_n_subspaces = args.max_n_subspaces*args.n_layers
        if len(args.max_subspace_dim) == 1: args.max_subspace_dim = args.max_subspace_dim*args.n_layers
        if len(args.max_kernel_size) == 1: args.max_kernel_size = args.max_kernel_size*args.n_layers
        if len(args.min_n_subspaces) == 1: args.min_n_subspaces = args.min_n_subspaces*args.n_layers
        if len(args.layer_pool) == 1: args.layer_pool = args.layer_pool*args.n_layers
        
        args.ks = [k if k != -1 else 'ALL' for k in args.ks]
        
        self.args = args
        
        # data
        if self.args.data == 'mnist':
            x, labels = mnist()[:2]
            x, labels = x[:60000], labels[:60000]
            train_idx = torch.arange(50000)
            test_idx = torch.arange(50000,60000)
            
            # train_idx = torch.arange(100)
            # test_idx = torch.arange(100,200)
            
        elif self.args.data == 'fashion_mnist':
            x, labels = fashion_mnist()[:2]
            x, labels = x[:60000], labels[:60000]
            train_idx = torch.arange(50000)
            test_idx = torch.arange(50000,60000)     

        elif self.args.data == 'emnist_letters':
            x, labels = emnist_letters()[:2]
            x, labels = x[:124800], labels[:124800]
            train_idx = torch.arange(110000)
            test_idx = torch.arange(110000,124800)
            
        elif self.args.data == 'cifar10':
            x, labels = cifar10(color=True)[:2]
            x, labels = x[:50000], labels[:50000]
            train_idx = torch.arange(40000)
            test_idx = torch.arange(40000,50000)
            
        elif self.args.data == 'omniglot':
            x, labels = omniglot(affine_correct=True)[:2]
            x, labels = x[:19280], labels[:19280]
            train_idx = torch.arange(15000)
            test_idx = torch.arange(15000,19280) 
            
        self.x = x
        self.labels = labels
        self.train_idx = train_idx
        self.test_idx = test_idx
        
    def __call__(self, trial):
        # get params
        params = self.suggest_params(trial)

        # init
        self.net = StackedInvariantSubspaceModule(params,verbose=args.verbose)
        
        if self.args.verbose: print_memory_usage('init network complete')
        if self.args.verbose: print(self.net)

        # fit
        loader = DataLoader(dataset=self.x, batch_size=self.args.batch_size)
        self.net.fit(loader)
        
        if self.args.verbose: print_memory_usage('fit network complete')
        if self.args.verbose: print(loader)

        # inference
        y = chunked_inference(self.net, self.x, chunk_size=self.args.batch_size, output_device='cpu')
        y -= y.mean(0)
        y /= y.std()

        if self.args.verbose: print_memory_usage('inference complete')
        if self.args.verbose: print('y.shape: {}'.format(y.shape))
        
        # move to cpu to get everything off the gpu
        self.net = self.net.cpu()

        # evaluate
        device = 'cuda' 
        # device = 'cpu'
        if self.args.clf == 'nearest_neighbor' or self.args.clf == 'logistic_regression':
            evaluator = Evaluator(clf=self.args.clf, nlabels_per_class=self.args.ks,
                              ntrials_per_config=self.args.nevals,verbose=self.args.verbose)
            scores = evaluator.evaluate(y[self.train_idx].to(device), self.labels[self.train_idx].to(device),
                                    y[self.test_idx].to(device), self.labels[self.test_idx].to(device))[0]
            self.scores = scores
            err = 1-np.mean(scores)
            
        elif self.args.clf == 'omniglot':
            evaluator = OmniglotEvaluator(n_way=20, n_shot=1, n_trials=1000, within_class=False)
            score = evaluator.evaluate(y[self.test_idx].to(device), self.labels[self.test_idx])[0]
            self.scores = [score]
            err = 1 - score

        if self.args.verbose: print_memory_usage('scoring complete')
        return err
                  
        # acc = score(y[self.train_idx].to(device), y[self.test_idx].to(device), 
        #            self.labels[self.train_idx].to(device), self.labels[self.test_idx].to(device), 
        #            self.args.ks, self.args.clf, self.args.verbose)
        # 
        # 
        # return 1-acc                

    def suggest_params(self, trial):
        params = {}
        
        def output_shape(params, layer=0):
            shape = np.array([_ for _ in self.x.shape[1:]])
            for key, val in params.items():
                if (key[:5] == 'layer') and (key[-5:] != '_pool') and (key[-4:] != '_zca') and (key[-12:] != '_standardize'):
                    layer_id = key[5]
                    if int(layer_id) <= layer:
                        shape[0] = val['n_subspaces']
                        shape[[1,2]] -= (val['kernel_size']-1-2*val['padding'])
            return [int(_) for _ in shape]

        # input norm
        if self.args.input_zca:
            params['input_norm'] = {}
            params['input_norm']['kernel_size'] = trial.suggest_int('input_norm-kernel_size', 1, 11, step=2)
            params['input_norm']['n_components'] = trial.suggest_int('input_norm-n_components', 0, params['input_norm']['kernel_size']**2-1)
        
        # output pool
        if self.args.output_pool != -1:
            params['output_pool'] = {'output_size': self.args.output_pool}
        
        # layers
        for i in range(self.args.n_layers):
            layer = 'layer{}'.format(i+1)
            params[layer] = {}
            
            # training params
            # params[layer]['warmup_iter'] = trial.suggest_int(layer+'-warmup_iter',0,10)
            # params[layer]['chunks_per_update'] = trial.suggest_int(layer+'-chunks_per_update',1,100)
            
            params[layer]['warmup_iter'] = 10
            params[layer]['chunks_per_update'] = self.args.chunks_per_update
                
            # kernel size, padding
            max_kernel_size = min(self.args.max_kernel_size[i], output_shape(params,layer=i)[1])
            if self.args.tie_parameters:
                params[layer]['kernel_size'] = trial.suggest_int(layer+'-kernel_size',3,max_kernel_size,step=2)
                k = params[layer]['kernel_size']
                min_pad_size = max(0,np.ceil((k+(self.args.n_layers-1)*(k-1)-28) / max(1,2*(self.args.n_layers-1))))
                max_pad_size = params[layer]['kernel_size']//2
                params[layer]['padding'] = trial.suggest_int(layer+'-padding',min_pad_size,max_pad_size)
                
            else: 
                if self.args.preserve_size:
                    max_kernel_size  -= ((max_kernel_size+1) % 2)
                    params[layer]['kernel_size'] = trial.suggest_int(layer+'-kernel_size',3,max_kernel_size,step=2)
                    params[layer]['padding'] = params[layer]['kernel_size']//2
                else:
                    params[layer]['kernel_size'] = trial.suggest_int(layer+'-kernel_size',1,max_kernel_size)
                    params[layer]['padding'] = trial.suggest_int(layer+'-padding',0,params[layer]['kernel_size']//2)
            
            # n_subspaces, subspace_dim, p
            max_subspace_dim = output_shape(params,layer=i)[0] * params[layer]['kernel_size']**2
            max_subspace_dim = min(self.args.max_subspace_dim[i], max_subspace_dim)
            
            params[layer]['subspace_dim'] = trial.suggest_int(layer+'-subspace_dim',1,max_subspace_dim)
            params[layer]['n_subspaces'] = trial.suggest_int(layer+'-n_subspaces',self.args.min_n_subspaces[i],self.args.max_n_subspaces[i])
            params[layer]['p'] = trial.suggest_float(layer+'-p', 0.0, 1.0)
            
            # fixed params
            params[layer]['stride'] = 1
            params[layer]['n_epochs'] = self.args.n_epochs
            params[layer]['rescale_features'] = self.args.rescale_features
            params[layer]['signed_features'] = self.args.signed_features_if_1D and (self.args.max_subspace_dim[i] == 1)
            params[layer]['verbose'] = self.args.verbose
            params[layer]['random_features'] = self.args.random_features
            
            # params[layer]['kernel_size'] = 5
            # params[layer]['padding'] = 2
            # params[layer]['stride'] = 2
            
            # pool
            if self.args.layer_pool[i] != 'none':
                max_pool_size = min(11,output_shape(params,layer=i+1)[1])
                max_pool_size -= ((max_pool_size+1) % 2) # subtract 1 to make odd if needed
                kernel_size = trial.suggest_int(layer+'-pool',1,max_pool_size,step=2)
                params[layer+'_pool'] = {
                    'pool_type': self.args.layer_pool[i],
                    'kernel_size': kernel_size,
                    'stride': 1, 
                    'padding': kernel_size//2}
            
            # standardize
            if self.args.layer_standardize:
                params[layer+'_standardize'] = {'center': True}
            
            # zca
            if self.args.layer_zca:
                params[layer+'_zca'] = {{'kernel_size': 1,
                     'n_components': trial.suggest_int(
                         layer+'-zca-n_components', 0, output_shape(params,layer=i)[0]**2)}}
                
            if self.args.tie_parameters:
                for i in range(1,self.args.n_layers):
                    layer = 'layer{}'.format(i+1)
                    params[layer] = params['layer1']
                        
                    if self.args.layer_pool[0] != 'none':
                        params[layer+'_pool'] = params['layer1_pool']
                        
                    if self.args.layer_standardize:
                        params[layer+'_standardize'] = params['layer1_standardize']
                        
                    if self.args.layer_zca:
                        params[layer+'_zca'] = params['layer1_zca'] 
                break 
                        
        
        return params
    
    def save_best_model(self, study, trial):
        if study.best_trial == trial:
            torch.save(self.net, 'bestnet.pt')
            np.savetxt('bestscores.txt', self.scores)
            # torch.save(self.sc, 'bestnet.pt')
            # torch.save(self.net.state_dict(), self.args.name + '_bestnet.pt')

    

if __name__=='__main__':
    args = parser.parse_args()
    if args.verbose: print('running experiment, args: {}'.format(args))
    
    # setup logging
    with open('args.json', 'w') as fp:
        json.dump(vars(args), fp)
    
    # setup CUDA params
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
    
    # create/load study
    objective = Objective(args)
    if args.create_study:
        study = optuna.create_study(
            study_name=args.name, 
            storage='sqlite:///{}.db'.format(args.name),
            direction='minimize')
    else:
        study = optuna.load_study(
            study_name=args.name, 
            storage='sqlite:///{}.db'.format(args.name))
        
    # run
    study.optimize(objective, n_trials=args.n_trials, callbacks=[objective.save_best_model])
