import json
import os
import shutil
import re
import pickle

from collections import defaultdict
import numpy as np
import pandas as pd

# argumets
models = ['bert-base-uncased','bert-base-uncased','roberta-base','roberta-base']
settings = ['inclusive','approximate', 'inclusive','approximate']

model_output_dir = './model_output'
target_metric = 'eval_groundtruth_f1_factual'

def get_factual_performance_from_checkpoint(exp_dir):
    # open latest checkpoint
    checkpoints = os.listdir(exp_dir)

    if len(checkpoints) > 0:
        checkpoints = [checkpoint for checkpoint in checkpoints if "checkpoint" in checkpoint]
        checkpoints = sorted(checkpoints)
        last_checkpoint = checkpoints[-1]
        last_trainer_state = os.path.join(exp_dir,last_checkpoint, 'trainer_state.json')
        with open(last_trainer_state, 'r') as fp:
            last_trainer_state = json.load(fp)

        # find step of best counterfactual performance
        best_step = last_trainer_state['global_step']
        
        log_history = last_trainer_state['log_history']
        log_index = 0
        while log_history[log_index]['step'] != best_step and target_metric not in log_history[log_index]:
            log_index += 1

        # get associated factual metrics
        metric = log_history[log_index][target_metric]
        return metric
    else:
        return None

def get_factual_performance(model, setting):
    subdirs = os.listdir(model_output_dir)
    subdirs = [os.path.join(model_output_dir, subdir) for subdir in subdirs]

    # filter correct dirs
    subdirs = [subdir for subdir in subdirs if os.path.isdir(subdir) and model in subdir and setting in subdir]

    # get factual performance for each experiment
    factual_performance_map = defaultdict(list)
    for subdir in subdirs:
        k = int(re.search('__(\d+)-shot', subdir).group(1))

        perf = get_factual_performance_from_checkpoint(subdir)
        if perf:
            factual_performance_map[k].append(perf)

    # get mean and std
    factual_performance_list = []
    for k, v in factual_performance_map.items():
        v = np.array(v)
        mean = np.mean(v)
        std = np.std(v)

        factual_performance_list.append((k, mean, std))
        
    # sort
    factual_performance_list = sorted(factual_performance_list, key=lambda x: x[0])

    safe_target_metric = target_metric.replace('_','-')

    # pickle
    # with open(f'{safe_target_metric}_{model}_{setting}.pkl', 'wb') as fp:
    #     pickle.dump(factual_performance_list, fp)

    # csv
    k_dict = [t[0] for t in factual_performance_list]
    mean_dict = [f'{t[1]:.2f} ({t[2]:.2f})'for t in factual_performance_list]
    df = pd.DataFrame.from_dict({
        'k': k_dict,
        'mean': mean_dict,
    })
    df.to_csv(f'{safe_target_metric}_{model}_{setting}.csv')

    # print
    # print(f'{safe_target_metric}_{model}_{setting}: ', factual_performance_list)

if __name__ == '__main__':
    for model, setting in zip(models, settings):
        get_factual_performance(model, setting)