import numpy as np

import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt

import seaborn as sns

model_name = "Adaptive_SingleLayer_Fix2_Reg00001"
selections_name = model_name + "_selections.npy"
labels_name = model_name + "_labels.npy"

selections = np.load(selections_name)
labels = np.load(labels_name)

selections = selections.reshape(selections.shape[0], selections.shape[1], selections.shape[2], 50, 3)
selections = selections[:,:,:,:,0]
selections = selections.squeeze()

average_selections1 = selections.mean(1)
average_selections2 = average_selections1.reshape(average_selections1.shape[0], 25, 2)
average_selections2 = average_selections2.mean(-1)

average_selections = average_selections2

num_classes = 60
num_features = average_selections.shape[1]

selections_per_class = np.zeros((num_classes, num_features))
num_class_instances = np.zeros((num_classes))
for i in range(average_selections.shape[0]):
    selections_per_class[labels[i], :] += average_selections[i, :]
    num_class_instances[labels[i]] += 1

num_class_instances = np.expand_dims(num_class_instances, axis=-1)

normalized = selections_per_class / num_class_instances

plt.clf()
plt.figure(figsize = (8,10))
feature_names = ['base of the spine', 'middle of the spine', 'neck', 'head', 'left shoulder', 'left elbow', 'left wrist', 'left hand', 'right shoulder', 'right elbow', 'right wrist', 'right hand', 'left hip', 'left knee', 'left ankle', 'left foot', 'right hip', 'right knee', 'right ankle', 'right foot', 'spine', 'tip of the left hand', 'left thumb', 'tip of the right hand', 'right thumb']
label_names = ["drink", "eat", "brushing teeth", "brushing hair", "drop", "pickup", "throw", "sitting down", "standing up", "clapping", "reading", "writing", "tear up paper", "wear jacket", "take off jacket", "wear a shoe", "take off a shoe", "wear on glasses", "take off glasses", "put on a hat/cap", "take off a hat/cap", "cheer up", "hand waving", "kicking something", "reach into pocket", "hopping", "jump up", "answer phone", "playing with phone", "typing on a keyboard", "pointing to something with finger", "taking a selfie", "check time", "rub two hands together", "nod head/bow", "shake head", "wipe face", "salute", "put the palms together", "cross hands in front", "sneeze", "staggering", "falling", "touch head", "touch chest", "touch back", "touch neck", "nausea or vomiting condition", "use a fan", "punching other person", "kicking other person", "pushing other person", "pat on back of other person", "point finger at the other person", "hugging other person", "giving something to other person", "touch other person's pocket", "handshaking", "walking towards each other", "walking apart from each other"]

sns_plot = sns.heatmap(normalized, xticklabels=feature_names, yticklabels=label_names)
sns_plot = sns_plot.get_figure()
sns_plot.savefig("heatmap.pdf", bbox_inches='tight')
sns_plot.savefig("heatmap.png", bbox_inches='tight')

plt.clf()
plt.figure(figsize = (12,8))
feature_names = ['base of the spine', 'middle of the spine', 'neck', 'head', 'left shoulder', 'left elbow', 'left wrist', 'left hand', 'right shoulder', 'right elbow', 'right wrist', 'right hand', 'left hip', 'left knee', 'left ankle', 'left foot', 'right hip', 'right knee', 'right ankle', 'right foot', 'spine', 'tip of the left hand', 'left thumb', 'tip of the right hand', 'right thumb']
label_names = ["drink", "eat", "brushing teeth", "brushing hair", "drop", "pickup", "throw", "sitting down", "standing up", "clapping", "reading", "writing", "tear up paper", "wear jacket", "take off jacket", "wear a shoe", "take off a shoe", "wear on glasses", "take off glasses", "put on a hat/cap", "take off a hat/cap", "cheer up", "hand waving", "kicking something", "reach into pocket", "hopping", "jump up", "answer phone", "playing with phone", "typing on a keyboard", "pointing to something with finger", "taking a selfie", "check time", "rub two hands together", "nod head/bow", "shake head", "wipe face", "salute", "put the palms together", "cross hands in front", "sneeze", "staggering", "falling", "touch head", "touch chest", "touch back", "touch neck", "nausea or vomiting condition", "use a fan", "punching other person", "kicking other person", "pushing other person", "pat on back of other person", "point finger at the other person", "hugging other person", "giving something to other person", "touch other person's pocket", "handshaking", "walking towards each other", "walking apart from each other"]

sns_plot = sns.heatmap(np.transpose(normalized), xticklabels=label_names, yticklabels=feature_names)
sns_plot = sns_plot.get_figure()
sns_plot.savefig("heatmap_transpose.pdf", bbox_inches='tight')
sns_plot.savefig("heatmap_transpose.png", bbox_inches='tight')