'''
Overlap the sum of signals of a single digit in one figure in one figure, with different colors
'''

import os, sys
import pickle
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

import argparse
sys.path.append('..')

from utils.argparse_utils import *

COLORS = plt.get_cmap('tab10').colors
RED = COLORS[3]
BLUE = COLORS[4]

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_type', type=str, default='RR-avg_sqrt_power-1,7')
    parser.add_argument('--model_dir', type=str)
    parser.add_argument('--model_type', type=str, default='lowest_total_loss_with_final_kl_model') # lowest_total_loss_with_final_kl_model, lowest_rec_loss
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument('--hash', type=str)
    parser.add_argument('--comma_sep_labels', type=str, default='7,1') # get sum of signals of individual labels only!!!

    args = parser.parse_args()

    print(args.hash, file=sys.stderr)
    print(args.comma_sep_labels, file=sys.stderr)

    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'
    
    labels = list(map(int, args.comma_sep_labels.split(',')))
    images = {}
    for label in labels:
        with open(os.path.join(args.model_dir, args.hash, 'sum_of_images', 'rec_images_unrotated_by_inverse_of_frame%s-split=%s-input_type=%s-labels=%s.pkl' % (model_type_str, args.split, args.input_type, label)), 'rb') as f:
            images[label] = pickle.load(f)
    
    plt.figure(figsize=(6, 6))
    N = 256
    for label_i, label in enumerate(labels):
        color = COLORS[label]
        colormap = np.ones((N, 4))
        for rgb_i, rgb_value in enumerate(color):
            colormap[:, rgb_i] = np.linspace(1, rgb_value, N)
        alpha = (len(labels) - label_i) / len(labels)
        alpha *= 1.5
        colormap[:, 3] = np.linspace(0, min(alpha, 1.0), N)
        colormap = ListedColormap(colormap)

        plt.imshow(images[label].reshape(60, 60), cmap=colormap)

    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'sum_of_images', '__overlap_of_rec_images_unrotated_by_inverse_of_frame%s-split=%s-input_type=%s-labels=%s.png' % (model_type_str, args.split, args.input_type, args.comma_sep_labels)))
    plt.savefig(os.path.join(args.model_dir, args.hash, 'sum_of_images', '__overlap_of_rec_images_unrotated_by_inverse_of_frame%s-split=%s-input_type=%s-labels=%s.pdf' % (model_type_str, args.split, args.input_type, args.comma_sep_labels)))
