from utils import print_memory_usage, tovec
from data import mnist, fashion_mnist, emnist_letters, cifar10, omniglot
from net import StackedInvariantSubspaceModule, chunked_inference
from classifiers import NearestNeighborClassifier, LogisticRegressionCV, Evaluator, OmniglotEvaluator

from scipy.optimize import linear_sum_assignment
from sklearn.cluster import KMeans

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import optuna

import numpy as np
import argparse
import os
import json

parser = argparse.ArgumentParser(description='Hyperparameter optimization')

# experiment parameters
parser.add_argument('--name', type=str, required=True, help='name of experiment')
parser.add_argument('--gpu', type=int, default=0, help='which gpu')
parser.add_argument('--n_trials', type=int, default=100, help='number of hparams trials')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--create_study', action='store_true', help='create a database to store trials')
parser.add_argument('--verbose', type=int, default=0, help='verbose')

# data parameters 
parser.add_argument('--data', type=str, default='mnist', choices=['mnist','fashion_mnist','emnist_letters','cifar10','omniglot'])

# network parameters
parser.add_argument('--n_layers', type=int, default=1, help='number of layers')
parser.add_argument('--input_zca', action='store_true', help='apply zca to the inputs')
parser.add_argument('--output_pool', type=int, default=-1, help='avg pool to the specified output size (-1) for no pooling')
parser.add_argument('--preserve_size', action='store_true', help='preserve network size at each layer')

parser.add_argument('--layer_pool', nargs="+", type=str, default=['none'], choices=['none','avg','max'], help='pool after every layer')
parser.add_argument('--layer_standardize', action='store_true',  help='standardize feature maps after every layer')
parser.add_argument('--layer_zca', action='store_true', help='zca after every layer')

# subspace parameters
parser.add_argument('--n_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--chunks_per_update', type=int, default=1, help='number of batches per update')
parser.add_argument('--tie_parameters', action='store_true', help='use same params in every layer')

parser.add_argument('--rescale_features', action='store_true', help='rescale subspace feature maps')
parser.add_argument('--signed_features_if_1D', action='store_true', help='use signed features (only applicable if subspace_dim=1)')
parser.add_argument('--random_features', action='store_true', help='use random features')

parser.add_argument('--max_n_subspaces', nargs="+", type=int, default=[64], help='max n_subspaces for each layer')
parser.add_argument('--max_subspace_dim', nargs="+", type=int, default=[32], help='max subspaces_dim for each layer')
parser.add_argument('--max_kernel_size', nargs="+", type=int, default=[28], help='max kernel_size for each layer')
parser.add_argument('--min_n_subspaces', nargs="+", type=int, default=[2], help='min n_subspaces for each layer')

# evaluation parameters 
parser.add_argument('--n_labels', type=int, default=-1)
parser.add_argument('--n_classes', type=int, default=10)
parser.add_argument('--label_seed', type=int, default=0)

# parser.add_argument('--ks', type=int, nargs="+", default=[10,30,100,300,1000,3000])
# parser.add_argument('--nevals', type=int, nargs="+", default=[1])
# parser.add_argument('--clf', type=str, default='logistic_regression', choices=['logistic_regression','nearest_neighbor','omniglot'])


# evaluation
def permutation_invariant_error(p,y,n_classes):
    """ Permutation invariant error (1-accuracy)"""
    # checks
    unq = y.unique()
    for c in range(len(unq)):
        assert(c in unq)
    assert(len(unq) == n_classes)

    # construct C
    C = torch.zeros(n_classes, n_classes, dtype=p.dtype, device=p.device)
    for i in range(n_classes):
        for j in range(n_classes):
            C[i,j] = (p[y==i] != j).float().sum(0)
    C = C / p.shape[0]

    # construct P
    P = linear_sum_assignment(C.detach().cpu().numpy())[1]

    # compute hinge loss
    return C[torch.arange(n_classes), P].sum(), P

def sample_labels(labels, nlabels, n_classes, label_seed):
    if nlabels == -1:
        ixs = torch.arange(len(labels))
    else:
        ixs = torch.zeros(nlabels*n_classes, dtype=int)
        for i in range(n_classes):
            ixs_ = torch.where(labels==i)[0]
            ixs[i*nlabels:(i+1)*nlabels] = ixs_[label_seed*nlabels:(label_seed+1)*nlabels]
    return ixs

