
import argparse
from domino import embed, DominoSlicer
import matplotlib.pyplot as plt
import meerkat as mk
import numpy as np
import os
from pathlib import Path
import pickle
from PIL import Image, ImageOps
import random
from scipy.special import softmax
import sys

# Config
parser = argparse.ArgumentParser(description = 'Runs Domino')
parser.add_argument('--dataset', type = str, default = 'Synthetic')
parser.add_argument('--name', type = str, default = '')
parser.add_argument('--skip_plot', default = False, action = 'store_true')
args = parser.parse_args()

dataset = args.dataset
name = args.name

weight = 10
y_log_likelihood_weight = weight
y_hat_log_likelihood_weight = weight
n_mixture_components = 25
n_slices = 10

d1 = 4
d2 = 5
k = d1 * d2

# Load imports
if dataset == 'Synthetic':
     out_dir = './Outputs/{}/domino'.format(name)   
else:
    out_dir = './Outputs/domino/{}'.format(name)
out_file = '{}/out.txt'.format(out_dir)

sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), './{}/'.format(dataset)))
if dataset == 'ImageNet':
    from Config import get_class_map
else:
    from Config import get_data_dir, get_out_features
    
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), './Common/'))
from Blindspots import image_grid
from DataUtils import get_transform
if dataset == 'ImageNet':
    from Load import load_imagenet
else:
    from Load import load_standard

# Setup outputs
os.chdir(dataset)
os.system('rm -rf {}'.format(out_dir))
Path(out_dir).mkdir(parents = True, exist_ok = True)
sys.stdout = open(out_file, 'w')

# Get the model errors
if dataset == 'ImageNet':
    out = load_imagenet(name, get_class_map())
    
    index = int(get_class_map()[name][1])           
    logits = out['logits']
    probs = softmax(logits, axis = 1)[:, index]
    errors = 1 * (np.argmax(logits, axis = 1) != index)
    labels = 1 * (out['labels'] == index)

else:
    if dataset == 'Synthetic':
        out = load_standard('object', '{}/{}'.format(get_data_dir(), name), get_out_features(), base_dir = './Outputs/{}'.format(name), fold = 'test')
    else:
        out = load_standard(name, get_data_dir(), get_out_features())
    
    probs = out['probs']
    errors = 1 * (probs < 0.5)
    labels = out['labels']
errors = np.array(errors)

# Setup the data transform
if dataset == 'ImageNet':
    t_main = get_transform(mode = 'imagenet')
    t_vis = get_transform(mode = 'resize-crop')
else:
    t_main = get_transform(mode = 'normalize')
    t_vis = None
    
# Setup the meerkat dataset for Domino
dp = mk.DataPanel({'img': mk.ImageColumn.from_filepaths(out['files'])})
dp['input'] = dp['img'].to_lambda(t_main)
dp['prob'] = probs
dp['target'] = labels

# Run Domino
dp = embed(dp, input_col = 'img', encoder = 'clip', device = 'cuda:0')
domino = DominoSlicer(
    y_log_likelihood_weight = y_log_likelihood_weight,
    y_hat_log_likelihood_weight = y_hat_log_likelihood_weight,
    n_mixture_components = n_mixture_components,
    n_slices = n_slices)
domino.fit(data = dp, embeddings = 'clip(img)', targets = 'target', pred_probs = 'prob')
dp['domino_slices'] = domino.predict(data = dp, embeddings = 'clip(img)', targets = 'target', pred_probs = 'prob')

# Create the map from "group" to "set of points" (recorded as indices)
cluster_map = {}
for i in range(n_slices):
    cluster_map[i] = []
    
for i, v in enumerate(dp['domino_slices']):
    cluster_map[np.argmax(v)].append(i)
    
# Score each of those groups
scores = []
for i in cluster_map:
    indices = cluster_map[i]
    score = len(indices) * np.mean(errors[indices]) ** 2 # Equivalent to 'number of errors * error rate'
    scores.append((i, score))
scores = sorted(scores, key = lambda x: -1 * x[1])

print('Scores:')
for i, score in scores:
    indices = cluster_map[i]
    print(i, score, len(indices) * np.mean(errors[indices]), np.mean(errors[indices]))
print()
    
# Re-order that map
new_map = {}
for index, _ in scores:
    new_map[index] = cluster_map[index]
    
cluster_map = new_map
       
with open('{}/map.pkl'.format(out_dir), 'wb') as f:
    pickle.dump(cluster_map, f)
    
# Visualize the groups
if not args.skip_plot:
    for i, key in enumerate(cluster_map):
        choices = cluster_map[key]
        chosen = np.array(random.sample(choices, min(k, len(choices))))
        chosen = chosen[np.argsort(errors[chosen])]    
        imgs = []
        for j in chosen:
            img = Image.open(out['files'][j]).convert('RGB')
            if t_vis is not None:
                img = t_vis(img)
            if errors[j] == 1: 
                img = ImageOps.expand(img, border = 5, fill = (255, 0, 0))
                img = ImageOps.expand(img, border = 3, fill = 0)
            else:
                img = ImageOps.expand(img, border = 8, fill = 0)
            imgs.append(img)
        grid = image_grid(imgs, d1, d2)
        plt.imshow(grid)
        plt.axis('off')
        plt.savefig('{}/group_{}.png'.format(out_dir, i))
        plt.close()
    