import argparse
import json
import numpy
import math
import pathlib
import itertools
import torch
import matplotlib.pyplot as plt

from rau.tools.logging import LogParseError, LogEvent, read_log_file

from ourproject.analysis.plot_cross_entropy_vs_num_edits import (
    load_labels,
    load_negative_score_lines,
    divide
)

from rau.tools.logging import read_log_file

def get_test_accuracy(fin):
    with open(fin) as f:
        data = json.load(f)
        return data['scores']['recognition_accuracy']

def main():
    parser = argparse.ArgumentParser()
    # parser.add_argument('--datasets', type=pathlib.Path, default='./languages')
    parser.add_argument('--models', type=pathlib.Path, default='./models')
    parser.add_argument('--output', type=pathlib.Path)
    args = parser.parse_args()

    architectures = ['transformer', 'rnn', 'lstm']
    setup = 'validation-long'
    losses = ['rec', 'rec+lm', 'rec+ns', 'rec+lm+ns']
    runs = [str(i) for i in range(1, 11)]
    languages = ['even-pairs', 'repeat-01', 'parity', 'cycle-navigation', 'modular-arithmetic-simple', \
                 'dyck-2-3', 'first', 'majority', 'stack-manipulation', 'marked-reversal', 'marked-copy', \
                 'missing-duplicate-string', 'odds-first', 'binary-addition', 'binary-multiplication', 'compute-sqrt', \
                 'bucket-sort']

    output_file = args.output
    print("|Language|Tf|RNN|LSTM|")
    with output_file.open('w') as fout:
        fout.write("| Language | Tf | RNN | LSTM |\n")
        fout.write("| ----------------- | ------- | ------- | ------- |\n")
        for language in languages:
            line = f'| {language} | '
            for i in range(len(architectures)):
                architecture = architectures[i]
                acc_by_model = []

                for loss in losses:
                    for run in runs:
                        model = args.models / language / architecture / loss / setup / run
                        log_file = model / 'eval' / 'test.json'
                        acc = get_test_accuracy(log_file)
                        acc_by_model.append((model, acc, loss))

                acc_by_model.sort(key=lambda x: x[1], reverse=True)
                model, _, loss = acc_by_model[0]

                log_file = model / 'logs' / 'main.log'
                used_loss_fns = loss.split('+')

                with open(log_file) as f:
                    log = read_log_file(f)
                    for event in log:
                        if event.type == 'training_info':
                            lm_coef = round(event.data['language_modeling_loss_coefficient'], 3)
                            ns_coef = round(event.data['next_symbols_loss_coefficient'], 3)
                            if 'lm' in used_loss_fns and 'ns' in used_loss_fns:
                                line += f'LM:{lm_coef}, NS:{ns_coef}'
                            elif 'lm' in used_loss_fns:
                                line += f'LM:{lm_coef}'
                            elif 'ns' in used_loss_fns:
                                line += f'NS:{ns_coef}'

                line += ' | '
            line += '\n'
            print(line)
            fout.write(line)


if __name__ == '__main__':
    main()