import jsonlines
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
def pretty(d, indent=0):
   for key, value in sorted(d.items()):
      print('\t' * indent + str(key))
      if isinstance(value, dict):
         pretty(value, indent+1)
      else:
          print('\t' * (indent+1) + f"{value}")

RES = {'ERM-positives': defaultdict(list), 'FRM-positives': defaultdict(list), 
        'ERM-negatives': defaultdict(list), 'FRM-negatives': defaultdict(list)}
SUM = {'ERM-positives': defaultdict(dict), 'FRM-positives': defaultdict(dict), 
        'ERM-negatives': defaultdict(dict), 'FRM-negatives': defaultdict(dict)}

DSUM = {}
filename = 'logs_functional_data/res_clean_OS'
count_seen_seed = defaultdict(int)
neg_str = 'negatives'
with jsonlines.open(filename, 'r') as reader:
    for l in reader: 
        count_seen_seed[l['SEED']] += 1

with jsonlines.open(filename, 'r') as reader:
    for l in reader: 
        # skip seed if it hasn't yet seen all 4 configurations
        if count_seen_seed[l['SEED']] < 4: continue
        for app in ['',f'_{neg_str}']:
            for k in [f'train_ERM{app}', f'train_FRM{app}', f'test_ERM{app}', f'test_FRM{app}']:
                RES[f"{l['LOSS']}-{l['LOSS_SPACE']}"][k].append(l[k])

for loss in RES.keys():
    for k in RES[loss].keys():
        SUM[loss][k] = (np.mean(RES[loss][k]), 
                np.std(RES[loss][k])/np.sqrt(len(RES[loss][k])))
pretty(SUM)

for k in RES[f'FRM-{neg_str}'].keys():
    DSUM[k] = {}
    for loss in RES.keys():
        DSUM[k][loss] = (np.round(np.mean(RES[loss][k]),3), 
                np.round(np.std(RES[loss][k])/np.sqrt(len(RES[loss][k])),3))

print("Test ERM positives: ")
pretty(DSUM['test_ERM'])
print(f"Test ERM {neg_str}: ")
pretty(DSUM[f'test_ERM_{neg_str}'])
