import os
import json


def parse_seq(lines):
    lines = [line.replace('\n', '') for line in lines]
    return {'example_seq': lines}


def parse_results(lines):
    dic = {'bleu': float(lines[0].split(':\t')[-1].replace('\n', '')),
           'cosinus': float(lines[1].split(':\t')[-1].replace('\n', '')),
           'w_overlap': float(lines[2].split(':\t')[-1].replace('\n', ''))}
    # 'style_accuracy': float(lines[3].split('\t:')[-1].replace('\n', '')),
    # 'ppl': float(lines[-1].split('\t:')[-1].replace('\n', ''))}
    return dic


if __name__ == '__main__':
    model_name = [('style_emb', 'checkpoint-100000')]
    desant = ['baseline', "no_reny", 'reny_1_3', 'reny_1_5']
    lambdas = [0.1, 0.5, 0.75, 1, 5, 10, 15, 20]
    file_name = ["True.txt", "False.txt"]

    all_dict = dict()
    for model in model_name:
        method_dic = dict()
        for method in desant:
            results_false_dict = dict()
            results_true_dict = dict()
            seq_false_dict = dict()
            seq_true_dict = dict()
            desant_dict = dict()
            for lambda_ in lambdas:
                path = 'results_textual_ev/results_'

                path += '{}_lambda_{}_gender_{}_'.format(model[0], lambda_, method)
                print('Opening {}'.format(path + file_name[0]))
                with open(path + file_name[0], 'r') as file:
                    lines = file.readlines()
                results_false = parse_results(lines)
                results_false_dict[lambda_] = results_false

                with open(path + file_name[1], 'r') as file:
                    lines = file.readlines()
                results_true = parse_results(lines)
                results_true_dict[lambda_] = results_true

            method_dic[method] = {'results_false': results_false_dict, 'results_true': results_true_dict,
                                  'desantanglement': desant_dict}

        all_dict[model[0]] = method_dic

    with open('data_gender_sentences.json', 'w') as fp:
        json.dump(all_dict, fp)

# "{}_lambda_{}_{}"
# checkpoint-80000  dae
# checkpoint-500000 multi_dec
# checkpoint-540000 style_emb


# with open('sentences_True.txt', 'r') as file:
#    lines = file.readlines()

# parse_seq(lines)
