
import os, sys
import gzip, pickle
import json
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pylab
import torch
import e3nn
from e3nn import o3
import umap
import argparse

sys.path.append('..')
from projections import ZernickeRadialFunctions, RadialSphericalTensor, MultiChannelRadialSphericalTensor

ATOM_TYPES_STRINGNAMES_DICT = {
    b' CA ': 'CA',
    b' N  ': 'N',
    b' C  ': 'C',
    b' O  ': 'O'
}
PERC = '%'

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment_dir', type=str, required=True) # '../runs/neighborhoods/local_equiv_fibers' + hash
    parser.add_argument('--model_type', type=str, default='lowest_total_loss_with_final_kl_model')
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument('--statistics_dict_file', type=str, default='../data/neighborhoods/data/casp12_testing_neighborhoods-yes_residue-no_sidechain-H,E,C-r=12.5-0_to_215_proteins-STATISTICS_DICT.gz')
    args = parser.parse_args()

    # plotting constants
    if args.split =='valid':
        MARKER_SCALING = 0.5
        ALPHA = 0.3
    elif args.split =='test':
        MARKER_SCALING = 1.0
        ALPHA = 0.6
    
    GOLDEN_RATIO = (1 + 5.0**0.5) / 2

    HEIGHT = 5.0
    WIDTH = HEIGHT * GOLDEN_RATIO

    COLORS_10 = plt.get_cmap('tab10').colors
    COLORS_AA_LABEL = [
        COLORS_10[0],
        COLORS_10[1],
        COLORS_10[2]
    ]

    COLORS_SEC_STRUCT = [
        COLORS_10[0],
        COLORS_10[7],
        COLORS_10[6]
    ]
    
    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'no_training':
        model_type_str = '-no_training'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'

    with open(os.path.join(args.experiment_dir, 'hparams.json'), 'r') as f:
        hparams = json.load(f)

    OnRadialFunctions = ZernickeRadialFunctions(hparams['rcut'], hparams['rmax']+1, hparams['lmax'], complex_sph = False)
    rst = RadialSphericalTensor(hparams['rmax']+1, OnRadialFunctions, hparams['lmax'], 1, 1)
    mul_rst = MultiChannelRadialSphericalTensor(rst, hparams['n_channels'])
    data_irreps = o3.Irreps(str(mul_rst))


    arrays = np.load(os.path.join(args.experiment_dir, 'results_arrays/inference%s-split=%s.npz' % (model_type_str, args.split)))
    invariants_ND = arrays['invariants_ND']
    data_ids_N = arrays['data_ids_N']
    images_NF = arrays['images_NF']
    rec_images_NF = arrays['rec_images_NF']

    if os.path.exists(os.path.join(args.experiment_dir, 'results_arrays/umap_invariants%s-split=%s.npy' % (model_type_str, args.split))):
        lower_dim_invariants_N2 = np.load(os.path.join(args.experiment_dir, 'results_arrays/umap_invariants%s-split=%s.npy' % (model_type_str, args.split)))
    else:
        print('Computing umaps for experiment: {}'.format(args.experiment_dir))
        if invariants_ND.shape[1] == 2:
            lower_dim_invariants_N2 = invariants_ND
        else:
            lower_dim_invariants_N2 = umap.UMAP(random_state=42).fit_transform(invariants_ND)
        np.save(os.path.join(args.experiment_dir, 'results_arrays/umap_invariants%s-split=%s.npy' % (model_type_str, args.split)), lower_dim_invariants_N2)




    from matplotlib.lines import Line2D

    # saved statistics
    with gzip.open(args.statistics_dict_file, 'rb') as f:
        statistics_dict = pickle.load(f)

    # residue type (aa label) colors
    aa_label_N = [data_id.split('_')[0] for data_id in data_ids_N]
    aa_label_sorted_set = list(sorted(list(set(aa_label_N))))
    aa_label_idx_dict = {}
    for i, aa_label in enumerate(aa_label_sorted_set):
        aa_label_idx_dict[aa_label] = i
    aa_label_idxs_N = []
    for aa_label in aa_label_N:
        aa_label_idxs_N.append(aa_label_idx_dict[aa_label])
    aa_label_colors_N = list(map(lambda i: COLORS_AA_LABEL[i], aa_label_idxs_N))
    aa_label_name_dict = {
        'H': 'HIS',
        'E': 'GLU',
        'C': 'CYS'
    }
    aa_label_legend_elements = [Line2D([0], [0], marker='o', markersize=9.0, ls='', color=COLORS_AA_LABEL[aa_label_idx_dict[label]], markerfacecolor=COLORS_AA_LABEL[aa_label_idx_dict[label]], label='%s' % (aa_label_name_dict[label])) for label in list(sorted(list(set(list(aa_label_N)))))]

    # secondary structure colors
    sec_struct_N = [data_id.split('_')[5] for data_id in data_ids_N]
    sec_struct_sorted_set = list(sorted(list(set(sec_struct_N))))
    sec_struct_idx_dict = {}
    for i, sec_struct in enumerate(sec_struct_sorted_set):
        sec_struct_idx_dict[sec_struct] = i
    sec_struct_idxs_N = []
    for sec_struct in sec_struct_N:
        sec_struct_idxs_N.append(sec_struct_idx_dict[sec_struct])
    sec_struct_colors_N = list(map(lambda i: COLORS_SEC_STRUCT[i], sec_struct_idxs_N))
    sec_struct_name_dict = {
        'H': '$\\alpha$-helix',
        'E': '$\\beta$-sheet',
        'L': 'loop'
    }
    sec_struct_legend_elements = [Line2D([0], [0], marker='o', markersize=12.0, ls='', color=COLORS_SEC_STRUCT[sec_struct_idx_dict[label]], markerfacecolor=COLORS_SEC_STRUCT[sec_struct_idx_dict[label]], label='%s' % (sec_struct_name_dict[label])) for label in list(sorted(list(set(list(sec_struct_N)))))]

    # total power of original data
    ls_indices = torch.cat([torch.tensor(data_irreps.ls)[torch.tensor(data_irreps.ls) == l].repeat(2*l+1) for l in sorted(list(set(data_irreps.ls)))]).type(torch.float)
    orig_powers_N = torch.einsum('bf,bf,f->b', torch.tensor(images_NF), torch.tensor(images_NF), 1.0 / (2*ls_indices + 1)).numpy()
    
    # number of atoms
    num_atoms_N = []
    num_atoms_dict = statistics_dict['num_atoms']

    for data_id in data_ids_N:
        num_atoms_N.append(num_atoms_dict[data_id])
    
    NONZERO_ATOMS_MASK_N = np.array(num_atoms_N) != 0
    print('Excluding %d/%d (%.5f%s) neighborhoods because they give zero number of atoms.' % (np.sum(np.logical_not(NONZERO_ATOMS_MASK_N)), len(num_atoms_N), (np.sum(np.logical_not(NONZERO_ATOMS_MASK_N)) / len(num_atoms_N))*100, PERC))

    # average sasa
    avg_sasa_N = []
    avg_sasa_dict = statistics_dict['avg_sasa']

    for data_id in data_ids_N:
        avg_sasa_N.append(avg_sasa_dict[data_id])
    
    # NONZERO_ATOMS_MASK_N = np.array(avg_sasa_N) != 0
    # print('Excluding %d/%d (%.5f%s) neighborhoods because they give zero number of atoms.' % (np.sum(np.logical_not(NONZERO_ATOMS_MASK_N)), len(avg_sasa_N), (np.sum(np.logical_not(NONZERO_ATOMS_MASK_N)) / len(avg_sasa_N))*100, PERC))
    
    # proportion of atoms of a given type
    prop_atoms_of_type_N = {}
    num_atoms_of_type_dict = statistics_dict['num_atoms_of_type']
    for atom_type in num_atoms_of_type_dict:
        prop_atoms_of_type_N[atom_type] = []
    
    for data_id in data_ids_N:
        for atom_type in num_atoms_of_type_dict:
            prop_atoms_of_type_N[atom_type].append((num_atoms_of_type_dict[atom_type][data_id] / num_atoms_dict[data_id]) * 100)
    


    
    # individual plots with categorical labels
    for colors, legend_handles, id_for_filename in zip([np.array(aa_label_colors_N), np.array(sec_struct_colors_N)],
                                                        [aa_label_legend_elements, sec_struct_legend_elements],
                                                        ['residue_type', 'secondary_structure']):
        plt.figure(figsize=(WIDTH, HEIGHT))
        plt.scatter(lower_dim_invariants_N2[:, 0][NONZERO_ATOMS_MASK_N], lower_dim_invariants_N2[:, 1][NONZERO_ATOMS_MASK_N], c=colors[NONZERO_ATOMS_MASK_N], alpha=ALPHA, edgecolors='none', s=(mpl.rcParams['lines.markersize']*MARKER_SCALING)**2)
        # plt.legend(handles=legend_handles, prop={'size': 14.5})
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(args.experiment_dir, 'latent_space_viz/__pretty_umap__%s%s-split=%s.png' % (id_for_filename, model_type_str, args.split)), bbox_inches='tight')
        plt.savefig(os.path.join(args.experiment_dir, 'latent_space_viz/__pretty_umap__%s%s-split=%s.pdf' % (id_for_filename, model_type_str, args.split)), bbox_inches='tight')

        # save legend separately
        figLegend = pylab.figure(figsize = (6.0, 0.6))
        figLegend.legend(handles=legend_handles, loc='center', ncol=len(legend_handles), prop={'size': 18})
        figLegend.savefig(os.path.join(args.experiment_dir, 'latent_space_viz/__legend__%s.png' % (id_for_filename)))
        figLegend.savefig(os.path.join(args.experiment_dir, 'latent_space_viz/__legend__%s.pdf' % (id_for_filename)))
    
    
    # individual plots with sequential/continuous labels
    for colors, cmap, colorbar_label, id_for_filename, vmin_vmax in zip([np.array(orig_powers_N), np.array(num_atoms_N), np.array(avg_sasa_N), np.array(prop_atoms_of_type_N[b' CA ']), np.array(prop_atoms_of_type_N[b' N  ']), np.array(prop_atoms_of_type_N[b' C  ']), np.array(prop_atoms_of_type_N[b' O  '])],
                                                                        ['viridis', 'viridis', 'viridis', 'viridis', 'viridis', 'viridis', 'viridis'],
                                                                        ['Total power', 'Number of\natoms', 'Average\nSASA', '{} of C$\\alpha$ atoms'.format(PERC), '{} of N atoms'.format(PERC), '{} of C atoms'.format(PERC), '{} of O atoms'.format(PERC)],
                                                                        ['input_total_power', 'num_atoms', 'avg_sasa', 'prop_of_CA', 'prop_of_N', 'prop_of_C', 'prop_of_O'],
                                                                        [(None,None), (None,None), (None,None), (16.01,31.99), (16.01,31.99), (16.01,31.99), (16.01,31.99)]):
        plt.figure(figsize=(WIDTH, HEIGHT))
        mpb = plt.scatter(lower_dim_invariants_N2[:, 0][NONZERO_ATOMS_MASK_N], lower_dim_invariants_N2[:, 1][NONZERO_ATOMS_MASK_N], c=colors[NONZERO_ATOMS_MASK_N], cmap=cmap, vmin=vmin_vmax[0], vmax=vmin_vmax[1], alpha=ALPHA, edgecolors='none', s=(mpl.rcParams['lines.markersize']*MARKER_SCALING)**2)
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(args.experiment_dir, 'latent_space_viz/__pretty_umap__%s%s-split=%s.png' % (id_for_filename, model_type_str, args.split)), bbox_inches='tight')
        plt.savefig(os.path.join(args.experiment_dir, 'latent_space_viz/__pretty_umap__%s%s-split=%s.pdf' % (id_for_filename, model_type_str, args.split)), bbox_inches='tight')

        # save colorbar separately
        fig, ax = plt.subplots(figsize=(WIDTH, HEIGHT))
        clb = plt.colorbar(mpb, ax=ax)
        clb.ax.set_title(colorbar_label, fontsize=16)
        clb.ax.tick_params(labelsize=14)
        ax.remove()
        plt.tight_layout()
        plt.savefig(os.path.join(args.experiment_dir, 'latent_space_viz/__colorbar__%s%s-split=%s.png' % (id_for_filename, model_type_str, args.split)), bbox_inches='tight')
        plt.savefig(os.path.join(args.experiment_dir, 'latent_space_viz/__colorbar__%s%s-split=%s.pdf' % (id_for_filename, model_type_str, args.split)), bbox_inches='tight')
