
import argparse
from collections import defaultdict
import glob
import itertools
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
from PIL import Image, ImageOps
from pathlib import Path
import random
from scipy.special import softmax
from sklearn import mixture
import sys

parser = argparse.ArgumentParser(description = 'Runs Blindspots with automatic clustering')
parser.add_argument('--dataset', type = str, default = 'Synthetic')
parser.add_argument('--name', type = str, default = '')
parser.add_argument('--use_error', default = False, action = 'store_true')
parser.add_argument('--weight', type = float, default = 0.025)
parser.add_argument('--max_groups', type = int, default = 10)
args = parser.parse_args()

dataset = args.dataset
name = args.name
use_confidence = not args.use_error
error_weight = args.weight
max_groups = args.max_groups

if not use_confidence:
    threshold = 1
else:
    threshold = 0.5

d1 = 4
d2 = 5
k = d1 * d2

if dataset == 'Synthetic':
     out_dir = './Outputs/{}/blindspots-{}'.format(name, use_confidence)   
else:
    out_dir = './Outputs/blindspots/{}-{}'.format(name, use_confidence)
out_file = '{}/out.txt'.format(out_dir)

sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), './{}/'.format(dataset)))
if dataset == 'ImageNet':
    from Config import get_class_map
else:
    from Config import get_data_dir, get_out_features

sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), './Common/'))
from Blindspots import image_grid
from DataUtils import get_transform
if dataset == 'ImageNet':
    from Load import load_imagenet
else:
    from Load import load_standard
    
os.chdir(dataset)
os.system('rm -rf {}'.format(out_dir))
Path(out_dir).mkdir(parents = True, exist_ok = True)
sys.stdout = open(out_file, 'w')

# Get the model errors
if dataset == 'ImageNet':
    out = load_imagenet(name, get_class_map())
    
    index = int(get_class_map()[name][1])           
    logits = out['logits']

    if use_confidence:
        
        errors = 1 - softmax(logits, axis = 1)[:, index]
        threshold = 0.5
    else:
        errors = (np.argmax(logits, axis = 1) != index)
        threshold = 1
else:
    if dataset == 'Synthetic':
        out = load_standard('object', '{}/{}'.format(get_data_dir(), name), get_out_features(), base_dir = './Outputs/{}'.format(name), fold = 'test')
    else:
        out = load_standard(name, get_data_dir(), get_out_features())
    
    if use_confidence:
        errors = [1 - v for v in out['probs']]
        threshold = 0.5
    else:
        errors = [0 if v >= 0.5 else 1 for v in out['probs']]
        threshold = 1
errors = np.array(errors)

# Setup the data transform
if dataset == 'ImageNet':
    transform = get_transform(mode = 'resize-crop')
else:
    transform = None
    
# Load SCVIS represnetation
if dataset == 'Synthetic':
    search_string = './Outputs/{}/scvis/*.tsv'.format(name)
else:
    search_string = './Outputs/scvis/{}/*.tsv'.format(name)
tmp = sorted(glob.glob(search_string), key = len)[0]
embedding = pd.read_csv(tmp, sep = '\t', index_col = 0).values

###
# Clustering
###

# Modified from: https://scikit-learn.org/stable/auto_examples/mixture/plot_gmm_selection.html#sphx-glr-auto-examples-mixture-plot-gmm-selection-py

X = np.copy(embedding)
X -= np.min(X, axis = 0)
X /= np.max(X, axis = 0)
if error_weight != 0.0:
    X = np.concatenate((X, error_weight * errors.reshape(-1, 1)), axis = 1)

lowest_bic = np.infty
bic = []
n_components_range = range(1, 33)
cv_types = ['full']
for cv_type in cv_types:
    for n_components in n_components_range:
        gmm = mixture.GaussianMixture(n_components = n_components, covariance_type = cv_type)
        gmm.fit(X)
        bic.append(gmm.bic(X))
        if bic[-1] < lowest_bic:
            lowest_bic = bic[-1]
            best_gmm = gmm

bic = np.array(bic)
color_iter = itertools.cycle(['navy', 'turquoise', 'cornflowerblue', 'darkorange'])
clf = best_gmm
bars = []

plt.figure(figsize=(10, 15))
spl = plt.subplot(3, 1, 1)
for i, (cv_type, color) in enumerate(zip(cv_types, color_iter)):
    xpos = np.array(n_components_range) + 0.2 * (i - 2)
    bars.append(
        plt.bar(
            xpos,
            bic[i * len(n_components_range) : (i + 1) * len(n_components_range)],
            width=0.2,
            color=color,
        )
    )
plt.xticks(n_components_range)
plt.ylim([bic.min() * 1.01 - 0.01 * bic.max(), bic.max()])
plt.title('BIC score per model')
xpos = (
    np.mod(bic.argmin(), len(n_components_range))
    + 0.65
    + 0.2 * np.floor(bic.argmin() / len(n_components_range))
)
plt.text(xpos, bic.min() * 0.97 + 0.03 * bic.max(), '*', fontsize=14)
spl.set_xlabel('Number of components')
spl.legend([b[0] for b in bars], cv_types)

splot = plt.subplot(3, 1, 2)
Y_ = clf.predict(X)
plt.scatter(embedding[:, 0], embedding[:, 1], c = Y_, s = 2)

plt.xticks(())
plt.yticks(())
plt.title(
    f'Selected GMM: {best_gmm.covariance_type} model, '
    f'{best_gmm.n_components} components'
)

splot = plt.subplot(3, 1, 3)
plt.scatter(embedding[:, 0], embedding[:, 1], c = errors, s = 2)
plt.title('Original data')

plt.subplots_adjust(hspace = 0.35, bottom = 0.02)

plt.savefig('{}/out.png'.format(out_dir))
plt.close()

###
# Visualization
###

# Create the map from "group" to "set of points" (recorded as indices)
cluster_map = defaultdict(list)
for i, v in enumerate(Y_):        
    cluster_map[v].append(i)

# Score each of those groups
scores = []
for i in cluster_map:
    indices = cluster_map[i]
    score = len(indices) * np.mean(errors[indices]) ** 2 # Equivalent to 'number of errors * error rate'
    scores.append((i, score))
scores = sorted(scores, key = lambda x: -1 * x[1])

print('Scores:')
for i, score in scores:
    indices = cluster_map[i]
    print(i, score, len(indices) * np.mean(errors[indices]), np.mean(errors[indices]))
print()
    
# Re-order that map and cap the number of groups
new_map = {}
num_groups = 0
for index, _ in scores:
    new_map[index] = cluster_map[index]
    num_groups += 1
    if num_groups == max_groups:
        break
    
cluster_map = new_map

with open('{}/map.pkl'.format(out_dir), 'wb') as f:
    pickle.dump(cluster_map, f)

# Visualize the groups
for i, key in enumerate(cluster_map):
    choices = cluster_map[key]
    chosen = np.array(random.sample(choices, min(k, len(choices))))
    chosen = chosen[np.argsort(errors[chosen])]    
    imgs = []
    for j in chosen:
        img = Image.open(out['files'][j]).convert('RGB')
        if transform is not None:
            img = transform(img)
        if errors[j] >= threshold: 
            img = ImageOps.expand(img, border = 5, fill = (255, 0, 0))
            img = ImageOps.expand(img, border = 3, fill = 0)
        else:
            img = ImageOps.expand(img, border = 8, fill = 0)
        imgs.append(img)
    grid = image_grid(imgs, d1, d2)
    plt.imshow(grid)
    plt.axis('off')
    plt.savefig('{}/group_{}.png'.format(out_dir, i))
    plt.close()
    