import argparse
import copy
import logging
import os
import time
import math
from shutil import copyfile
import json

import numpy as np
import scipy

import torch
import torch.nn as nn
import torch.nn.functional as F
from apex import amp

from lip_convnets import LipConvNet
from utils import *

logger = logging.getLogger(__name__)

def get_args():
    parser = argparse.ArgumentParser()
    
    # Model specifications
    parser.add_argument('--batch-size', default=4096, type=int)
    parser.add_argument('--gamma', default=0., type=float, help='gamma for certificate regularization')
    parser.add_argument('--conv-layer', default='soc', type=str, choices=['bcop', 'cayley', 'soc'], 
                        help='BCOP, Cayley, SOC convolution')
    parser.add_argument('--fast-train', action='store_true', help='make backward pass of SOC faster during training')
    parser.add_argument('--init-channels', default=32, type=int)
    parser.add_argument('--activation', default='maxmin', choices=['maxmin', 'hh1', 'hh2'], help='Activation function')
    parser.add_argument('--pooling', default='lip1', choices=['max', 'lip1'], help='Pooling layer')
    parser.add_argument('--num-layers', default=5, type=int, choices=[5, 10, 15, 20, 25, 30, 35, 40], 
                        help='number of layers per block in the LipConvnet network')
    parser.add_argument('--last-layer', default='crc_full', choices=['ortho', 'lln', 'crc_ortho', 'crc_full'], 
                        help='last layer that maps features to logits')
    parser.add_argument('--data-dir', default='./cifar-data', type=str)
    parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'cifar100'], 
                        help='dataset to use for training')
    parser.add_argument('--out-dir', default='extract', type=str, help='Model directory')
    
    parser.add_argument('--model-name', default='last', type=str, help='Model name')
    return parser.parse_args()

def load_metadata(model_dir, prefix):
    train_features = np.load(os.path.join(model_dir, 'extracted', prefix + '_train_features.npy'))
    train_labels = np.load(os.path.join(model_dir, 'extracted', prefix + '_train_labels.npy'))

    test_features = np.load(os.path.join(model_dir, 'extracted', prefix + '_test_features.npy'))
    test_labels = np.load(os.path.join(model_dir, 'extracted', prefix + '_test_labels.npy'))
    
    return train_features, train_labels, test_features, test_labels

def compute_pdists(A):
    A_pdist = A @ A.T
    print(A_pdist.shape)
    A_sq = A_pdist.diagonal()
    print(A_sq.shape) 

    
    A_pdist = (-2 * A_pdist) + A_sq[:, None] + A_sq[None, :]
    print(A_pdist.shape, A_pdist.min(), A_pdist.mean(), A_pdist.max())
    return np.sqrt(A_pdist)

def main():
    args = get_args()
    
    if args.conv_layer == 'cayley' and args.opt_level == 'O2':
        raise ValueError('O2 optimization level is incompatible with Cayley Convolution')
    if args.fast_train and not(args.conv_layer == 'soc'):
        raise ValueError('fast training is only compatible with SOC')

    args.out_dir += '_' + str(args.dataset)
    args.out_dir += '_' + str(args.num_layers) 
    if args.fast_train:
        args.out_dir += '_fast' + str(args.conv_layer)
    else:
        args.out_dir += '_' + str(args.conv_layer)
    args.out_dir += '_' + str(args.activation)
    args.out_dir += '_' + str(args.pooling)
    args.out_dir += '_cr' + str(args.gamma)
    args.out_dir += '_' + str(args.last_layer)

    train_loader, test_loader = get_loaders(args.data_dir, args.batch_size, args.dataset, shuffle=False)
    if args.dataset == 'cifar10':
        args.num_classes = 10    
    elif args.dataset == 'cifar100':
        args.num_classes = 100
    else:
        raise Exception('Unknown dataset')
    
    # Evaluation at best model (early stopping)
    train_features, train_labels, test_features, test_labels = load_metadata(args.out_dir, 'last')
    
    train_images, train_labels_c = extract_inputs(train_loader)
    test_images, test_labels_c = extract_inputs(test_loader)
    
#     train_images_pdist = scipy.spatial.distance.pdist(train_images, metric='euclidean')
#     train_features_pdist = scipy.spatial.distance.pdist(train_features, metric='euclidean')

    num_select = 10000
    images, features, labels, labels_c = train_images, train_features, train_labels, train_labels_c
    
    idxs = np.random.choice(len(test_images), size=num_select)
    images = images[idxs, :]
    features = features[idxs, :]
    labels = labels[idxs]
    labels_c = labels_c[idxs]
    

    images_pdist = compute_pdists(images)
    features_pdist = compute_pdists(features)
    
    diff = (images_pdist - features_pdist)
#     np.fill_diagonal(diff, diff.mean())
        
    print(np.sum(diff < 0), np.sum(diff == 0), diff.min(), diff.mean(), diff.max())
    
    print(np.allclose(labels, labels_c))

        
    
if __name__ == "__main__":
    main()


