
import argparse
import numpy as np
import os
import pandas as pd
from pathlib import Path
import pickle
import sys
import torch

parser = argparse.ArgumentParser(description = 'Runs Barlow')
parser.add_argument('--dataset', type = str, default = 'Synthetic')
parser.add_argument('--name', type = str, default = '')
parser.add_argument('--not_imagenet_model', default = True, action = 'store_false')
parser.add_argument('--skip_feature_vis', default = False, action = 'store_true')
args = parser.parse_args()

dataset = args.dataset
name = args.name
use_imagenet = args.not_imagenet_model

if dataset == 'Synthetic':
     out_dir = './Outputs/{}/barlow-{}'.format(name, use_imagenet)   
else:
    out_dir = './Outputs/barlow/{}-{}'.format(name, use_imagenet)
out_file = '{}/out.txt'.format(out_dir)
meta_file = '{}/tmp.csv'.format(out_dir)

sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), './{}/'.format(dataset)))
if dataset == 'ImageNet':
    from Config import get_class_map
else:
    from Config import get_data_dir, get_out_features

sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), './Common/'))
from Barlow import sample_failure_explanation
if dataset == 'ImageNet':
    from Load import load_imagenet
else:
    from Load import load_standard
    
os.chdir(dataset)
os.system('rm -rf {}'.format(out_dir))
Path(out_dir).mkdir(parents = True, exist_ok = True)
sys.stdout = open(out_file, 'w')

if dataset == 'ImageNet':
    out = load_imagenet(name, get_class_map())

    classes = list(get_class_map())
    logits = out['logits']
    preds = [classes[np.argmax(logits[i, :])] for i in range(len(logits))]

else:
    if dataset == 'Synthetic':
        out = load_standard('object', '{}/{}'.format(get_data_dir(), name), get_out_features(), base_dir = './Outputs/{}'.format(name), fold = 'test')
    else:
        out = load_standard(name, get_data_dir(), get_out_features())

    preds = []
    for v in out['probs']:
        if v >= 0.5:
            preds.append(name)
        else:
            preds.append('No {}'.format(name))

labels = [name] * len(preds)

df = pd.DataFrame.from_dict({'Filenames': out['files'], 'Predictions': preds, 'Labels': labels})
df.to_csv(meta_file, index = False)

robust_model = None
if not use_imagenet:
    from robustness.datasets import ImageNet
    from robustness.attacker import AttackerModel

    from robustness.imagenet_models.resnet import resnet18


    model = resnet18(pretrained = True)
    model.fc = torch.nn.Linear(in_features = 512, out_features = get_out_features())
    model.load_state_dict(torch.load('./Outputs/adv-tune/trial0/model.pt'))
    model.eval()

    robust_model = AttackerModel(model, ImageNet('./'))
    robust_model.cuda()

out = sample_failure_explanation(meta_file, name, use_imagenet = use_imagenet, model = robust_model, out_dir = out_dir, skip_feature_vis = args.skip_feature_vis)

with open('{}/map.pkl'.format(out_dir), 'wb') as f:
    pickle.dump(out, f)

# Clean up
os.system('rm {}'.format(meta_file))
