import json
import os
import statistics
import numpy as np


if __name__ == '__main__':
    os.makedirs('metrics', exist_ok=True)

    n_picks = 1

    outputs_inputs = {
        
        'metrics/trusting_1pick.json': [
            'evaluate/trusting/results/average_metrics_per_pick.json',
            'evaluate/trusting/results/average_metrics_per_pick.json',
            'evaluate/trusting/results/average_metrics_per_pick.json',
            'evaluate/trusting/results/average_metrics_per_pick.json'
        ],
        'metrics/noglob_1pick.json': [
            'evaluate/noglobalobjective/results/average_metrics_per_pick.json',
            'evaluate/noglobalobjective/results/average_metrics_per_pick.json',
            'evaluate/noglobalobjective/results/average_metrics_per_pick.json',
            'evaluate/noglobalobjective/results/average_metrics_per_pick.json'
        ],
        'metrics/helios_1pick.json': [
            'evaluate/helios-0_300/results/average_metrics_per_pick.json',
            'evaluate/helios-300_600/results/average_metrics_per_pick.json',
            'evaluate/helios-600_900/results/average_metrics_per_pick.json',
            'evaluate/helios-900_1200/results/average_metrics_per_pick.json'
        ],
    }

    for output, inputs in outputs_inputs.items():
        output_jd = []
        for input in inputs:
            with open(input, 'r') as f:
                data = f.read()  
                data = data.split("}\n")
                data = [d.strip() + "}" for d in data]
                data = list(filter(("}").__ne__, data))
                for d in data:
                    if d[-2:] == '}}':
                        d= d[:-1]
                    jd = json.loads(d)
                    if jd['max_picks'] == n_picks:
                        output_jd += [jd]
        outputs_inputs[output] = output_jd


    for output, inputs in outputs_inputs.items():
        output_new = {}
        for key in inputs[0]:
            if key == 'episode_count':
                output_new['episode_count'] = sum(input['episode_count'] for input in inputs)
            else:
                total = sum([input[key] * input['episode_count'] for input in inputs])
                n = sum(input['episode_count'] for input in inputs)
                output_new[key] = total/n
                output_new[key+'_stderr'] = statistics.stdev([1]*int(total) + [0]*int(n-total))/np.sqrt(n)
        outputs_inputs[output] = output_new
    {
        json.dump(input, open(output, 'w'), indent=4)
        for output, input in outputs_inputs.items()
    }

    large_metrics = ['find_objects', 'pick_objects', 'find_receps', 'place_objects', 'overall_success',]
    for output, input in outputs_inputs.items():
        print(os.path.basename(output).split('.')[0], '\t', ', '.join(f'{key}: {input[key] * 100:.1f}%' for key in large_metrics))
    # small_metrics = ['0_failed_find_object', '1_failed_pick_object', '2_failed_find_recep', '3_failed_place_goal', '4_failed_place_rest', '5_failed place_stable', '6_collision', '7_success']
    # for output, input in outputs_inputs.items():
    #     print(os.path.basename(output).split('.')[0], '\t', ', '.join(f'{key}: {input[key] * 100:.2f}%' for key in small_metrics))

    #calculate standard error
    for output, input in outputs_inputs.items():
        print(os.path.basename(output).split('.')[0], '\t', ', '.join(f'{key}: {input[key+"_stderr"] * 100:.1f}%' for key in large_metrics))