from argparse import ArgumentParser
import sys
from pathlib import Path

import numpy as np
from nuq import NuqClassifier

sys.path.append('.')
from image_uncertainty.cifar.cifar_evaluate import described_plot
from experiments.imagenet_discrete import dump_ues
from scipy.special import softmax

parser = ArgumentParser()
parser.add_argument(
    '--subsample', action='store_true', default=False,
    help='set to subsample train embeddings 1 to 20'
)
parser.add_argument(
    '--base-dir', type=str,
    default='/home/mephody_bro/imagenet_embeddings_full'
)

parser.add_argument(
    '--ood-name', type=str,
    default='imagenet_o'
)

args = parser.parse_args()
base_dir = Path(args.base_dir)



x_train = np.load(base_dir / f'train_embeddings.npy')
y_train = np.load(base_dir / f'train_targets.npy')
x_val = np.load(base_dir / f'val_embeddings.npy')
y_val = np.load(base_dir / f'val_targets_resnet50.npy')
x_ood = np.load(base_dir / f'ood_embeddings_{args.ood_name}.npy')
# base_dir=Path('/home/mephody_bro/apps/face_uncertainty/experiments/checkpoint/imagenet_embeddings')
# x_train = np.load(base_dir / f'train_embeddings.npy')
# y_train = np.load(base_dir / f'train_targets.npy')
# x_val = np.load(base_dir / f'val_embeddings.npy')
# x_ood = np.load(base_dir / f'ood_embeddings.npy')


if args.subsample:
    idx = range(0, len(x_train), 20)
    x_train = x_train[idx]
    y_train = y_train[idx]

print(x_train.shape, len(x_val), len(x_ood))


nuq = NuqClassifier(
    strategy='isj', tune_bandwidth=True, n_neighbors=20
)
#described_plot(nuq.predict_uncertainty(x_val)['total'], nuq.predict_uncertainty(x_ood)['total'], '', '')
nuq.fit(X=x_train, y=y_train)
ues_test = nuq.predict_uncertainty(x_val)
ues_ood = nuq.predict_uncertainty(x_ood)


for ue_type in ues_test.keys():
    print(ue_type)
    described_plot(
        ues_test[ue_type], ues_ood[ue_type], args.ood_name, 'spectral', title_extras=f'NUQ_{ue_type}'
    )

    dump_ues(ues_test[ue_type], ues_ood[ue_type], f'nuq_{ue_type}', 'imagenet', args.ood_name)
    

