import json
import os
import statistics
import numpy as np

if __name__ == '__main__':
    os.makedirs('metrics', exist_ok=True)
    outputs_inputs = {
        'metrics/gtsem_trusting.json': [
            'evaluate/gtsem_trusting/results/average_metrics.json',
            'evaluate/gtsem_trusting/results/average_metrics.json',
            'evaluate/gtsem_trusting/results/average_metrics.json',
            'evaluate/gtsem_trusting/results/average_metrics.json'
        ],

        'metrics/gtsem_helios.json': [
            'evaluate/gtsem_helios-0_300/results/average_metrics.json',
            'evaluate/gtsem_helios-300_600/results/average_metrics.json',
            'evaluate/gtsem_helios-600_900/results/average_metrics.json',
            'evaluate/gtsem_helios-900_1200/results/average_metrics.json'
        ],

        'metrics/trusting.json': [
            'evaluate/trusting/results/average_metrics.json',
            'evaluate/trusting/results/average_metrics.json',
            'evaluate/trusting/results/average_metrics.json',
            'evaluate/trusting/results/average_metrics.json'
        ],
        'metrics/noglob.json': [
            'evaluate/noglobalobjective/results/average_metrics.json',
            'evaluate/noglobalobjective/results/average_metrics.json',
            'evaluate/noglobalobjective/results/average_metrics.json',
            'evaluate/noglobalobjective/results/average_metrics.json'
        ],
        'metrics/helios.json': [
            'evaluate/helios-0_300/results/average_metrics.json',
            'evaluate/helios-300_600/results/average_metrics.json',
            'evaluate/helios-600_900/results/average_metrics.json',
            'evaluate/helios-900_1200/results/average_metrics.json'
        ],

    }

    outputs_inputs = {
        output: [json.load(open(input)) for input in inputs]
        for output, inputs in outputs_inputs.items()
    }

    for output, inputs in outputs_inputs.items():
        output_new = {}
        for key in inputs[0]:
            if key == 'episode_count':
                output_new['episode_count'] = sum(input['episode_count'] for input in inputs)
            else:
                total = sum([input[key] * input['episode_count'] for input in inputs])
                n = sum(input['episode_count'] for input in inputs)
                output_new[key] = total/n
                output_new[key+'_stderr'] = statistics.stdev([1]*int(total) + [0]*int(n-total))/np.sqrt(n)
        print(output_new['episode_count'])
        outputs_inputs[output] = output_new

    {
        json.dump(input, open(output, 'w'), indent=4)
        for output, input in outputs_inputs.items()
    }

    large_metrics = ['find_objects', 'pick_objects', 'find_receps', 'place_objects', 'overall_success',]
    for output, input in outputs_inputs.items():
        print(os.path.basename(output).split('.')[0], '\t', ', '.join(f'{key}: {input[key] * 100:.1f}%' for key in large_metrics))
    # small_metrics = ['0_failed_find_object', '1_failed_pick_object', '2_failed_find_recep', '3_failed_place_goal', '4_failed_place_rest', '5_failed place_stable', '6_collision', '7_success']
    # for output, input in outputs_inputs.items():
    #     print(os.path.basename(output).split('.')[0], '\t', ', '.join(f'{key}: {input[key] * 100:.1f}%' for key in small_metrics))

    #calculate standard error
    for output, input in outputs_inputs.items():
        print(os.path.basename(output).split('.')[0], '\t', ', '.join(f'{key}: {input[key+"_stderr"] * 100:.1f}%' for key in large_metrics))