
import os, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import argparse

sys.path.append('..')
from utils.argparse_utils import *


AE_MODEL_NAME = 'cgvae_symm_simp_flex-AE-z=%d-x_lambda=400-data=20000-bs=200_v%d'
VAE_MODEL_NAME = 'cgvae_symm_simp_flex-VAE_min_loss_with_final_kl-z=%d-x_lambda=400-data=20000-bs=200-kl_lambda=%s_v%d'

REPETITIONS_DICT = {
    '0': 3,
    '0.025': 3,
    '0.05': 3,
    '0.1': 3,
    '0.25': 3,
    '0.5': 3
}

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default='../runs/toy_aminoacids/local_equiv_fibers')
    parser.add_argument('--kl_lambdas', type=comma_sep_str_list, default='0,0.025,0.05,0.1,0.25,0.5')
    parser.add_argument('--z', type=int, default=2)
    parser.add_argument('--validation_for_accuracy', type=str_to_bool, default=False)
    parser.add_argument('--classifier', type=str, default='KNN', choices=['LC', 'KNN'])
    
    args = parser.parse_args()

    # collect csv files of classification reports, present accuracies

    if args.classifier == 'LC':
        classifier_str = ''
    elif args.classifier == 'KNN':
        classifier_str = 'KNN_'
    
    stuff_to_plot = {}
    accuracies = []
    losses = []
    for kl_lambda in args.kl_lambdas:
        rep_accuracies = []
        rep_losses = []
        for rep in range(REPETITIONS_DICT[kl_lambda]):
            if kl_lambda == '0':
                hash = AE_MODEL_NAME % (args.z, rep+1)
                curr_model_type_str = '-lowest_rec_loss'
            else:
                hash = VAE_MODEL_NAME % (args.z, kl_lambda, rep+1)
                curr_model_type_str = '-lowest_total_loss_with_final_kl_model'
            
            try:
                if args.validation_for_accuracy:
                    report = pd.read_csv(os.path.join(args.model_dir, hash, 'latent_space_classification', 'classificaton_on_latent_space_default_classes%s-split=valid.csv' % (curr_model_type_str)))
                else:
                    report = pd.read_csv(os.path.join(args.model_dir, hash, 'latent_space_classification', '__residue_type_cross_val_on_test_%sclassificaton_on_latent_space%s.csv' % (classifier_str, curr_model_type_str)))
                curr_losses_df = pd.read_csv(os.path.join(args.model_dir, hash, 'loss_distributions', 'losses%s-split=test.csv' % (curr_model_type_str)))
            except Exception as E:
                if curr_model_type_str == '-lowest_rec_loss':
                    curr_model_type_str = '-lowest_total_loss_with_final_kl_model'
                    if args.validation_for_accuracy:
                        report = pd.read_csv(os.path.join(args.model_dir, hash, 'latent_space_classification', 'classificaton_on_latent_space_default_classes%s-split=valid.csv' % (curr_model_type_str)))
                    else:
                        report = pd.read_csv(os.path.join(args.model_dir, hash, 'latent_space_classification', '__residue_type_cross_val_on_test_%sclassificaton_on_latent_space%s.csv' % (classifier_str, curr_model_type_str)))
                    curr_losses_df = pd.read_csv(os.path.join(args.model_dir, hash, 'loss_distributions', 'losses%s-split=test.csv' % (curr_model_type_str)))
                    print('Warning, using model_type_str: %s' % (curr_model_type_str), file=sys.stderr)
                else:
                    raise E

            assert np.all(report['accuracy'].values == report['accuracy'].values[0])

            rep_accuracies.append(report['accuracy'].values[0])
            rep_losses.append(curr_losses_df.values)
            losses_names = list(curr_losses_df.columns)
        
        accuracies.append(rep_accuracies)
        losses.append(rep_losses)
    
    stuff_to_plot = {
        'Classification accuracy': accuracies,
        'MSE': [np.vstack(rep_losses)[:, 0] for rep_losses in losses],
        'Cosine loss': [np.vstack(rep_losses)[:, 1] for rep_losses in losses],
        'PearsonR of total power': [np.vstack(rep_losses)[:, 2] for rep_losses in losses]
    }

    print(stuff_to_plot)
    

    expanded_kl_lambdas = [kl_lambda for kl_lambda in args.kl_lambdas for _ in range(REPETITIONS_DICT[kl_lambda])]
    xticks_locations = np.argsort(args.kl_lambdas)
    expanded_xticks_locations = [xtick_loc for i, xtick_loc in enumerate(xticks_locations) for _ in range(REPETITIONS_DICT[args.kl_lambdas[i]])]


    GOLDEN_RATIO = (1 + 5.0**0.5) / 2
    HEIGHT = 4
    WIDTH = HEIGHT * GOLDEN_RATIO
    COLORS = plt.get_cmap('tab20').colors

    for metric_i, metric in enumerate(stuff_to_plot.keys()):
        color = COLORS[metric_i*2]
        alpha = 1 # (2 - is_vae_i) / 2

        errorbar_means = [np.mean([np.min(rep), np.max(rep)]) for rep in stuff_to_plot[metric]]
        errorbar_symm_values = [np.max(rep) - mean for rep, mean in zip(stuff_to_plot[metric], errorbar_means)]

        plt.figure(figsize=(WIDTH, HEIGHT))
        if 'cosine' in metric.lower():
            plt.axhline(1.0, ls='--', color='dimgrey', alpha=0.6, label='Random')
            plt.ylim((-0.05, 1.1))
        elif 'accuracy' in metric.lower():
            plt.axhline(0.05, ls='--', color='dimgrey', alpha=0.6, label='Random guess')
            plt.ylim((-0.05, 1.05))
        # elif 'MSE' in metric:
        #     ax.set_ylim((-0.00075, 0.0175))
        # elif 'pearson' in metric and is_vae_i == 0:
        #     ax.axhline(0.0, ls='--', color='dimgrey', alpha=0.6, label='Random')

        plt.scatter(np.array(expanded_xticks_locations), np.hstack(stuff_to_plot[metric]), color=color, alpha=alpha)
        plt.errorbar(np.array(xticks_locations), errorbar_means, yerr=errorbar_symm_values, color=color, alpha=alpha, fmt='none')

        plt.xticks(xticks_locations, args.kl_lambdas)

        plt.xlabel('$\\beta$', fontdict={'fontsize': 13})
        plt.ylabel(metric, fontdict={'fontsize': 10.5})

        if 'cosine' in metric or 'accuracy' in metric:
            plt.legend(loc='upper right')

        plt.tight_layout()
        plt.savefig('kld_ablation_%s%s_plot-z=%d-validation_for_accuracy=%s.png' % (classifier_str, '_'.join(metric.split()), args.z, args.validation_for_accuracy))
        plt.savefig('kld_ablation_%s%s_plot-z=%d-validation_for_accuracy=%s.pdf' % (classifier_str, '_'.join(metric.split()), args.z, args.validation_for_accuracy))
        plt.close()
        
            



        


