import argparse
import dataclasses
import json
import pathlib
import sys

import numpy

@dataclasses.dataclass
class Trial:
    accuracy: float
    path: pathlib.Path

def read_trial(dirname):
    json_file_name = dirname / 'eval' / 'test.json'
    try:
        with json_file_name.open() as fin:
            accuracy = json.load(fin)['scores']['recognition_accuracy']
    except FileNotFoundError:
        return None
    return Trial(accuracy, dirname)

def read_trials(dirnames):
    trials = []
    missing_dirs = []
    for dirname in dirnames:
        trial = read_trial(dirname)
        if trial is not None:
            trials.append(trial)
        else:
            missing_dirs.append(dirname)
    return trials, missing_dirs

def format_mean(score):
    if score is not None:
        mean, stddev = score
        return f'${mean:.2f}_{{\\pm {stddev:.2f}}}$'
    else:
        return ''

def format_language(language_class):
    if language_class == 'podfa':
        return 'PODFA'
    elif language_class == 'star-free':
        return 'SF'
    elif language_class == 'regular':
        return 'R'
    elif language_class == 'context-free':
        return 'CF'
    else:
        raise NotImplementedError

ARCHITECTURES = ('transformer', 'rnn', 'lstm')
PARAMETER_BUDGETS = (128000, 256000, 512000)

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--base-dir', type=pathlib.Path, required=True)
    parser.add_argument('--num-languages', type=int, required=True)
    parser.add_argument('--num-trials', type=int, required=True)
    args = parser.parse_args()

    base_dir = args.base_dir
    num_languages = args.num_languages
    num_trials = args.num_trials

    missing_dirs = []
    print(f'num languages: {num_languages}, num trials: {num_trials}')

    def get_mean_over_trials(
        language_class,
        architecture,
        parameter_budget,
        validation_set,
        aggregate_scores
    ):
        def generate_language_scores():
            for language_no in range(1, num_languages+1):
                trials_dir = base_dir / 'models' / f'{parameter_budget}' / f'random-{language_class}-{language_no}' / architecture / 'rec' / validation_set
                trials, trial_missing_dirs = read_trials(trials_dir / str(i) for i in range(1, num_trials+1))
                missing_dirs.extend(trial_missing_dirs)
                scores = numpy.array([t.accuracy for t in trials])
                if len(scores) > 0:
                    yield aggregate_scores(scores)
        scores = numpy.array(list(generate_language_scores()))
        if len(scores) > 0:
            return numpy.mean(scores), numpy.std(scores)
        else:
            return None

    def get_perc_perfect_accuracy(
        language_class,
        architecture,
        parameter_budget,
        validation_set
    ):  
        def generate_max_scores():
            for language_no in range(1, num_languages+1):
                trials_dir = base_dir / 'models' / f'{parameter_budget}' / f'random-{language_class}-{language_no}' / architecture / 'rec' / validation_set
                trials, trial_missing_dirs = read_trials(trials_dir / str(i) for i in range(1, num_trials+1))
                missing_dirs.extend(trial_missing_dirs)
                scores = [t.accuracy for t in trials]
                if len(scores) > 0:
                    yield max(scores)
        scores = list(generate_max_scores())
        scores = numpy.array(scores)
        num_perfect_scores = (scores == 1.0).sum()
        return num_perfect_scores, num_perfect_scores / num_languages
    print(r'''
\begin{tabular}{@{}lcccccc@{}}
\toprule
& \multicolumn{3}{c}{Inductive Bias} & \multicolumn{3}{c}{Expressivity} \\
\cmidrule(lr){2-4} \cmidrule(lr){5-7}
Language & \transformerAbbrev{} & \rnnAbbrev{} & \lstmAbbrev{} & \transformerAbbrev{} & \rnnAbbrev{} & \lstmAbbrev{} \\
\midrule
'''.strip())
    for language_class in ('podfa', 'star-free', 'regular', 'context-free'):
        print(format_language(language_class), end='')
        # Inductive bias
        for architecture in ARCHITECTURES:
            scores_by_budget = []
            for parameter_budget in PARAMETER_BUDGETS:
                score = get_mean_over_trials(
                    language_class,
                    architecture,
                    parameter_budget,
                    'validation-short',
                    lambda scores: numpy.mean(scores)
                )
                if score is not None:
                    scores_by_budget.append(score)
            if len(scores_by_budget) > 0:
                score = max(scores_by_budget)
                print(f' & {format_mean(score)}', end='')
            else:
                print(f' &', end='')
        # Expressivity
        for architecture in ARCHITECTURES:
            scores_by_budget = []
            for parameter_budget in PARAMETER_BUDGETS:
                score = get_mean_over_trials(
                    language_class,
                    architecture,
                    parameter_budget,
                    'validation-long',
                    lambda scores: numpy.max(scores)
                )
                if score is not None:
                    scores_by_budget.append(score)
            if len(scores_by_budget) > 0:
                score = max(scores_by_budget)
                print(f' & {format_mean(score)}', end='')
            else:
                print(f' &', end='')
        print(r' \\')
    print(r'''
\bottomrule
\end{tabular}
'''.strip())
    print(r'''
    \begin{tabular}{@{}lcccccc@{}}
    \toprule
    & \multicolumn{3}{c}{Inductive Bias} & \multicolumn{3}{c}{Expressivity} \\
    \cmidrule(lr){2-4} \cmidrule(lr){5-7}
    Language & \transformerAbbrev{} & \rnnAbbrev{} & \lstmAbbrev{} & \transformerAbbrev{} & \rnnAbbrev{} & \lstmAbbrev{} \\
    \midrule
    '''.strip())
    for language_class in ('podfa', 'star-free', 'regular', 'context-free'):
        print(format_language(language_class), end='')
        # Inductive bias
        for architecture in ARCHITECTURES:
            scores_by_budget = []
            for parameter_budget in PARAMETER_BUDGETS:
                _, score = get_perc_perfect_accuracy(
                    language_class,
                    architecture,
                    parameter_budget,
                    'validation-short'
                )
                if score is not None:
                    scores_by_budget.append(score)
            if len(scores_by_budget) > 0:
                score = max(scores_by_budget)
                print(f' & ${score*100:.2f}\\%$', end='')
            else:
                print(f' &', end='')
        # Expressivity
        for architecture in ARCHITECTURES:
            for parameter_budget in PARAMETER_BUDGETS:
                _, score = get_perc_perfect_accuracy(
                    language_class,
                    architecture,
                    parameter_budget,
                    'validation-long'
                )
                if score is not None:
                    scores_by_budget.append(score)
            if len(scores_by_budget) > 0:
                score = max(scores_by_budget)
                print(f' & ${score*100:.2f}\\%$', end='')
            else:
                print(f' &', end='')
        print(r' \\')
    print(r'''
    \bottomrule
    \end{tabular}
    '''.strip())

    for missing_dir in missing_dirs:
        print(f'missing: {missing_dir}', file=sys.stderr)

if __name__ == '__main__':
    main()
