import argparse
import json
import numpy
import math
import sys
import pathlib
import dataclasses
import matplotlib.pyplot as plt

from recognizers.analysis.plot_cross_entropy_vs_num_edits import (
    load_labels,
    load_negative_score_lines,
    divide
)

from rau.tools.logging import read_log_file

def safe(x):
    return x if x is not None else math.nan

def load_scores(fin):
    for score_line in fin:
        score_line =json.loads(score_line)
        yield score_line['recognition_cross_entropy'][0] / score_line['recognition_cross_entropy'][1]

def load_lengths(fin):
    for line in fin:
        yield len(line.rstrip().split())

def load_lengths_and_scores(strings_fin, scores_fin):
    scores_list = load_scores(scores_fin)
    lengths_list = load_lengths(strings_fin)
    for score_line, l in zip(scores_list, lengths_list, strict=True):
        yield l, json.loads(score_line)

def get_validation_ce(fin):
    with open(fin) as f:
        log = read_log_file(f)
        for event in log:
            if event.type == 'train':
                return event.data['best_validation_scores']['recognition_cross_entropy']

def get_test_accuracy(fin):
    with open(fin) as f:
        data = json.load(f)
        return data['scores']['recognition_accuracy']

def get_smoothed_points(xs, ys):
    def points_in_window(middle, points, window_size=10):
        return [(x, y) for x, y in points if middle - window_size <= x and x <= middle + window_size]
    points = list(zip(xs, ys))
    groups = [(x, points_in_window(x, points)) for x in range(0, 501, 10)]
    xs = [x for x, _ in groups]
    means = [numpy.mean([y for _, y in x]) for _, x in groups]
    stds = [numpy.std([y for _, y in x]) for _, x in groups]
    return xs, means, stds

def format_class_name(language_class):
    if language_class == 'regular':
        return 'Regular'
    elif language_class == 'podfa':
        return 'PODFA'
    elif language_class == 'context-free':
        return 'Context-free'
    elif language_class == 'star-free':
        return 'Star-free'
    else:
        raise NotImplementedError


