'''Computes similarity matrix'''
import os
import copy
import sys
import argparse
import random
import timm
import time
import torch
import torchvision
import torch.nn as nn
import logging
import pandas as pd
import numpy as np
from tqdm import tqdm

from utils.model import load_model, load_simclr_model
from utils.eval_model import eval_model
from utils.data import SpecialCIFAR10, ImageNetV2Dataset

from robustness import defaults
from robustness.main import setup_args
from robustness.datasets import DATASETS
from robustness.data_augmentation import TEST_TRANSFORMS_IMAGENET


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--out-dir-complement', default=None, type=str,
        help='Folder name to be added to the out_dir. If None, will be automatically generated.')

    parser.add_argument('--head_dataset', default=None, type=str,
        help='If dataset for training head is not the same as the original dataset that the model was trained on, specify it here')
    parser.add_argument('--head_data', default=None, type=str,
        help='Path to dataset specified in head_dataset arg.')

    parser.add_argument('--seed', default=0, type=int, help='Seed')
    parser.add_argument('--num_classes', default=None, type=int, help='Number of classes')
    parser.add_argument('--perturbed_imgs_path', type=str, default=None,
        help='Path to perturbed images.')
    parser.add_argument('--load_dataset', default=None, type=str)

    parser.add_argument('--bottom', default=False, action='store_true')
    parser.add_argument('--neuron', default=None, type=int, help='Neurons to keep')
    parser.add_argument('--nb_neurons', default=50, type=int, help='Number of neurons to consider KEEPING')
    parser.add_argument('--neurons_path', type=str, default=None,
        help='Path to list of neurons to be removed.')

    parser.add_argument('--standard_model', default='robustness',
            choices=['robustness', 'simclr', 'timm'], type=str,
            help='If robustness, model was trained using the robustness package. If simclr, model was trained with simclr loss.')
        
    parser = defaults.add_args_to_parser(defaults.CONFIG_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.MODEL_LOADER_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.TRAINING_ARGS, parser)
    parser = defaults.add_args_to_parser(defaults.PGD_ARGS, parser)
    args = parser.parse_args()
    return args

def create_logger(args, path):
    logger = logging.getLogger('')
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(os.path.join(path, 'log.txt'))
    fh.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(logging.DEBUG)
    logger.addHandler(handler)
    logger.propagate = False
    return logger


def calculate_accuracy(args, testloader, model, device, num_classes):
    softmax = nn.Softmax(dim=1)
    if num_classes is not None:
        accum_per_class = dict((class_id,0) for class_id in range(num_classes))
        correct_per_class = dict((class_id,0) for class_id in range(num_classes))
    all_preds, all_labels, all_confidences, all_max_confidences, all_margin_confidences = [], [], [], [], []

    for i, (inp, label) in tqdm(enumerate(testloader), total=len(testloader)):
        with torch.no_grad():
            inp = inp.to(device)
            inp.requires_grad = False
            label = label.to(device)
            label.requires_grad = False

            # grid_img = torchvision.utils.make_grid(inp, nrow=5)
            # img = torchvision.transforms.ToPILImage()(grid_img)
            # img.save('input.png')
            # print(label)
            # exit()

            
            if args.standard_model == 'robustness':
                output, _ = model(inp, label=label)
            else:
                output = model(inp)
            model_logits = output[0] if (type(output) is tuple) else output

            _, preds = torch.topk(model_logits, 1, dim=1)
            preds = preds.t().cpu().detach()[0]
    
            confidence = softmax(model_logits)
            max2_confidences = torch.topk(confidence, 2, dim=1).values
            margin_confidence = torch.sub(max2_confidences[:, 0], max2_confidences[:, 1])
            max_confidences = torch.max(confidence, 1).values

            label = label.cpu().detach()
            all_preds.extend(preds)
            all_labels.extend(label)
            all_max_confidences.extend(max_confidences.cpu().detach())
            all_confidences.extend(confidence.cpu().detach())
            all_margin_confidences.extend(margin_confidence.cpu().detach())

            if num_classes is not None:
                for item_target, item_pred in zip(label, preds):
                    accum_per_class[int(item_target)] += 1
                    if int(item_target) == int(item_pred):
                        correct_per_class[int(item_target)] += 1

    all_labels = torch.stack(all_labels)
    all_preds = torch.stack(all_preds)
    all_confidences = torch.stack(all_confidences)
    all_max_confidences = torch.stack(all_max_confidences)
    all_margin_confidences = torch.stack(all_margin_confidences)
    if num_classes is not None:
        return all_preds, all_labels, all_confidences, all_max_confidences, all_margin_confidences, accum_per_class, correct_per_class
    return all_preds, all_labels, all_confidences, all_max_confidences, all_margin_confidences, None, None

