
import os, sys
import gzip
import pickle
import argparse
import json
from tqdm import tqdm

import numpy as np
import torch
from torch.utils.data import DataLoader
import e3nn
from e3nn import o3
import matplotlib.pyplot as plt

from scipy.stats import spearmanr

sys.path.append('..')

from utils import get_wigner_D_from_rot_matrix, rotate_signal
from utils.argparse_utils import *
from utils.protein import *

from loss_functions import *

import umap

'''
Given a trained model's umap of invariants, we want to compare datapoints (and their reconstructions) that have the same label but are in different clusters.
The goal is to pinpoint what aspects of the data the model has picked up on.

Return the indices of the datapoints.

aa_to_ind = {'CYS': 2, 'ILE': 8, 'GLN': 12, 'VAL': 6, 'LYS': 13,
             'PRO': 4, 'GLY': 0, 'THR': 5, 'PHE': 16, 'GLU': 14,
             'HIS': 15, 'MET': 11, 'ASP': 7, 'LEU': 9, 'ARG': 17,
             'TRP': 19, 'ALA': 1, 'ASN': 10, 'TYR': 18, 'SER': 3}

'''


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default='../runs/toy_aminoacids/local_equiv_fibers')
    parser.add_argument('--umap_bounds', type=str, required=True) # e.g. for syntax PHE = [-9.0,-5.0]_[4.0,8.0]?[-6.0,-3.0]_[14.0,17.0] ; TYR = [-7.5,-5.0]_[-1.0,2.5]?[-4.0,-1.0]_[15.0,18.0]
    parser.add_argument('--label', type=str, required=True)
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument('--model_type', type=str, default='lowest_total_loss_with_final_kl_model')
    parser.add_argument('--hash', type=str, required=True)
    parser.add_argument('--n_samples', type=int, default=10)
    parser.add_argument('--seed', type=int, default=1000005)

    args = parser.parse_args()

    device = 'cpu'
    
    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'no_training':
        model_type_str = '-no_training'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'
    
    arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=%s.npz' % (model_type_str, args.split)))
    invariants_ND = arrays['invariants_ND']
    learned_frames_N9 = arrays['learned_frames_N9']
    labels_N = arrays['labels_N']
    rotations_N9 = arrays['rotations_N9']
    images_NF = arrays['images_NF']
    rec_images_NF = arrays['rec_images_NF']

    N = labels_N.shape[0]

    clusters_ids = args.umap_bounds.split('?')
    clusters = [cluster.split('_') for cluster in clusters_ids]
    clusters = [[tuple(map(lambda x: float(x), bounds.strip('[').strip(']').split(','))) for bounds in cluster] for cluster in clusters]
    print(clusters)

    # get umap
    lower_dim_invariants_N2 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants%s-split=%s.npy' % (model_type_str, args.split)))

    # get stuff within bounds!
    images_per_cluster = []
    for i, ((x_low, x_high), (y_low, y_high)) in enumerate(clusters):
        idxs = np.arange(N)
        is_in_x = np.logical_and(lower_dim_invariants_N2[:, 0] > x_low, lower_dim_invariants_N2[:, 0] < x_high)
        # print(np.sum(is_in_x))
        is_in_y = np.logical_and(lower_dim_invariants_N2[:, 1] > y_low, lower_dim_invariants_N2[:, 1] < y_high)
        # print(np.sum(is_in_y))
        is_in_cluster = np.logical_and(is_in_x, is_in_y)
        # print(np.sum(is_in_cluster))
        is_in_cluster_and_is_label = np.logical_and(is_in_cluster, labels_N == aa_to_ind[args.label])
        # print(np.sum(is_in_cluster_and_is_label))
        idxs_in_cluster = idxs[is_in_cluster_and_is_label]

        idxs_to_show = np.random.default_rng(args.seed).choice(idxs_in_cluster, size=args.n_samples)

        print('Cluster #{}: {}'.format(i+1, list(idxs_to_show)))



