import os
import json
from optparse import OptionParser
from collections import defaultdict, Counter

import seaborn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from classification.export_training_data import simplify_name


def main():
    usage = "%prog"
    parser = OptionParser(usage=usage)
    parser.add_option('--min-count', type=int, default=10,
                      help='Min count: default=%default')
    parser.add_option('--outdir', type=str, default='plots',
                      help='Output directory: default=%default')

    (options, args) = parser.parse_args()

    min_count = options.min_count
    outdir = options.outdir

    print("Loading counts")
    with open(os.path.join('data', 'classification', 'value_counts.json')) as f:
        value_counts = json.load(f)

    print("Loading f1s")
    f1s_by_value = {}
    for v, c in value_counts.items():
        if c >= min_count:
            infile = os.path.join('data', 'classification', 'exp', v, 'partition_t300_s42', 'linear_f1_binarize_n1_l1', 'results.test.json')
            with open(infile) as f:
                results = json.load(f)
                f1 = results['f1']
                f1s_by_value[v] = f1
                print(v, c, f1)

    value_order = ['Performance Values',
                   'Performance',
                   'Accuracy',
                   'State-of-the-art',
                   'Building On Past Work',
                   'Building on recent work',
                   'Building on classic work',
                   'Generalization Values',
                   'Generalization',
                   'Flexibility/Extensibility',
                   'Avoiding train/test discrepancy',
                   'Efficiency Values',
                   'Efficiency',
                   'Low cost',
                   'Data efficiency',
                   'Label efficiency (reduced need for labeled data)',
                   'Fast',
                   'Reduced training time',
                   'Memory efficiency',
                   'Energy efficiency',
                   'Quantitative evidence (e.g. experiments)',
                   'Novelty',
                   'Understanding (for researchers)',
                   'Applies to real world',
                   'Formal description/analysis',
                   'Simplicity',
                   'Identifying limitations',
                   'Robustness',
                   'Unifying ideas or integrating components',
                   'Effectiveness',
                   'Theoretical guarantees',
                   'Scientific methodology',
                   'Used in practice/Popular',
                   'Approximation',
                   'Large scale',
                   'Scales up',
                   'Successful',
                   'Qualitative evidence (e.g. examples)',
                   'Generality',
                   'Facilitating use (e.g. sharing code)',
                   'Improvement',
                   'Useful',
                   'Parallelizability / distributed',
                   'Practical',
                   'Promising',
                   'Preciseness',
                   'Easy to implement',
                   'Requires few resources',
                   'Exactness',
                   'Realistic output',
                   'Progress',
                   'Interpretable (to users)',
                   'Beneficence',
                   'Optimal',
                   'Automatic',
                   'Security',
                   'Human-like mechanism',
                   'Concreteness',
                   'Learning from humans',
                   'Controllability (of model owner)',
                   'Valid assumptions',
                   'Important',
                   'Easy to work with',
                   'Principled',
                   'Stable',
                   'Deferral to humans',
                   'Critique',
                   'Impressive',
                   'Realistic world model',
                   'Reproducibility',
                   'Privacy',
                   'Powerful',
                   'Diverse output',
                   'User influence',
                   'Non-maleficence',
                   'Safety',
                   'Not socially biased',
                   'Fairness',
                   'Respect for Law and public interest',
                   'Explicability',
                   'Autonomy (power to decide)',
                   'Respect for Persons',
                   'Justice'
                   ]

    simple_order = [simplify_name(val) for val in value_order]

    label_order = ['Performance Values',
                   'Performance',
                   'Accuracy',
                   'State-of-The-Art',
                   'Building On Past Work',
                   'Building on Recent Work',
                   'Building on Classic Work',
                   'Generalization Values',
                   'Generalization',
                   'Flexibility/Extensibility',
                   'Avoiding Train/Test Discrepancy',
                   'Efficiency Values',
                   'Efficiency',
                   'Low Cost',
                   'Data Efficiency',
                   'Label Efficiency',
                   'Fast',
                   'Reduced Training Time',
                   'Memory Efficiency',
                   'Energy Efficiency',
                   'Quantitative Evidence',
                   'Novelty',
                   'Understanding (For Researchers)',
                   'Applies to real world',
                   'Formal Description/Analysis',
                   'Simplicity',
                   'Identifying Limitations',
                   'Robustness',
                   'Unifying Ideas',
                   'Effectiveness',
                   'Theoretical Guarantees',
                   'Scientific Methodology',
                   'Used in practice/Popular',
                   'Approximation',
                   'Large Scale',
                   'Scales Up',
                   'Successful',
                   'Qualitative Evidence',
                   'Generality',
                   'Facilitating Use',
                   'Improvement',
                   'Useful',
                   'Parallelizability / Distributed',
                   'Practical',
                   'Promising',
                   'Preciseness',
                   'Easy To Implement',
                   'Requires Few Resources',
                   'Exactness',
                   'Realistic Output',
                   'Progress',
                   'Interpretabile (To Users)',
                   'Beneficence',
                   'Optimal',
                   'Automatic',
                   'Security',
                   'Human-Like Mechanism',
                   'Concreteness',
                   'Learning From Humans',
                   'Controllability (Of Model Owner)',
                   'Valid Assumptions',
                   'Important',
                   'Easy To Work With',
                   'Principled',
                   'Stable',
                   'Deferral To Humans',
                   'Critique',
                   'Impressive',
                   'Realistic World Model',
                   'Reproducibility',
                   'Privacy',
                   'Powerful',
                   'Diverse Output',
                   'User Influence',
                   'Non-Maleficence',
                   'Safety',
                   'Not Socially Biased',
                   'Fairness',
                   'Respect For Law And Public Interest',
                   'Explicability',
                   'Autonomy (Power To Decide)',
                   'Respect For Persons',
                   'Justice',
                   'Critiqueability',
                   'Collective Influence',
                   'Transparent (To Users)']

    f1_colors = ['mediumseagreen' for v in value_order]
    #for val in [1, 2, 3, 5, 6, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19]:
    for val in [0, 4, 7, 10]:
        #f1_colors[val] = '#d6f5db'
        f1_colors[val] = 'grey'

    print("Loading ICML data and predictions")
    infile = os.path.join('data', 'icml', 'parsed.jsonlist')
    with open(infile) as f:
        icml_lines = f.readlines()

    # get the line indices of each paper for each year
    icml_start = 2008
    icml_end = 2020
    icml_indices_by_paper_by_year = {year: defaultdict(list) for year in range(icml_start, icml_end+1)}
    for i, line in enumerate(icml_lines):
        line = json.loads(line)
        year = line['year']
        paper = line['paper']
        icml_indices_by_paper_by_year[year][paper].append(i)
    icml_n_papers_per_year = {year: len(papers) for year, papers in icml_indices_by_paper_by_year.items()}

    # get the number of paper with each predicted value for each year
    icml_pred_paper_count_by_value_by_year = defaultdict(Counter)
    for value in simple_order:
        print(value)
        # load the value predictions
        pred_file = os.path.join('data', 'classification', 'exp', value, 'partition_t300_s42', 'linear_f1_binarize_n1_l1', 'pred.probs.icml.csv')
        if os.path.exists(pred_file):
            df = pd.read_csv(pred_file, header=0, index_col=0)
            preds = np.argmax(np.array(df.values), axis=1)
            # for each year, check each paper
            for year, papers in icml_indices_by_paper_by_year.items():
                # for each paper, get the associated line indices
                for paper, indices in papers.items():
                    # add one to the count if this value is predicted in thsi paper
                    pred_count = np.sum([preds[i] for i in indices])
                    if pred_count > 0:
                        icml_pred_paper_count_by_value_by_year[year][value] += 1

    # combine the counts over the years 2008 - 2020
    plot_start = 2008
    plot_end = 2020
    icml_value_sums = Counter()
    icml_total_papers = 0
    for year in range(plot_start, plot_end+1):
        icml_total_papers += icml_n_papers_per_year[year]
        for value in simple_order:
            icml_value_sums[value] += icml_pred_paper_count_by_value_by_year[year][value]

    # repeat for NeurIPS
    print("Loading NeurIPS data and predictions")
    infile = os.path.join('data', 'neurips', 'parsed.jsonlist')
    with open(infile) as f:
        neurips_lines = f.readlines()

    # get the line indices of each paper for each year
    neurips_start = 1987
    neurips_end = 2020
    neurips_indices_by_paper_by_year = {year: defaultdict(list) for year in range(neurips_start, neurips_end+1)}
    for i, line in enumerate(neurips_lines):
        line = json.loads(line)
        year = line['year']
        paper = line['paper']
        neurips_indices_by_paper_by_year[year][paper].append(i)
    neurips_n_papers_per_year = {year: len(papers) for year, papers in neurips_indices_by_paper_by_year.items()}

    # get the number of paper with each predicted value for each year
    neurips_pred_paper_count_by_value_by_year = defaultdict(Counter)
    for value in simple_order:
        print(value)
        # load the value predictions
        pred_file = os.path.join('data', 'classification', 'exp', value, 'partition_t300_s42', 'linear_f1_binarize_n1_l1', 'pred.probs.neurips.csv')
        if os.path.exists(pred_file):
            df = pd.read_csv(pred_file, header=0, index_col=0)
            preds = np.argmax(np.array(df.values), axis=1)
            # for each year, check each paper
            for year, papers in neurips_indices_by_paper_by_year.items():
                # for each paper, get the associated line indices
                for paper, indices in papers.items():
                    # add one to the count if this value is predicted in thsi paper
                    pred_count = np.sum([preds[i] for i in indices])
                    if pred_count > 0:
                        neurips_pred_paper_count_by_value_by_year[year][value] += 1

    # combine the counts over the years 2008 - 2020
    neurips_value_sums = Counter()
    neurips_total_papers = 0
    for year in range(plot_start, plot_end+1):
        neurips_total_papers += neurips_n_papers_per_year[year]
        for value in simple_order:
            neurips_value_sums[value] += neurips_pred_paper_count_by_value_by_year[year][value]
    print("Making plots")

    combined_value_sums = Counter()
    combined_total_papers = sum([icml_n_papers_per_year[year] + neurips_n_papers_per_year[year] for year in range(plot_start, plot_end+1)])
    for value in simple_order:
        combined_value_sums[value] = icml_value_sums[value] + neurips_value_sums[value]
        print(value, combined_value_sums[value] / combined_total_papers)

    #labels = [label_order[v_i] + ' ({:.0f})'.format(value_counts[value]) for v_i, value in enumerate(simple_order) if value_counts[value] >= 20]
    labels = [label_order[v_i] for v_i, value in enumerate(simple_order) if value_counts[value] >= 20]
    #f1s = [f1s_by_value[value] if value in f1s_by_value else 0 for v_i, value in enumerate(simple_order)]
    f1s = [f1s_by_value[value] if value in f1s_by_value else 0 for v_i, value in enumerate(simple_order) if value_counts[value] >= 20]
    icml_props = [100 * icml_value_sums[value] / icml_total_papers for value in simple_order if value_counts[value] >= 20]
    neurips_props = [100 * neurips_value_sums[value] / neurips_total_papers for value in simple_order if value_counts[value] >= 20]
    combined_props = [combined_value_sums[value] / combined_total_papers for value in simple_order if value_counts[value] >= 20]
    plot_df = pd.DataFrame()
    plot_df['Value'] = labels
    plot_df['F1'] = f1s
    plot_df['% ICML papers (2008-2020)'] = icml_props
    plot_df['% NeurIPS papers (2008-2020)'] = neurips_props
    plot_df['Est. prop. of papers'] = combined_props
    fig, axes = fig, ax = plt.subplots(ncols=2, figsize=(10, 14), sharey=True)
    plt.subplots_adjust(wspace=0.1)
    seaborn.barplot(y='Value', x='Est. prop. of papers', data=plot_df, ax=axes[0], palette=f1_colors)
    axes[0].set_title('NeurIPS and ICML (2008-2020)')
    axes[0].set_xlim(0, 1.)
    axes[0].set_ylabel('')
    seaborn.barplot(y='Value', x='F1', data=plot_df, ax=axes[1], palette=f1_colors)
    axes[1].set_xlim(0, 1)
    axes[1].set_title('F1 scores')
    axes[1].set_ylabel('')

    if not os.path.exists(outdir):
        os.makedirs(outdir)
    outfile = os.path.join(outdir, 'f1s.pdf')
    plt.savefig(outfile, bbox_inches='tight')

    fig, axes = plt.subplots(nrows=3, figsize=(5, 8))
    plt.subplots_adjust(hspace=0.3)
    for t_i, target in enumerate(['Performance', 'Accuracy', 'State-of-the-art']):
        ns = np.array([neurips_n_papers_per_year[year] for year in range(neurips_start, neurips_end+1)])
        props = np.array([neurips_pred_paper_count_by_value_by_year[year][target] / neurips_n_papers_per_year[year] for year in range(neurips_start, neurips_end+1)])
        std = np.sqrt(props * (1-props) / ns)
        axes[t_i].plot(range(neurips_start, neurips_end+1), props, label='NeurIPS')
        axes[t_i].fill_between(range(neurips_start, neurips_end+1), props - 2 * std, props + 2 * std, alpha=0.1)

        ns = np.array([icml_n_papers_per_year[year] for year in range(icml_start, icml_end+1)])
        props = np.array([icml_pred_paper_count_by_value_by_year[year][target] / icml_n_papers_per_year[year] for year in range(icml_start, icml_end+1)])
        std = np.sqrt(props * (1-props) / ns)
        axes[t_i].plot(range(icml_start, icml_end+1), props, label='ICML')
        axes[t_i].fill_between(range(icml_start, icml_end+1), props - 2 * std, props + 2 * std, alpha=0.1)

        axes[t_i].set_title(target)
        axes[t_i].set_ylim(0, 1.)
        axes[t_i].set_ylabel('Proportion of papers')
        axes[t_i].legend(loc='upper left')

    outfile = os.path.join(outdir, 'performance_plots.pdf')
    plt.savefig(outfile, bbox_inches='tight')

    with open(os.path.join(outdir, 'neurips_value_counts.json'), 'w') as f:
        json.dump({'paper_counts': neurips_n_papers_per_year,
                   'value_counts': neurips_pred_paper_count_by_value_by_year}, f, indent=2)

    with open(os.path.join(outdir, 'icml_value_counts.json'), 'w') as f:
        json.dump({'paper_counts': icml_n_papers_per_year,
                   'value_counts': icml_pred_paper_count_by_value_by_year}, f, indent=2)


if __name__ == '__main__':
    main()