def get_accuracy(args, testloader, model, device, path, logger=None):
    all_preds, all_labels, all_confidences, all_max_confidences, all_margin_conf, accum_per_class, correct_per_class = calculate_accuracy(args, testloader, model, device, args.num_classes)
    torch.save(all_preds, os.path.join(path, 'preds.pt'))
    torch.save(all_labels, os.path.join(path, 'labels.pt'))
    torch.save(all_confidences, os.path.join(path, 'confidences.pt'))
    torch.save(all_max_confidences, os.path.join(path, 'max_confidences.pt'))
    torch.save(all_margin_conf, os.path.join(path, 'margin_confidences.pt'))
    overall_acc = torch.eq(all_preds, all_labels).sum() / all_preds.shape[0]
    print(f' - Accuracy {overall_acc}')
    # if args.num_classes is not None:
    #     acc_per_class = dict((k, np.divide(correct_per_class[k], accum_per_class[k])) for k in correct_per_class)
    #     logger.info(f' - Accuracy per class {acc_per_class}')
    #     df = pd.DataFrame([v for v in acc_per_class.values()]).T
    #     df.to_csv(os.path.join(path, 'acc_per_class.csv'))
    #     print(df)
    #     acc_per_class_string = ''
    #     for acc_class in acc_per_class.values():
    #         acc_per_class_string += f'|| {acc_class}'
    #     print(f'|| {round(float(overall_acc), 2)} {acc_per_class_string} || ')