ARCHITECTURES = ('transformer','lstm', 'rnn')
ARCHITECTURE_ABBREV = ('Tf', 'LSTM', 'RNN')
SETUP = 'validation-long'
PARAMETER_BUDGETS = (64000, 128000, 256000, 512000)

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--language-class', choices = ['regular', 'podfa', 'context-free', 'star-free'], required=True)
    parser.add_argument('--base-dir', type=pathlib.Path, required=True)
    parser.add_argument('--tex-output', type=pathlib.Path , required=True)
    parser.add_argument('--num-languages', type=int, required=True)
    parser.add_argument('--num-trials', type=int, required=True)
    args = parser.parse_args()

    base_dir = args.base_dir
    num_languages = args.num_languages
    num_trials = args.num_trials
    language_class = args.language_class

    x, y = [], []
    for language_no in range(1, num_languages + 1):
        try:
            strings_file = base_dir / 'languages' / f'random-{language_class}-{language_no}' / 'datasets' / 'test' / 'main.tok'
            with open(strings_file, 'r') as fin:
                len_data = list(load_lengths(fin))

            ys = []
            # Get trial with best accuracy per architecture
            for architecture in ARCHITECTURES:
                acc_by_model = []
                for parameter_budget in PARAMETER_BUDGETS:
                    for trial in range(1, num_trials + 1):
                        try:
                            model = base_dir / 'models' / f'{parameter_budget}' / f'random-{language_class}-{language_no}' / architecture / 'rec' / 'validation-long' / f'{trial}'
                            log_file = model / 'eval' / 'test.json'
                            acc = get_test_accuracy(log_file)
                            acc_by_model.append((model, acc))
                        except:
                            continue

                acc_by_model.sort(key=lambda x: x[1], reverse=True)
                model, _ = acc_by_model[0]

                scores_file = model / 'eval' / 'test.jsonl'
                with open(scores_file, 'r') as scores_fin:
                    data = list(load_scores(scores_fin))
                ys.append(data)
            scores_data = list(zip(ys[0], ys[1], ys[2]))
        except:
            continue
        x += len_data
        y += scores_data

    xs, meanstf, stdstf = get_smoothed_points(x, [z[0] for z in y])
    _, meansrnn, stdsrnn = get_smoothed_points(x, [z[1] for z in y])
    _, meanslstm, stdslstm = get_smoothed_points(x, [z[2] for z in y])

    output_dir = base_dir / 'figures'
    output_dir.mkdir(exist_ok=True)
    dat_output = output_dir / args.tex_output.with_suffix('.dat')
    tex_output = output_dir / args.tex_output.with_suffix('.tex')
    print(f'writing {dat_output}')
    with dat_output.open('w') as fout:
        print(
            'length meantf meanplusstdtf meanminusstdtf meanrnn meanplusstdrnn meanminusstdrnn meanlstm meanplusstdlstm meanminusstdlstm',
            file=fout)
        for xi, yitfmean, yitfstd, yirnnmean, yirnnstd, yilstmmean, yilstmstd in zip(xs, meanstf, stdstf, meansrnn,
                                                                                     stdsrnn, meanslstm, stdslstm):
            meanplusstdtf = safe(yitfmean + yitfstd if yitfmean is not None and yitfstd is not None else None)
            meanminusstdtf = safe(yitfmean - yitfstd if yitfmean is not None and yitfstd is not None else None)
            meanplusstdrnn = safe(yirnnmean + yirnnstd if yirnnmean is not None and yirnnstd is not None else None)
            meanminusstdrnn = safe(yirnnmean - yirnnstd if yirnnmean is not None and yirnnstd is not None else None)
            meanplusstdlstm = safe(yilstmmean + yilstmstd if yilstmmean is not None and yilstmstd is not None else None)
            meanminusstdlsmt = safe(
                yilstmmean - yilstmstd if yilstmmean is not None and yilstmstd is not None else None)
            to_print = f'{xi} {safe(yitfmean)} {meanplusstdtf} {meanminusstdtf} {safe(yirnnmean)} {meanplusstdrnn} {meanminusstdrnn} {safe(yilstmmean)} {meanplusstdlstm} {meanminusstdlsmt}\n'
            fout.write(to_print)
    print(f'writing {tex_output}')
    with tex_output.open('w') as fout:
        fout.write(
            r'''\begin{tikzpicture}
    \begin{axis}[
        legend style={anchor=north east, font=\tiny, opacity=0.6, text opacity=1, draw=gray},
        axis lines=left,
        every axis plot/.append style={-},
        xlabel={Input Length},
            enlarge x limits=0.1,
            xmin=0,
            xtick={0, 100, 200, 300, 400, 500},
            ylabel={Cross Entropy},
            title={''')
        fout.write(format_class_name(language_class))
        fout.write(r'''},
            ymin=0,
            enlarge y limits=0.1]
            \addplot[
                name path=A-Tf,
                color=blue,
                opacity=0.2,
                forget plot
            ] table [x=length, y=meanplusstdtf] {02-Paper/figures/ce-vs-length/''')
        fout.write(dat_output.name)
        fout.write(r'''};
            \addplot[
                name path=B-Tf,
                color=blue,
                opacity=0.2,
                forget plot
            ] table [x=length, y=meanminusstdtf] {02-Paper/figures/ce-vs-length/''')
        fout.write(dat_output.name)
        fout.write(r'''};
            \addplot[color=blue, opacity=0.2, forget plot] fill between [of=A-Tf and B-Tf];
            \addplot[
                color=blue
            ] table [x=length, y=meantf] {02-Paper/figures/ce-vs-length/''')
        fout.write(dat_output.name)
        fout.write(r'''};
            \addplot[
                name path=A-rnn,
                color=red,
                opacity=0.2,
                forget plot
            ] table [x=length, y=meanplusstdrnn] {02-Paper/figures/ce-vs-length/''')
        fout.write(dat_output.name)
        fout.write(r'''};
            \addplot[
                name path=B-rnn,
                color=red,
                opacity=0.2,
                forget plot
            ] table [x=length, y=meanminusstdrnn] {02-Paper/figures/ce-vs-length/''')
        fout.write(dat_output.name)
        fout.write(r'''};
            \addplot[color=red, opacity=0.2, forget plot] fill between [of=A-rnn and B-rnn];
            \addplot[
                color=red
            ] table [x=length, y=meanrnn] {02-Paper/figures/ce-vs-length/''')
        fout.write(dat_output.name)
        fout.write(r'''};
            \addplot[
                name path=A-lstm,
                color=green,
                opacity=0.2,
                forget plot
            ] table [x=length, y=meanplusstdlstm] {02-Paper/figures/ce-vs-length/''')
        fout.write(dat_output.name)
        fout.write(r'''};
            \addplot[
                name path=B-lstm,
                color=green,
                opacity=0.2,
                forget plot
            ] table [x=length, y=meanminusstdlstm] {02-Paper/figures/ce-vs-length/''')
        fout.write(dat_output.name)
        fout.write(r'''};
            \addplot[color=green, opacity=0.2, forget plot] fill between [of=A-lstm and B-lstm];
            \addplot[
                color=green
            ] table [x=length, y=meanlstm] {02-Paper/figures/ce-vs-length/''')
        fout.write(dat_output.name)
        fout.write(r'''};
        \end{axis}
    \end{tikzpicture} ''')

if __name__ == '__main__':
    main()