import argparse
import dataclasses
import json
import pathlib
import sys
import torch
import numpy
import math

from recognizers.analysis.print_summary_table import (
    read_trials
)

def read_automaton_stat(dirname, measure):
    language_file_name = dirname / 'language.pt'
    try:
        data = torch.load(language_file_name, weights_only=False, map_location=torch.device('cpu'))
        try:
            match measure:
                case 'alphabet_size':
                    return data['language'].alphabet_size()
                case 'num_states':
                    return data['language'].num_states()
                case 'num_variables':
                    return data['language'].num_variables()
                case 'num_transitions':
                    return len(list(data['language'].transitions()))
                case 'num_rules':
                    return len(list(data['language'].rules()))
                case 'dot_depth':
                    return data['depth']
                case _:
                    raise NotImplementedError
        except AttributeError:
            return None
    except FileNotFoundError:
        print(f'missing: {language_file_name}', file=sys.stderr)
        return None

def get_max_over_trials(
        base_dir,
        language_class,
        language_no,
        num_trials,
        architecture
    ):
        missing_dirs = []
        scores = []
        for parameter_budget in PARAMETER_BUDGETS:
            trials_dir = base_dir / 'models' / f'{parameter_budget}' / f'random-{language_class}-{language_no}' / architecture / 'rec' / 'validation-long'
            trials, trial_missing_dirs = read_trials(trials_dir / str(i) for i in range(1, num_trials+1))
            missing_dirs.extend(trial_missing_dirs)
            scores.extend([t.accuracy for t in trials])
        scores = numpy.array(scores)
        if len(scores) > 0:
            return numpy.max(scores), missing_dirs
        else:
            return None, missing_dirs


def get_x_label(s: str) -> str:
    """Convert snake_case string to Title Case string."""
    return s.replace("_", " ").title()

def get_smoothed_points(xs, ys, min_val, max_val, window_size):
    def points_in_window(middle, points, window_size):
        return [(x, y) for x, y in points if middle - window_size <= x and x <= middle + window_size]
    points = list(zip(xs, ys))
    groups = [(x, points_in_window(x, points, window_size)) for x in range(min_val, max_val, window_size)]
    xs = [x for x, _ in groups]
    means = []
    stds = []
    for _, x in groups:
        vals = [y for _, y in x if y is not None]
        if len(vals) > 0:
            means.append(numpy.mean(vals))
            stds.append(numpy.std(vals))
        else:
            means.append(None)
            stds.append(None)
    return xs, means, stds

def safe(x):
    return x if x is not None else math.nan

