'''Computes similarity matrix'''
import copy
import os
import sys
import argparse
import random
import time
import torch
import torch.nn as nn
import logging
import numpy as np
from tqdm import tqdm

from utils.model import load_model
from utils.eval_model import eval_model
from utils.data import SpecialCIFAR10

from robustness import defaults
from robustness.main import setup_args
from robustness.datasets import DATASETS
from robustness.train import train_model
from robustness.tools import helpers

def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--seed', default=0, type=int, help='Seed')
    parser.add_argument('--bottom', default=False, action='store_true')
    parser.add_argument('--neuron', default=None, type=int, nargs='+', help='Total number of neurons to remove')
    parser.add_argument('--nb_neurons', default=50, type=int, help='Number of neurons to consider KEEPING')
    parser.add_argument('--neurons_path', type=str, default=None,
        help='Path to list of neurons to be removed.')
    parser.add_argument('--num_classes', default=None, type=int, help='Number of classes')
    parser.add_argument('--perturbed_imgs_path', type=str, default=None,
        help='Path to perturbed images.')
    parser.add_argument('--load_dataset', default=None, type=str)

    parser = defaults.add_args_to_parser(defaults.CONFIG_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.MODEL_LOADER_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.TRAINING_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.PGD_ARGS, parser)
    args = parser.parse_args()
    return args

def create_logger(args, path):
    logger = logging.getLogger('')
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(os.path.join(path, 'log.txt'))
    fh.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(logging.DEBUG)
    logger.addHandler(handler)
    logger.propagate = False
    return logger

def disable_gradients(model):
    for param in model.parameters():
        param.requires_grad = False
    # Set the previous layers to eval mode to prevent the BatchNorm
    # statistics from updating during fine-tuning
    for module in model.modules():
        module.eval()
    return model

def main():
    args = get_args()
    args = setup_args(args)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    path = os.path.join(args.out_dir)
    print(path)
    if not os.path.exists(path):
        os.makedirs(path)
    logger = create_logger(args, path)
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    dataset = DATASETS[args.dataset](data_path=args.data) #, **kwargs)
    train_loader, val_loader = dataset.make_loaders(args.workers, args.batch_size,
        shuffle_train=False, shuffle_val=False, data_aug=False, val_batch_size=args.batch_size)
    train_loader = helpers.DataPrefetcher(train_loader)
    val_loader = helpers.DataPrefetcher(val_loader)
    loaders = (train_loader, val_loader)

    model, _ = load_model(
        args, args.arch, args.resume, dataset, testloader=None)
    if 'module' in dir(model): model = model.module

    model = disable_gradients(model)
    model.eval()

    neuron = args.neuron
    if neuron is None:
        if args.neurons_path is not None:
            neuron = torch.load(args.neurons_path)
            neuron_class = args.out_dir.split('/')[-1].split('_')[-1]
            neuron = np.array(neuron[neuron_class])
            if args.bottom: # bototm -- least percentage
                print('bottom')
                neuron = neuron.argsort()[:args.nb_neurons]
            else: # top -- highest percentage
                print('top')
                neuron = neuron.argsort()[::-1][:args.nb_neurons]
            print(neuron_class, neuron)
        else:
            print('Ranomdly selecting some neurons')
            # neuron = np.array([])
            neuron = np.random.randint(0, 511, args.nb_neurons)
            print(neuron, neuron.shape)

    model.model.linear = nn.Linear(neuron.shape[0], args.num_classes) # all except the neuron
    # valid_indexes = torch.tensor(list(np.delete(np.arange(512), neuron))) # neurons to be kept are all except that neuron
    valid_indexes = torch.from_numpy(neuron.copy())
    def mask_pre_hook(module, input):
        return (torch.index_select(input[0], dim=1, index=valid_indexes.to(input[0].device)),)
    model.model.linear.register_forward_pre_hook(mask_pre_hook)

    if not args.resume_optimizer: checkpoint = None
    model = train_model(args, model, loaders, store=None, checkpoint=checkpoint)


if __name__=='__main__':
    main()
