
import os, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import argparse

sys.path.append('..')
from utils.argparse_utils import *


DATA_QUANTITY_TO_BATCH_SIZE_DICT = {
    0: 0,
    400: 4,
    1000: 10,
    2000: 20,
    5000: 50,
    20000: 200
}

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default='../runs/toy_aminoacids/local_equiv_fibers')
    parser.add_argument('--data_quantities', type=comma_sep_int_list, default='0,400,1000,2000,5000,20000') # 0,400,1000,2000,5000,20000
    parser.add_argument('--z', type=int, default=2)
    parser.add_argument('--repetitions', type=int, default=3)
    parser.add_argument('--classifier', type=str, default='KNN', choices=['LC', 'KNN'])
    
    args = parser.parse_args()

    # collect csv files of classification reports, present accuracies
    
    if args.classifier == 'LC':
        classifier_str = ''
    elif args.classifier == 'KNN':
        classifier_str = 'KNN_'

    stuff_to_plot = {}
    for is_vae in [False, True]:

        if is_vae:
            if args.z == 2:
                MODEL_NAME = 'cgvae_symm_simp_flex-VAE_min_loss_with_final_kl-z=%d-x_lambda=400-data=%d-bs=%d-kl_lambda=0.025_v%d' # -kl_lambda=0.025
            else:
                MODEL_NAME = 'cgvae_symm_simp_flex-VAE_min_loss_with_final_kl-z=%d-x_lambda=400-data=%d-bs=%d_v%d'
            model_type_str = '-lowest_total_loss_with_final_kl_model'
        else:
            MODEL_NAME = 'cgvae_symm_simp_flex-AE-z=%d-x_lambda=400-data=%d-bs=%d_v%d'
            model_type_str = '-lowest_rec_loss'

        accuracies = []
        losses = []
        for data_quantity in args.data_quantities:
            rep_accuracies = []
            rep_losses = []
            batch_size = DATA_QUANTITY_TO_BATCH_SIZE_DICT[data_quantity]
            for rep in range(args.repetitions):
                hash = MODEL_NAME % (args.z, data_quantity, batch_size, rep+1)
                if data_quantity == 0:
                    curr_model_type_str = '-no_training'
                    hash = hash.replace('-kl_lambda=0.025', '')
                else:
                    curr_model_type_str = model_type_str
                try:
                    report = pd.read_csv(os.path.join(args.model_dir, hash, 'latent_space_classification', '__residue_type_cross_val_on_test_%sclassificaton_on_latent_space%s.csv' % (classifier_str, curr_model_type_str)))
                    curr_losses_df = pd.read_csv(os.path.join(args.model_dir, hash, 'loss_distributions', 'losses%s-split=test.csv' % (curr_model_type_str)))
                except Exception as E:
                    if curr_model_type_str == '-lowest_rec_loss':
                        curr_model_type_str = '-lowest_total_loss_with_final_kl_model'
                        report = pd.read_csv(os.path.join(args.model_dir, hash, 'latent_space_classification', '__residue_type_cross_val_on_test_%sclassificaton_on_latent_space%s.csv' % (classifier_str, curr_model_type_str)))
                        curr_losses_df = pd.read_csv(os.path.join(args.model_dir, hash, 'loss_distributions', 'losses%s-split=test.csv' % (curr_model_type_str)))
                        print('Warning, using model_type_str: %s' % (curr_model_type_str), file=sys.stderr)
                    else:
                        raise E

                assert np.all(report['accuracy'].values == report['accuracy'].values[0])

                rep_accuracies.append(report['accuracy'].values[0])
                rep_losses.append(curr_losses_df.values)
                losses_names = list(curr_losses_df.columns)
            
            accuracies.append(rep_accuracies)
            losses.append(rep_losses)
        
        stuff_to_plot[is_vae] = {
            'Classification accuracy': accuracies,
            'MSE': [np.vstack(rep_losses)[:, 0] for rep_losses in losses],
            'Cosine loss': [np.vstack(rep_losses)[:, 1] for rep_losses in losses],
            'PearsonR of total power': [np.vstack(rep_losses)[:, 2] for rep_losses in losses]
        }
    
    print(stuff_to_plot)
    

    expanded_data_quantities = [data_quantity for data_quantity in args.data_quantities for _ in range(args.repetitions)]
    xticks_locations = np.argsort(args.data_quantities)
    expanded_xticks_locations = [xtick_loc for xtick_loc in xticks_locations for _ in range(args.repetitions)]

    is_vae_legend_name_dict = {
        True: 'VAE ($\\beta = 0.025$)',
        False: 'AE'
    }

    GOLDEN_RATIO = (1 + 5.0**0.5) / 2
    HEIGHT = 4
    WIDTH = HEIGHT * GOLDEN_RATIO
    COLORS = plt.get_cmap('tab20').colors

    dist = 0.12

    for metric_i, metric in enumerate(stuff_to_plot[is_vae].keys()):
        plt.figure(figsize=(WIDTH, HEIGHT))
        for is_vae_i, is_vae in enumerate([False, True]):
            marker = 'o'
            color = COLORS[metric_i*2 + is_vae_i]
            alpha = 1 # (2 - is_vae_i) / 2

            errorbar_means = [np.mean([np.min(rep), np.max(rep)]) for rep in stuff_to_plot[is_vae][metric]]
            errorbar_symm_values = [np.max(rep) - mean for rep, mean in zip(stuff_to_plot[is_vae][metric], errorbar_means)]

            plt.scatter(np.array(expanded_xticks_locations) + (dist*is_vae_i - dist/2), np.hstack(stuff_to_plot[is_vae][metric]), marker=marker, color=color, alpha=alpha, label=is_vae_legend_name_dict[is_vae])
            plt.errorbar(np.array(xticks_locations) + (dist*is_vae_i - dist/2), errorbar_means, yerr=errorbar_symm_values, color=color, alpha=alpha, fmt='none')


        if 'cosine' in metric:
            plt.axhline(1.0, ls='--', color='dimgrey', alpha=0.6, label='Random')
            plt.ylim((-0.05, 1.1))
        elif 'accuracy' in metric:
            plt.axhline(0.05, ls='--', color='dimgrey', alpha=0.6, label='Random guess')
            plt.ylim((-0.05, 1.05))
        elif 'MSE' in metric:
            plt.ylim((-0.00075, 0.0175))
        # elif 'pearson' in metric:
        #     ax.axhline(0.0, ls='--', color='dimgrey', alpha=0.6, label='Random')

        plt.xticks(xticks_locations, args.data_quantities)

        plt.xlabel('Number of residues used to train H-(V)AE', fontdict={'fontsize': 10.5})
        plt.ylabel(metric, fontdict={'fontsize': 10.5})

        if 'cosine' in metric.lower() or 'MSE' in metric:
            plt.legend(loc='upper right')
        else:
            plt.legend(loc='lower right')
    
        plt.tight_layout()
        plt.savefig('data_ablation_%s%s_plot-z=%d.png' % (classifier_str, '_'.join(metric.split()), args.z))
        plt.savefig('data_ablation_%s%s_plot-z=%d.pdf' % (classifier_str, '_'.join(metric.split()), args.z))
        plt.close()
        
            



        