class Objective:
    def __init__(self, args):
        # format args
        if len(args.max_n_subspaces) == 1: args.max_n_subspaces = args.max_n_subspaces*args.n_layers
        if len(args.max_subspace_dim) == 1: args.max_subspace_dim = args.max_subspace_dim*args.n_layers
        if len(args.max_kernel_size) == 1: args.max_kernel_size = args.max_kernel_size*args.n_layers
        if len(args.min_n_subspaces) == 1: args.min_n_subspaces = args.min_n_subspaces*args.n_layers
        if len(args.layer_pool) == 1: args.layer_pool = args.layer_pool*args.n_layers
        
        # args.ks = [k if k != -1 else 'ALL' for k in args.ks]
        
        self.args = args
        
        # data
        if self.args.data == 'mnist':
            x, labels = mnist()[:2]
            x, labels = x[:60000], labels[:60000]
            train_idx = torch.arange(50000)
            test_idx = torch.arange(50000,60000)
            
            # train_idx = torch.arange(100)
            # test_idx = torch.arange(100,200)
            
        elif self.args.data == 'fashion_mnist':
            x, labels = fashion_mnist()[:2]
            x, labels = x[:60000], labels[:60000]
            train_idx = torch.arange(50000)
            test_idx = torch.arange(50000,60000)     

        elif self.args.data == 'emnist_letters':
            x, labels = emnist_letters()[:2]
            x, labels = x[:124800], labels[:124800]
            train_idx = torch.arange(110000)
            test_idx = torch.arange(110000,124800)
            
        elif self.args.data == 'cifar10':
            x, labels = cifar10(color=True)[:2]
            x, labels = x[:50000], labels[:50000]
            train_idx = torch.arange(40000)
            test_idx = torch.arange(40000,50000)
            
        elif self.args.data == 'omniglot':
            x, labels = omniglot(affine_correct=True)[:2]
            x, labels = x[:19280], labels[:19280]
            train_idx = torch.arange(15000)
            test_idx = torch.arange(15000,19280) 
            
        self.x = x
        self.labels = labels
        self.label_indices = sample_labels(labels, args.n_labels, args.n_classes, args.label_seed)
        
        # self.x_eval = x[label_indices]
        # self.labels = labels[label_indices]
        
        # self.train_idx = train_idx
        # self.test_idx = test_idx
        
    def __call__(self, trial):
        # get params
        params = self.suggest_params(trial)

        # init
        self.net = StackedInvariantSubspaceModule(params,verbose=args.verbose)
        
        if self.args.verbose: print_memory_usage('init network complete')
        if self.args.verbose: print(self.net)

        # fit
        loader = DataLoader(dataset=self.x, batch_size=self.args.batch_size)
        self.net.fit(loader)
        
        if self.args.verbose: print_memory_usage('fit network complete')
        if self.args.verbose: print(loader)

        # cluster final
        cluster_final = False
        if cluster_final:
            y = chunked_inference(self.net, self.x[self.label_indices], chunk_size=self.args.batch_size, output_device='cpu')
            assert(y.shape[2] == y.shape[3] == 1)
            assert(y.shape[1] == self.args.n_classes)
            y = y.max(dim=1)[1]
        else:
            y = chunked_inference(self.net, self.x, chunk_size=self.args.batch_size, output_device='cpu')
            km = KMeans(self.args.n_classes)
            km.fit(tovec(y).cpu().numpy())
            y = torch.from_numpy(km.labels_)[self.label_indices]

        if self.args.verbose: print_memory_usage('inference complete')
        if self.args.verbose: print('y.shape: {}'.format(y.shape))
        
        # move to cpu to get everything off the gpu
        self.net.km = km
        self.net = self.net.cpu()

        # evaluate
        device = 'cuda' 
        
        err = permutation_invariant_error(y,self.labels[self.label_indices],10)[0]
        self.scores = [1-err]
        print(self.label_indices)
        # device = 'cpu'
        
#         if self.args.clf == 'nearest_neighbor' or self.args.clf == 'logistic_regression':
#             evaluator = Evaluator(clf=self.args.clf, nlabels_per_class=self.args.ks,
#                               ntrials_per_config=self.args.nevals,verbose=self.args.verbose)
#             scores = evaluator.evaluate(y[self.train_idx].to(device), self.labels[self.train_idx].to(device),
#                                     y[self.test_idx].to(device), self.labels[self.test_idx].to(device))[0]
#             self.scores = scores
#             err = 1-np.mean(scores)
            
