import argparse
import dataclasses
import json
import pathlib
import sys
import torch
import numpy
import math

from recognizers.analysis.print_summary_table import (
    read_trials
)

def read_automaton_stats(dirname, measure):
    language_file_name = dirname / 'language.pt'
    try:
        data = torch.load(language_file_name, weights_only=False, map_location=torch.device('cpu'))
        try:
            depth = data['depth']
            match measure:
                case 'alphabet_size':
                    size = data['language'].alphabet_size()
                case 'num_states':
                    size = data['language'].num_states()
                case 'num_transitions':
                    size = len(list(data['language'].transitions()))
                case _:
                    raise NotImplementedError
            return size, depth
        except AttributeError:
            return None, None
    except FileNotFoundError:
        print(f'missing: {language_file_name}', file=sys.stderr)
        return None, None

def get_axis_label(s: str) -> str:
    """Convert snake_case string to Title Case string."""
    return s.replace("_", " ").title()

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--base-dir', type=pathlib.Path, required=True)
    parser.add_argument('--num-languages', type=int, required=True)
    parser.add_argument('--size-measure', choices = ['alphabet_size', 'num_states', 'num_transitions'], default = 'alphabet_size')
    parser.add_argument('--tex-output', type=pathlib.Path , required=True)
    args = parser.parse_args()

    base_dir = args.base_dir
    num_languages = args.num_languages
    measure = args.size_measure

    x, y = [], []
    x_label = get_axis_label(measure)

    for language_no in range(1, num_languages + 1):
        size, depth = read_automaton_stats(base_dir / 'languages' / f'random-star-free-{language_no}', measure)
        x.append(size)
        y.append(depth)

    output_dir = base_dir / 'figures'
    output_dir.mkdir(exist_ok=True)
    dat_output = output_dir / args.tex_output.with_suffix('.dat')
    tex_output = output_dir / args.tex_output.with_suffix('.tex')
    print(f'writing {dat_output}')
    with dat_output.open('w', newline='\n') as fout:
        for xi, yi in zip(x, y):
            print(f'{xi} {yi}', file=fout)
    print(f'writing {tex_output}')
    with tex_output.open('w') as fout:
        fout.write(
            r'''\begin{tikzpicture}
        \begin{axis}[
            axis lines=left,
            xlabel={''')
        fout.write(x_label)
        fout.write(r'''},
            xmin=0,
            enlarge x limits=0.1,
            ylabel={Dot-depth},
            ymin=0,
            enlarge y limits=0.1,
        ]
            \addplot[
                mark=*,
                mark options={scale=0.5},
                only marks
            ] table {figures/dot-depth/''')
        fout.write(dat_output.name)
        fout.write(r'''};
        \end{axis}
    \end{tikzpicture}
    ''')

if __name__ == '__main__':
    main()