import argparse
import dataclasses
import json
import pathlib
import sys

import numpy

@dataclasses.dataclass
class Trial:
    accuracy: float
    path: pathlib.Path

def read_trial(dirname):
    json_file_name = dirname / 'eval' / 'test.json'
    try:
        with json_file_name.open() as fin:
            accuracy = json.load(fin)['scores']['recognition_accuracy']
    except FileNotFoundError:
        return None
    return Trial(accuracy, dirname)

def read_trials(dirnames):
    trials = []
    missing_dirs = []
    for dirname in dirnames:
        trial = read_trial(dirname)
        if trial is not None:
            trials.append(trial)
        else:
            missing_dirs.append(dirname)
    return trials, missing_dirs

def format_mean(score):
    if score is not None:
        mean, stddev = score
        return f'{mean:.2f} $\\pm$ {stddev:.2f}'
    else:
        return ''

def format_option_name(option_str: str) -> str:
    parts = option_str.split('-')[1:]
    return ' '.join(word.capitalize() for word in parts)

ARCHITECTURES = ('transformer', 'rnn', 'lstm')

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--base-dir', type=pathlib.Path, required=True)
    parser.add_argument('--num-languages', type=int, required=True)
    parser.add_argument('--num-trials', type=int, required=True)
    parser.add_argument('--language-class', choices=[
        'regular',
        'context-free',
        'podfa',
        'star-free'
    ], required=True)
    parser.add_argument('--size-measure', choices=[
        'mean-num-states',
        'mean-alphabet-size',
        'mean-num-variables',
        'mean-binary-rules',
        'mean-lexical-rules',
    ], required=True)
    parser.add_argument(
        '--size-options',
        type=int,  # convert each item to int
        nargs='+',  # one or more values
        required=True,
        help='List of sizes for the chosen size measure, e.g., 5, 10, 20 for mean number of states.'
    )
    args = parser.parse_args()

    base_dir = args.base_dir
    num_languages = args.num_languages
    num_trials = args.num_trials
    language_class = args.language_class
    size_measure = args.size_measure
    size_options = args.size_options

    missing_dirs = []

    def get_mean_over_trials(
        language_class,
        size_measure,
        size,
        architecture,
        validation_set,
        aggregate_scores
    ):
        def generate_language_scores():
            for language_no in range(1, num_languages+1):
                trials_dir = base_dir / 'models' / f'random-{language_class}-{language_no}-{size_measure}-{size}' / architecture / 'rec' / validation_set
                trials, trial_missing_dirs = read_trials(trials_dir / str(i) for i in range(1, num_trials+1))
                missing_dirs.extend(trial_missing_dirs)
                scores = numpy.array([t.accuracy for t in trials])
                if len(scores) > 0:
                    yield aggregate_scores(scores)
        scores = numpy.array(list(generate_language_scores()))
        if len(scores) > 0:
            return numpy.mean(scores), numpy.std(scores)
        else:
            return None
    print(r'''
\begin{tabular}{@{}lcccccc@{}}
\toprule
& \multicolumn{3}{c}{Inductive Bias} & \multicolumn{3}{c}{Expressivity} \\
\cmidrule(lr){2-4} \cmidrule(lr){5-7}       
'''.strip())
    print(format_option_name(size_measure)) 
    print(r'''& \transformerAbbrev{} & \rnnAbbrev{} & \lstmAbbrev{} & \transformerAbbrev{} & \rnnAbbrev{} & \lstmAbbrev{} \\
\midrule
'''.strip())
    for size in size_options:
        print(size, end='')
        # Inductive bias
        for architecture in ARCHITECTURES:
            score = get_mean_over_trials(
                language_class,
                size_measure,
                size,
                architecture,
                'validation-short',
                lambda scores: numpy.mean(scores)
            )
            print(f' & {format_mean(score)}', end='')
        # Expressivity
        for architecture in ARCHITECTURES:
            score = get_mean_over_trials(
                language_class,
                size_measure,
                size,
                architecture,
                'validation-long',
                lambda scores: numpy.max(scores)
            )
            print(f' & {format_mean(score)}', end='')
        print(r' \\')
    print(r'''
\bottomrule
\end{tabular}
'''.strip())

    for missing_dir in missing_dirs:
        print(f'missing: {missing_dir}', file=sys.stderr)

if __name__ == '__main__':
    main()