ARCHITECTURES = ('transformer', 'rnn', 'lstm')
PARAMETER_BUDGETS = (64000, 128000, 256000, 512000)

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--base-dir', type=pathlib.Path, required=True)
    parser.add_argument('--num-languages', type=int, required=True)
    parser.add_argument('--num-trials', type=int, required=True)
    parser.add_argument('--language-class', choices = ['regular', 'podfa', 'context-free', 'star-free'], required=True)
    parser.add_argument('--size-measure', choices = ['alphabet_size', 'num_states', 'num_transitions', 'num_rules', 'num_variables', 'dot_depth'], default = 'alphabet_size')
    parser.add_argument('--tex-output', type=pathlib.Path , required=True)
    args = parser.parse_args()

    base_dir = args.base_dir
    num_languages = args.num_languages
    num_trials = args.num_trials
    language_class = args.language_class
    measure = args.size_measure

    x, y = [], []
    x_label = get_x_label(measure)

    for language_no in range(1, num_languages + 1):
        stat = read_automaton_stat(base_dir / 'languages' / f'random-{language_class}-{language_no}', measure)
        x.append(stat)
        ys = []
        for architecture in ARCHITECTURES:
            accuracy, missing_dirs = get_max_over_trials(base_dir, language_class, language_no, num_trials, architecture)
            ys.append(accuracy)
            for missing_dir in missing_dirs:
                print(f'missing: {missing_dir}', file=sys.stderr)
        y.append(ys)

    min_val, max_val = min(x), max(x)
    xs, meanstf, stdstf = get_smoothed_points(x, [z[0] for z in y], min_val, max_val, window_size=3)
    _, meansrnn, stdsrnn = get_smoothed_points(x, [z[1] for z in y], min_val, max_val, window_size=3)
    _, meanslstm, stdslstm = get_smoothed_points(x, [z[2] for z in y], min_val, max_val, window_size=3)

    output_dir = base_dir / 'figures'
    output_dir.mkdir(exist_ok=True)
    dat_output = output_dir / args.tex_output.with_suffix('.dat')
    tex_output = output_dir / args.tex_output.with_suffix('.tex')
    print(f'writing {dat_output}')
    with dat_output.open('w', newline='\n') as fout:
        print('size meantf meanplusstdtf meanminusstdtf meanrnn meanplusstdrnn meanminusstdrnn meanlstm meanplusstdlstm meanminusstdlstm', file=fout)
        for xi, yitfmean, yitfstd, yirnnmean, yirnnstd, yilstmmean, yilstmstd in zip(xs, meanstf, stdstf, meansrnn, stdsrnn, meanslstm, stdslstm):
            meanplusstdtf = safe(yitfmean+yitfstd if yitfmean is not None and yitfstd is not None else None)
            meanminusstdtf = safe(yitfmean-yitfstd if yitfmean is not None and yitfstd is not None else None)
            meanplusstdrnn = safe(yirnnmean+yirnnstd if yirnnmean is not None and yirnnstd is not None else None)
            meanminusstdrnn = safe(yirnnmean-yirnnstd if yirnnmean is not None and yirnnstd is not None else None)
            meanplusstdlstm = safe(yilstmmean+yilstmstd if yilstmmean is not None and yilstmstd is not None else None)
            meanminusstdlsmt = safe(yilstmmean-yilstmstd if yilstmmean is not None and yilstmstd is not None else None)
            print(f'{xi} {safe(yitfmean)} {meanplusstdtf} {meanminusstdtf} {safe(yirnnmean)} {meanplusstdrnn} {meanminusstdrnn} {safe(yilstmmean)} {meanplusstdlstm} {meanminusstdlsmt}', file=fout)
    print(f'writing {tex_output}')
    with tex_output.open('w') as fout:
        fout.write(
            r'''\begin{tikzpicture}
    \begin{axis}[
        legend style={anchor=north east, font=\tiny, opacity=0.6, text opacity=1, draw=gray},
        axis lines=left,
        every axis plot/.append style={-},
        xlabel={''')
        fout.write(x_label)
        fout.write(r'''},
        enlarge x limits=0.1,
        ylabel={Accuracy},
        enlarge y limits=0.1]
        \addplot[
            name path=A-Tf,
            color=blue,
            opacity=0.2,
            forget plot
        ] table [x=size, y=meanplusstdtf] {02-Paper/figures/accuracy-vs-size/''')
        fout.write(dat_output.name)
        fout.write(r'''};
        \addplot[
            name path=B-Tf,
            color=blue,
            opacity=0.2,
            forget plot
        ] table [x=size, y=meanminusstdtf] {02-Paper/figures/accuracy-vs-size/''')
        fout.write(dat_output.name)
        fout.write(r'''};
        \addplot[color=blue, opacity=0.2, forget plot] fill between [of=A-Tf and B-Tf];
        \addplot[
            color=blue
        ] table [x=size, y=meantf] {02-Paper/figures/accuracy-vs-size/''')
        fout.write(dat_output.name)
        fout.write(r'''};
        \addplot[
            name path=A-rnn,
            color=red,
            opacity=0.2,
            forget plot
        ] table [x=size, y=meanplusstdrnn] {02-Paper/figures/accuracy-vs-size/''')
        fout.write(dat_output.name)
        fout.write(r'''};
        \addplot[
            name path=B-rnn,
            color=red,
            opacity=0.2,
            forget plot
        ] table [x=size, y=meanminusstdrnn] {02-Paper/figures/accuracy-vs-size/''')
        fout.write(dat_output.name)
        fout.write(r'''};
        \addplot[color=red, opacity=0.2, forget plot] fill between [of=A-rnn and B-rnn];
        \addplot[
            color=red
        ] table [x=size, y=meanrnn] {02-Paper/figures/accuracy-vs-size/''')
        fout.write(dat_output.name)
        fout.write(r'''};
        \addplot[
            name path=A-lstm,
            color=green,
            opacity=0.2,
            forget plot
        ] table [x=size, y=meanplusstdlstm] {02-Paper/figures/accuracy-vs-size/''')
        fout.write(dat_output.name)
        fout.write(r'''};
        \addplot[
            name path=B-lstm,
            color=green,
            opacity=0.2,
            forget plot
        ] table [x=size, y=meanminusstdlstm] {02-Paper/figures/accuracy-vs-size/''')
        fout.write(dat_output.name)
        fout.write(r'''};
        \addplot[color=green, opacity=0.2, forget plot] fill between [of=A-lstm and B-lstm];
        \addplot[
            color=green
        ] table [x=size, y=meanlstm] {02-Paper/figures/accuracy-vs-size/''')
        fout.write(dat_output.name)
        fout.write(r'''};
    \end{axis}
\end{tikzpicture} ''')

if __name__ == '__main__':
    main()