from depth.data_depth import *
import time
from common.util import save_status, load_data, load_model, assign_to_device, print_current_time, get_prediction_transformers, extraction_transformers
from depth.utils import *
from depth.depth_parser import parse_args

import os


def depth_features_by_class(depth, method, X_train, X_test, X_test_adv, y_train, c, U=None):
    X_train_c = X_train[np.where(np.argmax(y_train, axis=1) == c)]
    assert method in ["int_w_halfs_pace"]
    print("Choosen depths", method)
    if method == "int_w_halfs_pace":
        depth_method = depth.AI_IRW
    else:
        raise NotImplementedError
    res_adv = depth_method(X=X_train_c, X_test=X_test_adv, U=U)
    res_nat = depth_method(X=X_train_c, X_test=X_test, U=U)
    res = np.concatenate((res_nat, res_adv))
    labels = np.concatenate((np.zeros(X_test.shape[0], ), np.ones(X_test_adv.shape[0], )))
    return labels, res


def main(args):
    print("Depth scores")

    if args.attack.startswith('pgd') or args.attack.startswith('fgsm') or args.attack.startswith('bim'):
        all_attack = args.loss + "_" + args.attack + "_" + str(args.epsilon)
    else:
        all_attack = "_" + args.attack
    
    adv_path = os.path.join(args.adv_path, args.dataset_name)
    
    depth_res_dir = 'detectors/results/depth/'

    os.makedirs(depth_res_dir, exist_ok=True)

    X_train, y_train, X_test, y_test = load_data(dataset_name=args.dataset_name, transformer=args.transformer, data_dir=args.data_dir)

    X_test_adv = load_adv(args.adv_path, dataset_name=args.dataset_name, attack=all_attack, transformers=args.transformer)

    if args.dataset_name in 'cifar10':
        n_classes = 10
    else :
        n_classes = 100

    device = assign_to_device(args.device)
    if args.dataset_name == 'tiny':
        model = load_model(dataset_name=args.dataset_name, checkpoints_dir='{}/{}/'.format(args.checkpoints_dir, args.dataset_name),
                       device=device, transformer=args.transformer, checkpoint_name='model_tiny.pt')
    else :
        model = load_model(dataset_name=args.dataset_name, checkpoints_dir='{}/{}/'.format(args.checkpoints_dir, args.dataset_name),
                       device=device, transformer=args.transformer, ckpt_n=args.checkpoint_number, model_type=args.model_type)

    model.eval()

    print("Starting")
    t0 = time.perf_counter()

    K = args.K  # number of direction 10/20 times the dimension
    depth = DataDepth(args.K)
    logits_train, logits_test, logits_adv, y_train = get_prediction_transformers(args, model, X_train, X_test, X_test_adv)
    

    _, dim = logits_train.shape
    path = 'detectors/results/depth/{}/U_{}.npy'.format(args.dataset_name, dim)

    if os.path.exists(path):
        U = np.load(path)
    else:
        from pathlib import Path
        Path('detectors/results/depth/{}/'.format(args.dataset_name)).mkdir(parents=True, exist_ok=True)
        U = sampled_sphere(K, dim)
        np.save(path, U)

    loss_a = args.dataset_name + all_attack
    pred_labels_test_adv = np.argmax(logits_adv , axis=1)
    pred_labels_test_nat = np.argmax(logits_test, axis=1)
    pred_labels_test_all = np.concatenate((pred_labels_test_nat, pred_labels_test_adv))
    all_depth = np.zeros((len(pred_labels_test_nat) + len(pred_labels_test_adv), n_classes))
    predicted_depth = np.zeros((len(pred_labels_test_nat) + len(pred_labels_test_adv),))
    for c in range(n_classes):
        file_check = 'probs_' + args.dataset_name + all_attack
        file_check = depth_res_dir + '/probs_' + args.dataset_name + all_attack + '_' +str(c) + '.npy'

        if os.path.exists(file_check):
            print('Already exists:', file_check)
        else:
            print_current_time(s='Class {} start'.format(c))
            labels, res = depth_features_by_class(depth, args.depth_method, logits_train, logits_test, logits_adv, y_train, c, U=U)
            save_status(probs=res, attack= args.dataset_name + all_attack + '_' + str(c), path=depth_res_dir)
            all_depth[:, c] = res
                                    
                            

            print_current_time(s='Class {} end'.format(c))
            
    for p in range(len(pred_labels_test_all)) :
        predicted_depth[p] = all_depth[p, pred_labels_test_all[p]]

    dir_pred = depth_res_dir + '/prediction'
    save_status(probs=predicted_depth, attack= args.dataset_name + all_attack + '_predicted', path=dir_pred)
    t1 = time.perf_counter() - t0
    print("Time elapsed: ", t1)




if __name__ == "__main__":
    print("Starting")
    
    t0 = time.perf_counter()
    args = parse_args()
    if args.dataset_name == 'cifar10':
        args.checkpoint_number = 70000
    elif args.dataset_name == 'cifar100':
        args.checkpoint_number = 8000
    else:
        args.checkpoint_number = 0
    print(args.checkpoint_number)
    main(args)
    t1 = time.perf_counter() - t0
    print("Time elapsed: {} minutes".format(t1 / 60))
