
import os, sys
import gzip, pickle
import json
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sklearn
import torch
from torch.utils.data import DataLoader
import e3nn
from e3nn import o3
import argparse

sys.path.append('..')
from utils import NeighborhoodsDatasetWithConditioningAndIds__fibers
from projections import ZernickeRadialFunctions, RadialSphericalTensor, MultiChannelRadialSphericalTensor
from cgnet_fibers import ClebschGordanVAE_symmetric_simple_flexible

DATA_ARGS = ['rmax', 'lmax', 'n_channels', 'rcut', 'rst_normalization', 'get_H', 'get_SASA', 'get_charge']

def dict_to_device(adict, device):
    for key in adict:
        adict[key] = adict[key].float().to(device)
    return adict

def make_vector(x):
    x_vec = []
    for l in sorted(list(x.keys())):
        x_vec.append(x[l].reshape((x[l].shape[0], -1)))
    return torch.cat(x_vec, dim=-1)

def latent_space_prediction(train_invariants_ND, train_labels_N, eval_invariants_ND, eval_labels_N, valid_invariants_ND=None, valid_labels_N=None, classifier='LC', optimize_hyps=False):
    from sklearn.model_selection import GridSearchCV
    from sklearn.metrics import classification_report
    from classifiers import MultiClassLinearClassifier
    from sklearn.neighbors import KNeighborsClassifier
    
    n_features = train_invariants_ND.shape[1]
    n_classes = len(set(list(train_labels_N)))

    if classifier == 'LC':
        estimator = MultiClassLinearClassifier(n_features, n_classes, batch_size=512, verbose=True)
        hyperparams = {'lr': [0.1, 0.01, 0.001]}
    elif classifier == 'KNN':
        estimator = KNeighborsClassifier()
        hyperparams = {'n_neighbors': [5, 10, 20]}
    else:
        raise NotImplementedError
    
    if optimize_hyps:
        model = GridSearchCV(estimator, hyperparams)
    else:
        model = estimator
    
    if classifier == 'KNN':
        model = model.fit(train_invariants_ND, train_labels_N)
    else:
        model = model.fit(train_invariants_ND, train_labels_N, x_valid_MF=valid_invariants_ND, y_valid_M=valid_labels_N)
    
    predictions = model.predict_proba(eval_invariants_ND)
    onehot_predictions = np.argmax(predictions, axis=1)
    
    return classification_report(eval_labels_N, onehot_predictions, output_dict=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='../data/neighborhoods/data')
    parser.add_argument('--w3j_filepath', type=str, default='../cg_coefficients/w3j_matrices-lmax=14-version=0.5.0.pkl')
    parser.add_argument('--experiment_dir', type=str, required=True) # '../runs/neighborhoods/local_equiv_fibers' + hash
    parser.add_argument('--model_type', type=str, default='lowest_total_loss_with_final_kl_model')
    parser.add_argument('--classifier', type=str, default='KNN')
    parser.add_argument('--seed', type=int, default=100000000)
    
    args = parser.parse_args()

    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'no_training':
        model_type_str = '-no_training'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'

    if args.classifier == 'LC':
        classifier_str = ''
    elif args.classifier == 'KNN':
        classifier_str = 'KNN_'

    ## do inference and save (1) invariants (2) data_ids (which contain aa type and secondary structure)

    try: # load arrays if they've already been computed
        train_arrays = np.load(os.path.join(args.experiment_dir, 'results_arrays/latent_space_only_inference%s-split=train.npz' % (model_type_str)))
        valid_arrays = np.load(os.path.join(args.experiment_dir, 'results_arrays/latent_space_only_inference%s-split=valid.npz' % (model_type_str)))
        test_arrays = np.load(os.path.join(args.experiment_dir, 'results_arrays/latent_space_only_inference%s-split=test.npz' % (model_type_str)))
    except:

        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print('Running on %s.' % device)

        with open(os.path.join(args.experiment_dir, 'hparams.json'), 'r') as f:
            hparams = json.load(f)


        # data preparation and loading stuff
        rng = torch.Generator().manual_seed(hparams['seed'])

        DATA_ARGS = DATA_ARGS[:-3]
        train_data_id = '-'.join(list(sorted(['%s=%s' % (arg, hparams[arg]) for arg in DATA_ARGS] + ['n_neigh=%s' % (hparams['n_train_neigh'])])))
        valid_data_id = '-'.join(list(sorted(['%s=%s' % (arg, hparams[arg]) for arg in DATA_ARGS] + ['n_neigh=%s' % (hparams['n_valid_neigh'])])))
        test_data_id = '-'.join(list(sorted(['%s=%s' % (arg, hparams[arg]) for arg in DATA_ARGS] + ['n_neigh=%s' % (hparams['n_test_neigh'])])))

        normalize_str = '-normalize=%s' % (hparams['normalize']) if hparams['normalize'] is not None else ''
        train_data = torch.tensor(np.load(args.data_dir + '/projections-train%s-complex_sph=False-' % (normalize_str) + train_data_id + '.npy'))
        valid_data = torch.tensor(np.load(args.data_dir + '/projections-val%s-complex_sph=False-' % (normalize_str) + valid_data_id + '.npy'))
        test_data = torch.tensor(np.load(args.data_dir + '/projections-test%s-complex_sph=False-' % (normalize_str) + test_data_id + '.npy'))

        train_labels = torch.tensor(np.load(args.data_dir + '/aa_labels-train-complex_sph=False-' + train_data_id + '.npy'))
        valid_labels = torch.tensor(np.load(args.data_dir + '/aa_labels-val-complex_sph=False-' + valid_data_id + '.npy'))
        test_labels = torch.tensor(np.load(args.data_dir + '/aa_labels-test-complex_sph=False-' + test_data_id + '.npy'))

        try:
            train_frames = torch.tensor(np.load(args.data_dir + '/frames-train-complex_sph=False-' + train_data_id + '.npy')).view(-1, 3, 3)
            valid_frames = torch.tensor(np.load(args.data_dir + '/frames-val-complex_sph=False-' + valid_data_id + '.npy')).view(-1, 3, 3)
            test_frames = torch.tensor(np.load(args.data_dir + '/frames-test-complex_sph=False-' + test_data_id + '.npy')).view(-1, 3, 3)
        except:
            # dummy frames for compatibility
            train_frames = torch.zeros((train_labels.shape[0], 3, 3))
            valid_frames = torch.zeros((valid_labels.shape[0], 3, 3))
            test_frames = torch.zeros((test_labels.shape[0], 3, 3))
        
        def stringify(data_id):
            return '_'.join(list(map(lambda x: x.decode('utf-8'), list(data_id))))

        train_ids = np.array(list(map(stringify, list(np.load(args.data_dir + '/data_ids-train-complex_sph=False-' + train_data_id + '.npy')))))
        valid_ids = np.array(list(map(stringify, list(np.load(args.data_dir + '/data_ids-val-complex_sph=False-' + valid_data_id + '.npy')))))
        test_ids = np.array(list(map(stringify, list(np.load(args.data_dir + '/data_ids-test-complex_sph=False-' + test_data_id + '.npy')))))


        # filter by desired lmax and channels
        OnRadialFunctions = ZernickeRadialFunctions(hparams['rcut'], hparams['rmax']+1, hparams['lmax'], complex_sph = False)
        rst = RadialSphericalTensor(hparams['rmax']+1, OnRadialFunctions, hparams['lmax'], 1, 1)
        mul_rst = MultiChannelRadialSphericalTensor(rst, hparams['n_channels'])
        data_irreps = o3.Irreps(str(mul_rst))

        print('Data irreps: {}'.format(data_irreps), file=sys.stderr)
        print('Data dim: {}'.format(data_irreps.dim), file=sys.stderr)

        train_dataset = NeighborhoodsDatasetWithConditioningAndIds__fibers(train_data, data_irreps, train_labels, train_frames, train_ids)
        valid_dataset = NeighborhoodsDatasetWithConditioningAndIds__fibers(valid_data, data_irreps, valid_labels, valid_frames, valid_ids)
        test_dataset = NeighborhoodsDatasetWithConditioningAndIds__fibers(test_data, data_irreps, test_labels, test_frames, test_ids)

        print('%d training neighborhoods' % (len(train_dataset)), file=sys.stderr)
        print('%d validation neighborhoods' % (len(valid_dataset)), file=sys.stderr)

        train_dataloader = DataLoader(train_dataset, batch_size=hparams['batch_size'], generator=rng, shuffle=True, drop_last=True)
        valid_dataloader = DataLoader(valid_dataset, batch_size=hparams['batch_size'], generator=rng, shuffle=False, drop_last=True)
        test_dataloader = DataLoader(test_dataset, batch_size=100, generator=rng, shuffle=False, drop_last=True)

        dataloaders = {
            'train': train_dataloader,
            'valid': valid_dataloader,
            'test': test_dataloader
        }

        # get w3j matrices
        with gzip.open(args.w3j_filepath, 'rb') as f:
            w3j_matrices = pickle.load(f)

        for key in w3j_matrices:
            if key[0] <= hparams['net_lmax'] and key[1] <= hparams['net_lmax'] and key[2] <= hparams['net_lmax']:
                if device is not None:
                    w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float().to(device)
                else:
                    w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float()
                w3j_matrices[key].requires_grad = False

        if hparams['model'] == 'cgvae_symmetric_simple_flexible':
            for key, value in hparams.items():
                print('{}\t{}'.format(key, value))
            print()
            if 'norm_balanced' not in hparams:
                hparams['norm_balanced'] = False
            cgvae = ClebschGordanVAE_symmetric_simple_flexible(data_irreps,
                                                                hparams['latent_dim'],
                                                                hparams['net_lmax'],
                                                                hparams['n_cg_blocks'],
                                                                list(map(int, hparams['ch_size_list'].split(','))),
                                                                hparams['ls_nonlin_rule_list'].split(','),
                                                                hparams['ch_nonlin_rule_list'].split(','),
                                                                hparams['do_initial_linear_projection'],
                                                                hparams['ch_initial_linear_projection'],
                                                                w3j_matrices,
                                                                device,
                                                                lmax_list=list(map(int, hparams['lmax_list'].split(','))),
                                                                use_additive_skip_connections=hparams['use_additive_skip_connections'],
                                                                use_batch_norm=hparams['use_batch_norm'],
                                                                norm_type=hparams['norm_type'], # None, layer, signal
                                                                normalization=hparams['normalization'], # norm, component -> only considered if norm_type is not none
                                                                norm_balanced=hparams['norm_balanced'],
                                                                norm_affine=hparams['norm_affine'], # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                                                norm_nonlinearity=hparams['norm_nonlinearity'], # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                                                norm_location=hparams['norm_location'], # first, between, last
                                                                linearity_first=hparams['linearity_first'], # currently only works with this being false
                                                                filter_symmetric=hparams['filter_symmetric'], # whether to exclude duplicate pairs of l's from the tensor product nonlinearity
                                                                x_rec_loss_fn=hparams['x_rec_loss_fn'], # mse, mse_normalized, cosine
                                                                do_final_signal_norm=hparams['do_final_signal_norm'],
                                                                learn_frame=hparams['learn_frame'],
                                                                is_vae=hparams['is_vae']).to(device)
        else:
            raise NotImplementedError()

        if args.model_type != 'no_training':
            cgvae.load_state_dict(torch.load(os.path.join(args.experiment_dir, model_name), map_location=torch.device(device)))
        else:
            print('Using untrained model.', file=sys.stderr)
        
        cgvae = cgvae.float()
        cgvae.eval()


        for split, dataloader in zip(['train', 'valid', 'test'], [dataloaders[s] for s in ['train', 'valid', 'test']]):
            invariants, aa_labels, data_ids = [], [], []
            for i, (X, X_rec, y, rot, data_id) in tqdm(enumerate(dataloader)):
                X = dict_to_device(X, device)
                rot = rot.view(-1, 3, 3).float().to(device)

                (z_mean, z_log_var), scalar_features, learned_frame = cgvae.encode(X)

                z = z_mean

                # z_stddev = torch.exp(0.5 * z_log_var) # takes exponential function (log var -> stddev)
                # epsilon = torch.randn_like(z_stddev).to(device)        # sampling epsilon
                # z = z_mean + z_stddev*epsilon    

                invariants.append(z.detach().cpu().numpy())
                data_ids.append(data_id)
                
            invariants_ND = np.vstack(invariants)
            data_ids_N = np.hstack(data_ids) # hstack because we assume the data_ids are represented in their stringified version. otherwise we would have to use vstack

            if not os.path.exists(os.path.join(args.experiment_dir, 'results_arrays')):
                os.mkdir(os.path.join(args.experiment_dir, 'results_arrays'))
            
            np.savez(os.path.join(args.experiment_dir, 'results_arrays/latent_space_only_inference%s-split=%s.npz' % (model_type_str, split)),
                                    invariants_ND = invariants_ND,
                                    data_ids_N = data_ids_N)
        
        # reload the arrays for compatibility with the case in which they have already been computed (the "try" part of this "try-except" block)
        train_arrays = np.load(os.path.join(args.experiment_dir, 'results_arrays/latent_space_only_inference%s-split=train.npz' % (model_type_str)))
        valid_arrays = np.load(os.path.join(args.experiment_dir, 'results_arrays/latent_space_only_inference%s-split=valid.npz' % (model_type_str)))
        test_arrays = np.load(os.path.join(args.experiment_dir, 'results_arrays/latent_space_only_inference%s-split=test.npz' % (model_type_str)))



    ## use same code as in the evaluation script to do classification

    def make_array_idx(items_N):
        items_sorted_set = list(sorted(list(set(items_N)))) # the sorting makes sure that the mapping is the same across different lists containing the same labels
        items_idx_dict = {}
        for i, item in enumerate(items_sorted_set):
            items_idx_dict[item] = i
        items_idxs_N = []
        for item in items_N:
            items_idxs_N.append(items_idx_dict[item])
        return items_idxs_N, items_idx_dict

    from sklearn.cluster import KMeans
    from sklearn.metrics import homogeneity_score, completeness_score, silhouette_score

    def purity_score(labels_true, labels_pred):
        contingency_matrix = sklearn.metrics.cluster.contingency_matrix(labels_true, labels_pred)
        return np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix)


    # residue type
    train_aa_labels_N = np.array([data_id.split('_')[0] for data_id in train_arrays['data_ids_N']])
    valid_aa_labels_N = np.array([data_id.split('_')[0] for data_id in valid_arrays['data_ids_N']])
    test_aa_labels_N = np.array([data_id.split('_')[0] for data_id in test_arrays['data_ids_N']])
    train_aa_labels_idxs_N, train_aa_labels_idx_dict = make_array_idx(train_aa_labels_N)
    valid_aa_labels_idxs_N, valid_aa_labels_idx_dict = make_array_idx(valid_aa_labels_N)
    test_aa_labels_idxs_N, test_aa_labels_idx_dict = make_array_idx(test_aa_labels_N)
    
    # print('Performing residue type classification on validation data...')
    # report = latent_space_prediction(train_arrays['invariants_ND'], train_aa_labels_idxs_N, valid_arrays['invariants_ND'], valid_aa_labels_idxs_N, classifier=args.classifier, optimize_hyps=False)
    # pd.DataFrame(report).to_csv(os.path.join(args.experiment_dir, 'latent_space_classification/__%sclassificaton_on_latent_space_aa_labels%s-split=valid.csv' % (classifier_str, model_type_str)))

    # n_clusters = len(list(set(list(valid_aa_labels_N))))
    # kmeans = KMeans(n_clusters=n_clusters, random_state=args.seed, verbose=0)
    # labels_pred_N = kmeans.fit_predict(valid_arrays['invariants_ND'])

    # homogeneity = homogeneity_score(valid_aa_labels_idxs_N, labels_pred_N)
    # completeness = completeness_score(valid_aa_labels_idxs_N, labels_pred_N)
    # purity = purity_score(valid_aa_labels_idxs_N, labels_pred_N)
    # silhouette = silhouette_score(valid_arrays['invariants_ND'], labels_pred_N)

    # table = {
    #     'Homogeneity': [homogeneity],
    #     'Completeness': [completeness],
    #     'Purity': [purity],
    #     'Silhouette': [silhouette]
    # }

    # pd.DataFrame(table).to_csv(os.path.join(args.experiment_dir, 'latent_space_classification/__quality_of_clustering_metrics_aa_labels%s-split=valid.csv' % (model_type_str)), index=None)


    print('Performing residue type classification on testing data...')
    report = latent_space_prediction(train_arrays['invariants_ND'], train_aa_labels_idxs_N, test_arrays['invariants_ND'], test_aa_labels_idxs_N, valid_invariants_ND=valid_arrays['invariants_ND'], valid_labels_N=valid_aa_labels_idxs_N, classifier=args.classifier, optimize_hyps=False)
    pd.DataFrame(report).to_csv(os.path.join(args.experiment_dir, 'latent_space_classification/__%sclassificaton_on_latent_space_aa_labels%s-split=test.csv' % (classifier_str, model_type_str)))

    # n_clusters = len(list(set(list(test_aa_labels_N))))
    # kmeans = KMeans(n_clusters=n_clusters, random_state=args.seed, verbose=0)
    # labels_pred_N = kmeans.fit_predict(test_arrays['invariants_ND'])

    # homogeneity = homogeneity_score(test_aa_labels_idxs_N, labels_pred_N)
    # completeness = completeness_score(test_aa_labels_idxs_N, labels_pred_N)
    # purity = purity_score(test_aa_labels_idxs_N, labels_pred_N)
    # silhouette = silhouette_score(test_arrays['invariants_ND'], labels_pred_N)

    # table = {
    #     'Homogeneity': [homogeneity],
    #     'Completeness': [completeness],
    #     'Purity': [purity],
    #     'Silhouette': [silhouette]
    # }

    # pd.DataFrame(table).to_csv(os.path.join(args.experiment_dir, 'latent_space_classification/__quality_of_clustering_metrics_aa_labels%s-split=test.csv' % (model_type_str)), index=None)


    # secondary structure
    train_sec_struc_N = np.array([data_id.split('_')[5] for data_id in train_arrays['data_ids_N']])
    valid_sec_struc_N = np.array([data_id.split('_')[5] for data_id in valid_arrays['data_ids_N']])
    test_sec_struc_N = np.array([data_id.split('_')[5] for data_id in test_arrays['data_ids_N']])
    train_sec_struc_idxs_N, train_aa_labels_idx_dict = make_array_idx(train_sec_struc_N)
    valid_sec_struc_idxs_N, valid_aa_labels_idx_dict = make_array_idx(valid_sec_struc_N)
    test_sec_struc_idxs_N, test_aa_labels_idx_dict = make_array_idx(test_sec_struc_N)
    
    # print('Performing secondary structure classification on validation data...')
    # report = latent_space_prediction(train_arrays['invariants_ND'], train_sec_struc_idxs_N, valid_arrays['invariants_ND'], valid_sec_struc_idxs_N, classifier=args.classifier, optimize_hyps=False)
    # pd.DataFrame(report).to_csv(os.path.join(args.experiment_dir, 'latent_space_classification/__%sclassificaton_on_latent_space_sec_struc%s-split=valid.csv' % (classifier_str, model_type_str)))

    # n_clusters = len(list(set(list(valid_sec_struc_N))))
    # kmeans = KMeans(n_clusters=n_clusters, random_state=args.seed, verbose=0)
    # labels_pred_N = kmeans.fit_predict(valid_arrays['invariants_ND'])

    # homogeneity = homogeneity_score(valid_sec_struc_idxs_N, labels_pred_N)
    # completeness = completeness_score(valid_sec_struc_idxs_N, labels_pred_N)
    # purity = purity_score(valid_sec_struc_idxs_N, labels_pred_N)
    # silhouette = silhouette_score(valid_arrays['invariants_ND'], labels_pred_N)

    # table = {
    #     'Homogeneity': [homogeneity],
    #     'Completeness': [completeness],
    #     'Purity': [purity],
    #     'Silhouette': [silhouette]
    # }

    # pd.DataFrame(table).to_csv(os.path.join(args.experiment_dir, 'latent_space_classification/quality_of_clustering_metrics_sec_struc%s-split=valid.csv' % (model_type_str)), index=None)


    print('Performing secondary structure classification on testing data...')
    report = latent_space_prediction(train_arrays['invariants_ND'], train_sec_struc_idxs_N, test_arrays['invariants_ND'], test_sec_struc_idxs_N, valid_invariants_ND=valid_arrays['invariants_ND'], valid_labels_N=valid_sec_struc_idxs_N, classifier=args.classifier, optimize_hyps=False)
    pd.DataFrame(report).to_csv(os.path.join(args.experiment_dir, 'latent_space_classification/__%sclassificaton_on_latent_space_sec_struc%s-split=test.csv' % (classifier_str, model_type_str)))

    # n_clusters = len(list(set(list(test_sec_struc_N))))
    # kmeans = KMeans(n_clusters=n_clusters, random_state=args.seed, verbose=0)
    # labels_pred_N = kmeans.fit_predict(test_arrays['invariants_ND'])

    # homogeneity = homogeneity_score(test_sec_struc_idxs_N, labels_pred_N)
    # completeness = completeness_score(test_sec_struc_idxs_N, labels_pred_N)
    # purity = purity_score(test_sec_struc_idxs_N, labels_pred_N)
    # silhouette = silhouette_score(test_arrays['invariants_ND'], labels_pred_N)

    # table = {
    #     'Homogeneity': [homogeneity],
    #     'Completeness': [completeness],
    #     'Purity': [purity],
    #     'Silhouette': [silhouette]
    # }

    # pd.DataFrame(table).to_csv(os.path.join(args.experiment_dir, 'latent_space_classification/__quality_of_clustering_metrics_sec_struc%s-split=test.csv' % (model_type_str)), index=None)

