import argparse
import dataclasses
import json
import pathlib
import sys

import numpy

@dataclasses.dataclass
class Trial:
    accuracy: float
    path: pathlib.Path

def read_trial(dirname):
    json_file_name = dirname / 'eval' / 'test.json'
    try:
        with json_file_name.open() as fin:
            accuracy = json.load(fin)['scores']['recognition_accuracy']
    except FileNotFoundError:
        return None
    return Trial(accuracy, dirname)

def read_trials(dirnames):
    trials = []
    missing_dirs = []
    for dirname in dirnames:
        trial = read_trial(dirname)
        if trial is not None:
            trials.append(trial)
        else:
            missing_dirs.append(dirname)
    return trials, missing_dirs

def get_max_over_trials(
        base_dir,
        language_class,
        language_no,
        num_trials,
        architecture
):
    missing_dirs = []
    scores = []
    for parameter_budget in PARAMETER_BUDGETS:
        trials_dir = base_dir / 'models' / f'{parameter_budget}' / f'random-{language_class}-{language_no}' / architecture / 'rec' / 'validation-long'
        trials, trial_missing_dirs = read_trials(trials_dir / str(i) for i in range(1, num_trials + 1))
        missing_dirs.extend(trial_missing_dirs)
        scores.extend([t.accuracy for t in trials])
    scores = numpy.array(scores)
    if len(scores) > 0:
        return numpy.max(scores), missing_dirs
    else:
        return None, missing_dirs

def normalize_score(x):
    if x < 0.6:
        return 0.0
    return 0.6 + 0.4 * x

ARCHITECTURES = ('transformer', 'rnn', 'lstm')
PARAMETER_BUDGETS = (128000, 256000, 512000)
DEFAULT_OPACITY = 0.5
LANGUAGE_CLASSES = ('podfa', 'star-free', 'regular', 'context-free')
Y_SHIFTS = (-3.3, -5.8, -8.3, -10.8)

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--base-dir', type=pathlib.Path, required=True)
    parser.add_argument('--num-languages', type=int, required=True)
    parser.add_argument('--num-trials', type=int, required=True)
    parser.add_argument('--tex-output', type=pathlib.Path , required=True)
    args = parser.parse_args()

    base_dir = args.base_dir
    num_languages = args.num_languages
    num_trials = args.num_trials

    output_dir = base_dir / 'figures'
    output_dir.mkdir(exist_ok=True)
    tex_output = output_dir / args.tex_output.with_suffix('.tex')
    print(f'writing {tex_output}')
    with tex_output.open('w') as fout:
        fout.write(
            r'''\begin{tikzpicture}[every node/.style={font=\large}]

\tikzset{
  box/.style={draw, rounded corners, very thick},
  rect/.style={draw, minimum width=0.25cm, minimum height=0.3cm, inner sep=0pt}
}

\definecolor{contextfree}{RGB}{180,240,180}
\definecolor{regular}{RGB}{100,150,200}
\definecolor{dotdepth}{RGB}{250,180,200}
\definecolor{podfa}{RGB}{150,100,200}

\definecolor{fillgreen}{RGB}{67, 217, 107}
\definecolor{fillred}{RGB}{222, 87, 82}
\definecolor{fillblue}{RGB}{15, 98, 214}

\node[box, draw=contextfree, minimum width=14cm, rounded corners=28pt, minimum height=12.5cm, anchor=north, line width=1.5pt] (cf) {};
\node[box, draw=regular, minimum width=13.5cm, minimum height=9.2cm, anchor=north, yshift=-0.7cm, line width=1.5pt, rounded corners=23pt] (reg) at (cf.north) {};
\node[box, draw=dotdepth, minimum width=13cm, minimum height=6cm, anchor=north, yshift=-1.4cm, line width=1.5pt, rounded corners=17pt] (starfree) at (cf.north) {};
\node[box, draw=podfa, minimum width=12.5cm, minimum height=2.8cm, anchor=north, yshift=-2.1cm, line width=1.5pt, rounded corners=14pt] (podfa) at (cf.north) {};

\node at (cf.north) [below] {\textbf{Context-free}};
\node at (reg.north) [below] {\textbf{Regular}};
\node at (starfree.north) [below] {\textbf{Star-free}};
\node at (podfa.north) [below] {\textbf{PODFA}};

\newcommand{\coloredrect}[4]{
  \begin{scope}[shift={#1}]
    \path[draw=none, fill=fillblue, opacity=#2]   (0,0) rectangle (0.3,0.3);
    \path[draw=none, fill=fillred,  opacity=#3]   (0.3,0) rectangle (0.6,0.3);
    \path[draw=none, fill=fillgreen,opacity=#4]   (0.6,0) rectangle (0.9,0.3);
  \end{scope}
}

\newcommand{\placetriplet}[5]{
  \coloredrect{(#1,#2)}{#3}{#4}{#5}
}''')
        for num in range(len(LANGUAGE_CLASSES)):
            language_class = LANGUAGE_CLASSES[num]
            for n in range(num_languages):
                j = n % 10
                i = n // 10
                scores = []
                for architecture in ARCHITECTURES:
                    score, missing_dirs = get_max_over_trials(base_dir, language_class, n+1, num_trials, architecture)
                    scores.append(normalize_score(score) if score is not None else DEFAULT_OPACITY)
                    for missing_dir in missing_dirs:
                        print(f'missing: {missing_dir}', file=sys.stderr)
                fout.write(r'''
\placetriplet{''')
                x_pos = -5.3 + j * 1.1
                y_pos = Y_SHIFTS[num] - i * 0.6
                fout.write(str(round(x_pos,2)))
                fout.write(r'''cm}{''')
                fout.write(str(round(y_pos,2)))
                fout.write(r'''cm}{''')
                fout.write(str(round(scores[0],2)))
                fout.write(r'''}{''')
                fout.write(str(round(scores[1],2)))
                fout.write(r'''}{''')
                fout.write(str(round(scores[2],2)))
                fout.write(r'''}''')
        fout.write(r'''

\end{tikzpicture}''')

if __name__ == '__main__':
    main()
