
import os, sys
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import umap
import argparse

sys.path.append('..')
from utils.argparse_utils import *

DATA_QUANTITY_TO_BATCH_SIZE_DICT = {
    0: 0,
    400: 4,
    1000: 10,
    2000: 20,
    5000: 50,
    20000: 200
}


def rotate_latent_space(emb_N2, labels_N):
    # get bounding box limits
    x_low, x_high = np.min(emb_N2[:, 0]), np.max(emb_N2[:, 0])
    y_low, y_high = np.min(emb_N2[:, 1]), np.max(emb_N2[:, 1])

    # make box a square
    x_half_length, y_half_length = np.abs((x_high - x_low)/2), np.abs((y_high - y_low)/2)
    delta = np.abs(x_half_length - y_half_length)
    if x_half_length < y_half_length:
        x_low  -= delta
        x_high += delta
    else:
        y_low  -= delta
        y_high += delta

    # get center of square
    c_sq_2 = np.array([x_low + (x_high - x_low)/2, y_low + (y_high - y_low)/2])

    # subtract center from all coords
    emb_N2 -= c_sq_2
    x_low, x_high = list(np.array([x_low, x_high]) - c_sq_2)
    y_low, y_high = list(np.array([y_low, y_high]) - c_sq_2)

    # get coords of point I want to move (center of TRP)
    from utils.protein import aa_to_ind
    orig_coords_2 = np.mean(emb_N2[labels_N == aa_to_ind['TRP']], axis=0)

    # get cords I want to move point to
    coeff = 0.85
    new_coords_2 = np.array([x_low + coeff*(x_high - x_low), y_low + coeff*(y_high - y_low)])

    # get rotation_matrix
    theta = np.arctan2(new_coords_2[1]*orig_coords_2[0] - new_coords_2[0]*orig_coords_2[1], new_coords_2[0]*orig_coords_2[0] + new_coords_2[1]*orig_coords_2[1])
    if theta < 0:
        theta += 2 * np.pi
    cos_theta = np.cos(theta)
    sin_theta = np.sin(theta)
    R_22 = np.array([[cos_theta, - sin_theta],
                     [sin_theta, cos_theta]])
    
    return emb_N2 @ R_22.T
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default='../runs/toy_aminoacids/local_equiv_fibers')
    parser.add_argument('--z', type=int, default=2)
    parser.add_argument('--data_quantity', type=int, default=20000)
    parser.add_argument('--split', type=str, default='test')
    args = parser.parse_args()

    HASHES = [
        'cgvae_symm_simp_flex-AE-z=%d-x_lambda=400-data=%d-bs=%d_v1' % (args.z, args.data_quantity, DATA_QUANTITY_TO_BATCH_SIZE_DICT[args.data_quantity]),
        'cgvae_symm_simp_flex-VAE_min_loss_with_final_kl-z=%d-x_lambda=400-data=%d-bs=%d-kl_lambda=0.025_v1' % (args.z, args.data_quantity, DATA_QUANTITY_TO_BATCH_SIZE_DICT[args.data_quantity]),
        'cgvae_symm_simp_flex-VAE_min_loss_with_final_kl-z=%d-x_lambda=400-data=%d-bs=%d-kl_lambda=0.05_v1' % (args.z, args.data_quantity, DATA_QUANTITY_TO_BATCH_SIZE_DICT[args.data_quantity]),
        'cgvae_symm_simp_flex-VAE_min_loss_with_final_kl-z=%d-x_lambda=400-data=%d-bs=%d-kl_lambda=0.1_v1' % (args.z, args.data_quantity, DATA_QUANTITY_TO_BATCH_SIZE_DICT[args.data_quantity]),
        'cgvae_symm_simp_flex-VAE_min_loss_with_final_kl-z=%d-x_lambda=400-data=%d-bs=%d-kl_lambda=0.25_v1' % (args.z, args.data_quantity, DATA_QUANTITY_TO_BATCH_SIZE_DICT[args.data_quantity]),
        'cgvae_symm_simp_flex-VAE_min_loss_with_final_kl-z=%d-x_lambda=400-data=%d-bs=%d-kl_lambda=0.5_v1' % (args.z, args.data_quantity, DATA_QUANTITY_TO_BATCH_SIZE_DICT[args.data_quantity])
    ]
    MODEL_TYPES = [
        'lowest_rec_loss',
        'lowest_total_loss_with_final_kl_model',
        'lowest_total_loss_with_final_kl_model',
        'lowest_total_loss_with_final_kl_model',
        'lowest_total_loss_with_final_kl_model',
        'lowest_total_loss_with_final_kl_model'
    ]

    KL_LAMBDAS = ['0', '0.025', '0.05', '0.1', '0.25', '0.5']
    NROWS = 2
    NCOLS = 3
    assert len(HASHES) == len(KL_LAMBDAS)
    assert (NROWS * NCOLS) == len(HASHES)

    # plotting constants
    MARKER_SCALING = 0.4
    GOLDEN_RATIO = (1 + 5.0**0.5) / 2
    COLORS_20 = plt.get_cmap('tab20').colors
    ALPHA = 0.6

    # get umap projections for each hash
    umaps_list, colors_list = [], []
    for hash, model_type in zip(HASHES, MODEL_TYPES):

        if 'data=0' in hash:
            model_type_str = '-no_training'
        elif model_type == 'best':
            model_type_str = ''
            model_name = 'best_model.pt'
        elif model_type == 'best_04':
            model_type_str = ''
            model_name = 'best_model_04.pt'
        elif model_type == 'best_05':
            model_type_str = '-best_05'
            model_name = 'best_model_05.pt'
        elif model_type == 'best_06':
            model_type_str = '-best_06'
            model_name = 'best_model_06.pt'
        elif model_type == 'best_higher_kld':
            model_type_str = '-best_model_higher_kld'
            model_name = 'best_model_higher_kld.pt'
        elif model_type == 'lowest_rec_loss':
            model_type_str = '-lowest_rec_loss'
            model_name = 'lowest_rec_loss_model.pt'
        elif model_type == 'final':
            model_type_str = '-final_model'
            model_name = 'final_model.pt'
        elif model_type == 'no_training':
            model_type_str = '-no_training'
        elif model_type == 'lowest_total_loss_with_final_kl_model':
            model_type_str = '-lowest_total_loss_with_final_kl_model'
            model_name = 'lowest_total_loss_with_final_kl_model.pt'
        
        try:
            arrays = np.load(os.path.join(args.model_dir, hash, 'results_arrays/inference%s-split=%s.npz' % (model_type_str, args.split)))
        except Exception as E:
            if model_type == 'lowest_rec_loss':
                model_type_str = '-lowest_total_loss_with_final_kl_model'
                arrays = np.load(os.path.join(args.model_dir, hash, 'results_arrays/inference%s-split=%s.npz' % (model_type_str, args.split)))
                print('Warning, using model_type_str: %s' % (model_type_str), file=sys.stderr)
            else:
                raise E
        invariants_ND = arrays['invariants_ND']
        labels_N = arrays['labels_N']

        try:
            lower_dim_invariants_N2 = np.load(os.path.join(args.model_dir, hash, 'results_arrays/umap_invariants%s-split=%s.npy' % (model_type_str, args.split)))
        except:
            print('Computing umaps for hash: {}'.format(hash))
            if invariants_ND.shape[1] == 2:
                lower_dim_invariants_N2 = invariants_ND
            else:
                lower_dim_invariants_N2 = umap.UMAP(random_state=42).fit_transform(invariants_ND)
            np.save(os.path.join(args.model_dir, hash, 'results_arrays/umap_invariants%s-split=%s.npy' % (model_type_str, args.split)), lower_dim_invariants_N2)
        
        lower_dim_invariants_N2 = rotate_latent_space(lower_dim_invariants_N2, labels_N)

        umaps_list.append(lower_dim_invariants_N2)
        colors_list.append(list(map(lambda i: COLORS_20[i] + tuple([ALPHA]), labels_N)))
    

    # plot legend only
    import pylab
    from matplotlib.lines import Line2D
    from utils.protein import *
    legend_elements = [Line2D([0], [0], marker='o', markersize=9.0, ls='', color=COLORS_20[label], markerfacecolor=COLORS_20[label], label='%s' % (ind_to_aa[label])) for label in list(sorted(list(set(list(labels_N)))))] # labels_N can be from anyy as long as all 20 AAs are represented
    figLegend = pylab.figure(figsize = (1.3, 6.2))
    figLegend.legend(handles=legend_elements, loc='center', prop={'size': 15})
    figLegend.savefig('AAs_legend.png')
    figLegend.savefig('AAs_legend.pdf')


    # make figure with six subplots
    height = 5.0
    width = height # * GOLDEN_RATIO


    # fig, axs = plt.subplots(figsize=(NCOLS*width, NROWS*height), ncols=NCOLS, nrows=NROWS)
    # axs = axs.flatten()
    # for ax, umaps, colors, data_size in zip(axs, umaps_list, colors_list, DATA_SIZES_WITH_COMMA):
    #     ax.scatter(umaps[:, 0], umaps[:, 1], c=colors, edgecolors=colors, s=(mpl.rcParams['lines.markersize']*MARKER_SCALING)**2)
    #     ax.set_title(data_size, fontdict={'fontsize': 16, 'fontweight': 'bold'})
    #     ax.axis('off')
    # plt.tight_layout()
    # plt.savefig('AAs_data_ablation_umaps-title=with_comma-box=False.png')
    # # plt.savefig('AAs_data_ablation_umaps-title=with_k-box=False.pdf')


    # fig, axs = plt.subplots(figsize=(NCOLS*width, NROWS*height), ncols=NCOLS, nrows=NROWS)
    # axs = axs.flatten()
    # for ax, umaps, colors, data_size in zip(axs, umaps_list, colors_list, DATA_SIZES_WITH_COMMA):
    #     ax.scatter(umaps[:, 0], umaps[:, 1], c=colors, edgecolors=colors, s=(mpl.rcParams['lines.markersize']*MARKER_SCALING)**2)
    #     ax.set_title(data_size, fontdict={'fontsize': 16, 'fontweight': 'bold'})
    #     ax.set_xticks([])
    #     ax.set_yticks([])
    # plt.tight_layout()
    # plt.savefig('AAs_data_ablation_umaps-title=with_comma-box=True.png')
    # # plt.savefig('AAs_data_ablation_umaps-title=with_comma-box=True.pdf')


    # fig, axs = plt.subplots(figsize=(NCOLS*width, NROWS*height), ncols=NCOLS, nrows=NROWS)
    # axs = axs.flatten()
    # for ax, umaps, colors, data_size in zip(axs, umaps_list, colors_list, DATA_SIZES_WITH_COMMA):
    #     ax.scatter(umaps[:, 0], umaps[:, 1], c=colors, edgecolors=colors, s=(mpl.rcParams['lines.markersize']*MARKER_SCALING)**2)
    #     ax.set_title(data_size, fontdict={'fontsize': 16, 'fontweight': 'bold'})
    #     ax.axis('off')
    # plt.subplots_adjust(left=0.05,
    #                     bottom=0.05, 
    #                     right=0.95, 
    #                     top=0.95, 
    #                     wspace=0.2, 
    #                     hspace=0.2)
    # plt.savefig('AAs_data_ablation_umaps-title=with_comma-box=False_and_spaced_out.png')
    # # plt.savefig('AAs_data_ablation_umaps-title=with_k-box=False.pdf')


    for umaps, colors, kl_lambda in zip(umaps_list, colors_list, KL_LAMBDAS):
        plt.figure()
        plt.scatter(umaps[:, 0], umaps[:, 1], c=colors, edgecolors='none', s=(mpl.rcParams['lines.markersize']*MARKER_SCALING)**2)
        plt.axis('off')
        plt.tight_layout()
        plt.savefig('AAs_kld_ablation_umap-kl_lambda=%s-data=%d-z=%d.png' % (kl_lambda, args.data_quantity, args.z))
        plt.savefig('AAs_kld_ablation_umap-kl_lambda=%s-data=%d-z=%d.pdf' % (kl_lambda, args.data_quantity, args.z))
        plt.close()


    fig, axs = plt.subplots(figsize=(NCOLS*width, NROWS*height), ncols=NCOLS, nrows=NROWS)
    axs = axs.flatten()
    for ax, umaps, colors, kl_lambda in zip(axs, umaps_list, colors_list, KL_LAMBDAS):
        ax.scatter(umaps[:, 0], umaps[:, 1], c=colors, edgecolors='none', s=(mpl.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.set_title(kl_lambda, fontdict={'fontsize': 16, 'fontweight': 'bold'})
        ax.axis('off')
    plt.subplots_adjust(left=0.05,
                        bottom=0.05, 
                        right=0.95, 
                        top=0.95, 
                        wspace=0.2, 
                        hspace=0.2)
    plt.savefig('AAs_kld_ablation_umaps-data=%d-z=%d.png' % (args.data_quantity, args.z))
    plt.savefig('AAs_kld_ablation_umaps-data=%d-z=%d.pdf' % (args.data_quantity, args.z))