def main():
    args = get_args()
    args = setup_args(args)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    if args.out_dir_complement is None:
        info_path = args.resume.split('/')[-2]
        if 'linear_probe' in args.resume:
            info_path = os.path.join(args.resume.split('/')[-4], args.resume.split('/')[-3]) + '/' + info_path
        if args.load_dataset is not None:
            info_path += args.load_dataset
        elif args.perturbed_imgs_path is not None:
            info_path += '/'.join(args.perturbed_imgs_path.split('/')[3:])
        path = os.path.join(args.out_dir, info_path)
    else:
        path = os.path.join(args.out_dir, args.out_dir_complement)
    print(path)
    # input()
    if not os.path.exists(path):
        os.makedirs(path)
    logger = create_logger(args, path)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
    dataset = DATASETS[args.dataset](data_path=args.data) #, **kwargs)
    _, testloader = dataset.make_loaders(args.workers, args.batch_size, only_val=True,
        shuffle_train=False, shuffle_val=False, data_aug=False, val_batch_size=args.batch_size)

    if args.perturbed_imgs_path is not None and args.load_dataset is None:
        if any(perturbation in args.perturbed_imgs_path for perturbation in ['noise', 'adv', 'blur']):
            perturbed_images = torch.load(os.path.join(args.perturbed_imgs_path, 'images.pt'))
            perturbed_targets = torch.load(os.path.join(args.perturbed_imgs_path, 'labels.pt'))

            logger.info('==> Perturbed Instances Accuracy')
            testset = SpecialCIFAR10(perturbed_images, perturbed_targets, root='data/')
            testloader = torch.utils.data.DataLoader(
                testset, batch_size=args.batch_size, shuffle=False, num_workers=2)
    elif args.perturbed_imgs_path is not None and args.load_dataset is not None:
        if args.load_dataset == 'cifar10_1':
            print(f'=> Loading {args.load_dataset}')
            perturbed_images = np.load(os.path.join(args.perturbed_imgs_path, 'cifar10.1_v6_data.npy'))
            labels = np.load(os.path.join(args.perturbed_imgs_path, 'cifar10.1_v6_labels.npy'))
            # kwargs = dict(data_path='data/')
            dataset = SpecialCIFAR10(perturbed_images, labels)
            testloader = torch.utils.data.DataLoader(
                dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)

    if args.standard_model == 'robustness': # robustness package
        model, result_model = load_model(
            args, args.arch, args.resume, dataset, testloader=None)
        logger.info(f'Eval model 1: {result_model}')
    elif args.standard_model == 'simclr':
        model = load_simclr_model(args, device)
    elif args.standard_model == 'timm':
        model = timm.create_model(args.arch, num_classes=args.num_classes, pretrained=False)
        model_dict = model.state_dict()
        original_model_dict = copy.deepcopy(model.state_dict())
        trained_state_dict = torch.load(args.resume)['state_dict']
        trained_state_dict = {k.replace('model.', ''): v for k, v in trained_state_dict.items() if k.replace('model.', '') in model_dict}
        model_dict.update(trained_state_dict) 
        model.load_state_dict(trained_state_dict)
        # for new_name, new_param in model_dict.items():
        #     new_param = new_param.cpu().detach()
        #     if new_name in original_model_dict.keys():
        #         old_param = original_model_dict[new_name]
        #         old_param = old_param.cpu().detach()
        #         if torch.equal(new_param, old_param):
        #             print(f'===> Param {new_name} still has the same value....')
        #         else:
        #             print(f'Param {new_name} in model AND updated.')
        #     else:
        #         print(f"===> New name {new_name} not in model params")
        # input()
        model.to(device)

    neuron = args.neuron
    if 'linear_probe' in args.resume:
        if args.neurons_path is not None:
            neuron = torch.load(args.neurons_path)
            neuron_class = args.resume.split('/')[-2].split('_')[-1]
            neuron = np.array(neuron[neuron_class])
            if args.bottom: # bottom -- least percentage
                print('bottom')
                neuron = neuron.argsort()[:args.nb_neurons]
            else: # top -- highest percentage
                neuron = neuron.argsort()[::-1][:args.nb_neurons]
            print(neuron_class, neuron)
        else:
            print('Ranomdly selecting some neurons')
            neuron = np.random.randint(0, 511, args.nb_neurons)
            print(neuron, neuron.shape)
        valid_indexes = torch.from_numpy(neuron.copy())
        def mask_pre_hook(module, input):
            return (torch.index_select(input[0], dim=1, index=valid_indexes.to(input[0].device)),)
        model.model.linear.register_forward_pre_hook(mask_pre_hook)

    if args.head_dataset is not None:
        if args.head_dataset == 'imagenet_v2':
            print('Loading imagenet-v2 dataset...')
            dataset = ImageNetV2Dataset("matched-frequency", location=args.head_data, transform=TEST_TRANSFORMS_IMAGENET)
            testloader = torch.utils.data.DataLoader(
                dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)
        else:
            dataset = DATASETS[args.head_dataset](data_path=args.head_data) #, **kwargs)
            _, testloader = dataset.make_loaders(args.workers, args.batch_size, only_val=True,
                shuffle_train=False, shuffle_val=False, data_aug=False, val_batch_size=args.batch_size)

    model.eval()
    logger.info('==> Overall Accuracy')
    get_accuracy(args, testloader, model, device, path, logger)

if __name__=='__main__':
    main()
