
import argparse
import numpy as np
import os
from pathlib import Path
import pickle
import sys
import torch

parser = argparse.ArgumentParser(description='Runs Spotlight')
parser.add_argument('--dataset', type = str, default = 'Synthetic')
parser.add_argument('--name', type = str, default = '')
parser.set_defaults(augment=True)
args = parser.parse_args()

dataset = args.dataset
name = args.name

if dataset == 'Synthetic':
    out_dir = './Outputs/{}/spotlight'.format(name)
else:
    out_dir = './Outputs/spotlight/{}'.format(name)
out_file = '{}/out.txt'.format(out_dir)
      
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), './{}/'.format(dataset)))
if dataset == 'ImageNet':
    from Config import get_class_map
else:
    from Config import get_data_dir, get_out_features

sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), './Common/'))
from DataUtils import get_transform
from Spotlight import run_spotlight
if dataset == 'ImageNet':
    from Load import load_imagenet
else:
    from Load import load_standard
    
os.chdir(dataset)
os.system('rm -rf {}'.format(out_dir))
Path(out_dir).mkdir(parents = True, exist_ok = True)
sys.stdout = open(out_file, 'w')
   
if dataset == 'ImageNet':
    out = load_imagenet(name, get_class_map())
    criterion = torch.nn.CrossEntropyLoss(reduction = 'none')

    classes = list(get_class_map())
    logits = out['logits']
    preds = [classes[np.argmax(logits[i, :])] for i in range(len(logits))]
    errors = np.array([1 * (v != name) for v in preds])

    transform = get_transform(mode = 'resize-crop')
else:
    if dataset == 'Synthetic':
        out = load_standard('object', '{}/{}'.format(get_data_dir(), name), get_out_features(), base_dir = './Outputs/{}'.format(name), fold = 'test')
    else:
        out = load_standard(name, get_data_dir(), get_out_features())

    criterion = torch.nn.BCEWithLogitsLoss(reduction = 'none')

    errors = np.array([0 if v >= 0.5 else 1 for v in out['probs']])

    transform = None

print('General Error Rate: ', np.round(np.mean(errors), 2))
print()
print()
print()

reps = torch.from_numpy(out['reps']).float()
reps -= torch.mean(reps, axis = 0)
reps /= torch.std(reps, axis = 0)
losses = criterion(torch.from_numpy(out['logits']), torch.from_numpy(out['labels'])).data.numpy() 
out = run_spotlight(reps, torch.from_numpy(losses).float(), out['files'], errors, transform = transform, out_dir = out_dir)

with open('{}/map.pkl'.format(out_dir), 'wb') as f:
    pickle.dump(out, f)