#         elif self.args.clf == 'omniglot':
#             evaluator = OmniglotEvaluator(n_way=20, n_shot=1, n_trials=1000, within_class=False)
#             score = evaluator.evaluate(y[self.test_idx].to(device), self.labels[self.test_idx])[0]
#             self.scores = [score]
#             err = 1 - score
            
            

        if self.args.verbose: print_memory_usage('scoring complete')
        return err
                  
        # acc = score(y[self.train_idx].to(device), y[self.test_idx].to(device), 
        #            self.labels[self.train_idx].to(device), self.labels[self.test_idx].to(device), 
        #            self.args.ks, self.args.clf, self.args.verbose)
        # 
        # 
        # return 1-acc                

    def suggest_params(self, trial):
        params = {}
        
        def output_shape(params, layer=0):
            shape = np.array([_ for _ in self.x.shape[1:]])
            for key, val in params.items():
                if (key[:5] == 'layer') and (key[-5:] != '_pool') and (key[-4:] != '_zca') and (key[-12:] != '_standardize'):
                    layer_id = key[5]
                    if int(layer_id) <= layer:
                        shape[0] = val['n_subspaces']
                        shape[[1,2]] -= (val['kernel_size']-1-2*val['padding'])
            return [int(_) for _ in shape]

        # input norm
        if self.args.input_zca:
            params['input_norm'] = {}
            params['input_norm']['kernel_size'] = trial.suggest_int('input_norm-kernel_size', 1, 11, step=2)
            params['input_norm']['n_components'] = trial.suggest_int('input_norm-n_components', 0, params['input_norm']['kernel_size']**2-1)
        
        # output pool
        if self.args.output_pool != -1:
            params['output_pool'] = {'output_size': self.args.output_pool}
        
        # layers
        for i in range(self.args.n_layers):
            layer = 'layer{}'.format(i+1)
            params[layer] = {}
            
            # training params
            # params[layer]['warmup_iter'] = trial.suggest_int(layer+'-warmup_iter',0,10)
            # params[layer]['chunks_per_update'] = trial.suggest_int(layer+'-chunks_per_update',1,100)
            
            params[layer]['warmup_iter'] = 10
            params[layer]['chunks_per_update'] = self.args.chunks_per_update
                
            # kernel size, padding
            max_kernel_size = min(self.args.max_kernel_size[i], output_shape(params,layer=i)[1])
            if self.args.tie_parameters:
                params[layer]['kernel_size'] = trial.suggest_int(layer+'-kernel_size',3,max_kernel_size,step=2)
                k = params[layer]['kernel_size']
                min_pad_size = max(0,np.ceil((k+(self.args.n_layers-1)*(k-1)-28) / max(1,2*(self.args.n_layers-1))))
                max_pad_size = params[layer]['kernel_size']//2
                params[layer]['padding'] = trial.suggest_int(layer+'-padding',min_pad_size,max_pad_size)
                
            else: 
                if self.args.preserve_size:
                    max_kernel_size  -= ((max_kernel_size+1) % 2)
                    params[layer]['kernel_size'] = trial.suggest_int(layer+'-kernel_size',3,max_kernel_size,step=2)
                    params[layer]['padding'] = params[layer]['kernel_size']//2
                else:
                    params[layer]['kernel_size'] = trial.suggest_int(layer+'-kernel_size',1,max_kernel_size)
                    params[layer]['padding'] = trial.suggest_int(layer+'-padding',0,params[layer]['kernel_size']//2)
            
            # n_subspaces, subspace_dim, p
            max_subspace_dim = output_shape(params,layer=i)[0] * params[layer]['kernel_size']**2
            max_subspace_dim = min(self.args.max_subspace_dim[i], max_subspace_dim)
            
            params[layer]['subspace_dim'] = trial.suggest_int(layer+'-subspace_dim',1,max_subspace_dim)
            params[layer]['n_subspaces'] = trial.suggest_int(layer+'-n_subspaces',self.args.min_n_subspaces[i],self.args.max_n_subspaces[i])
            params[layer]['p'] = trial.suggest_float(layer+'-p', 0.0, 1.0)
            
            # fixed params
            params[layer]['stride'] = 1
            params[layer]['n_epochs'] = self.args.n_epochs
            params[layer]['rescale_features'] = self.args.rescale_features
            params[layer]['signed_features'] = self.args.signed_features_if_1D and (self.args.max_subspace_dim[i] == 1)
            params[layer]['verbose'] = self.args.verbose
            params[layer]['random_features'] = self.args.random_features
            
            # params[layer]['kernel_size'] = 5
            # params[layer]['padding'] = 2
            # params[layer]['stride'] = 2
            
            # pool
            if self.args.layer_pool[i] != 'none':
                max_pool_size = min(11,output_shape(params,layer=i+1)[1])
                max_pool_size -= ((max_pool_size+1) % 2) # subtract 1 to make odd if needed
                kernel_size = trial.suggest_int(layer+'-pool',1,max_pool_size,step=2)
                params[layer+'_pool'] = {
                    'pool_type': self.args.layer_pool[i],
                    'kernel_size': kernel_size,
                    'stride': 1, 
                    'padding': kernel_size//2}
            
            # standardize
            if self.args.layer_standardize:
                params[layer+'_standardize'] = {'center': True}
            
            # zca
            if self.args.layer_zca:
                params[layer+'_zca'] = {{'kernel_size': 1,
                     'n_components': trial.suggest_int(
                         layer+'-zca-n_components', 0, output_shape(params,layer=i)[0]**2)}}
                
            if self.args.tie_parameters:
                for i in range(1,self.args.n_layers):
                    layer = 'layer{}'.format(i+1)
                    params[layer] = params['layer1']
                        
                    if self.args.layer_pool[0] != 'none':
                        params[layer+'_pool'] = params['layer1_pool']
                        
                    if self.args.layer_standardize:
                        params[layer+'_standardize'] = params['layer1_standardize']
                        
                    if self.args.layer_zca:
                        params[layer+'_zca'] = params['layer1_zca'] 
                break 
             
        ###############
        # final layer #
        ###############
        cluster_final = False
        if cluster_final:
            i = self.args.n_layers-1
            layer = 'layer{}'.format(i+1)
            params[layer] = {}

            # training params
            # params[layer]['warmup_iter'] = trial.suggest_int(layer+'-warmup_iter',0,10)
            # params[layer]['chunks_per_update'] = trial.suggest_int(layer+'-chunks_per_update',1,100)

            params[layer]['warmup_iter'] = 10
            params[layer]['chunks_per_update'] = self.args.chunks_per_update

            # kernel size, padding
            # max_kernel_size = min(self.args.max_kernel_size[i], output_shape(params,layer=i)[1])
            max_kernel_size = output_shape(params,layer=i)[1]

            params[layer]['kernel_size'] = max_kernel_size
            params[layer]['padding'] = 0

            # n_subspaces, subspace_dim, p
            max_subspace_dim = output_shape(params,layer=i)[0] * params[layer]['kernel_size']**2
            max_subspace_dim = min(self.args.max_subspace_dim[i], max_subspace_dim)

            params[layer]['subspace_dim'] = 1 # trial.suggest_int(layer+'-subspace_dim',1,max_subspace_dim)
            params[layer]['n_subspaces'] = self.args.n_classes
            params[layer]['p'] = trial.suggest_float(layer+'-p', 0.0, 1.0)

            # fixed params
            params[layer]['stride'] = 1
            params[layer]['n_epochs'] = self.args.n_epochs
            params[layer]['rescale_features'] = self.args.rescale_features
            params[layer]['signed_features'] = self.args.signed_features_if_1D and (self.args.max_subspace_dim[i] == 1)
            params[layer]['verbose'] = self.args.verbose
            params[layer]['random_features'] = self.args.random_features

        return params
    
    def save_best_model(self, study, trial):
        if study.best_trial == trial:
            torch.save(self.net, 'bestnet.pt')
            np.savetxt('bestscores.txt', self.scores)
            # torch.save(self.sc, 'bestnet.pt')
            # torch.save(self.net.state_dict(), self.args.name + '_bestnet.pt')

    

if __name__=='__main__':
    args = parser.parse_args()
    if args.verbose: print('running experiment, args: {}'.format(args))
    
    # setup logging
    with open('args.json', 'w') as fp:
        json.dump(vars(args), fp)
    
    # setup CUDA params
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    
    # create/load study
    objective = Objective(args)
    if args.create_study:
        study = optuna.create_study(
            study_name=args.name, 
            storage='sqlite:///{}.db'.format(args.name),
            direction='minimize')
    else:
        study = optuna.load_study(
            study_name=args.name, 
            storage='sqlite:///{}.db'.format(args.name))
        
    # run
    study.optimize(objective, n_trials=args.n_trials, callbacks=[objective.save_best_model])
