import sys
import time
import json
import wandb
from pathlib import Path

def get_run_by_name(name):
    api = wandb.Api()
    return api.runs(path="open_instruct", filters={"display_name": name})[0]


def get_all_tydiqa_metric(metrics_file):
    with metrics_file.open() as f:
        metrics = json.load(f)
    flatten_metrics = {f"tydiqa_{k}": v['f1'] for k, v in metrics.items()}
    return flatten_metrics

def get_all_bbh_metric(metrics_file):
    with metrics_file.open() as f:
        metrics = json.load(f)
    return {f"bbh_{k}": v for k, v in metrics.items()}

def get_all_gsm_metric(metrics_file):
    with metrics_file.open() as f:
        metrics = json.load(f)
    return {f"gsm_{k}": v for k, v in metrics.items()}


def upload_metrics(exp_name, ws_path):
    run = get_run_by_name(exp_name)

    if not run or run.state != 'finished':
        print(f"Run not finished")
        return

    print(f"{run.name} found")

    run_summary = run.summary
    for eval_task in ['tydiqa', 'bbh', 'gsm']:
        print(f"checking {eval_task}")
        metrics_file = ws_path / 'exp_results' /  exp_name / eval_task / 'metrics.json'
        print(metrics_file)
        if not metrics_file.exists():
            continue

        all_scores = {}
        if eval_task == 'tydiqa':
            all_scores.update(get_all_tydiqa_metric(metrics_file))
        elif eval_task == 'bbh':
            all_scores.update(get_all_bbh_metric(metrics_file))
        elif eval_task == 'gsm':
            all_scores.update(get_all_gsm_metric(metrics_file))
        else:
            raise ValueError(f"Unknown eval task {eval_task}")

        for score_name, score in all_scores.items():
            if score_name not in run_summary:
                print(f"uploading {score_name}")
                run_summary[score_name] = score
            else:
                if round(score, 3) != round(run_summary[score_name], 3):
                    raise ValueError(f"Score {score_name} already exists with diff value {run_summary[score_name]}. Current value={score}")


        run.update()


def main():
    ws_path = Path('/mnt/workspace')
    exp_names = []
    if len(sys.argv) < 2:
        exp_results_path = ws_path / 'exp_results'
        for exp_name in exp_results_path.iterdir():
            if exp_name.is_dir():
                if 'human_mix' in exp_name.name:
                    exp_names.append(exp_name.name) 
    else:
        exp_names = sys.argv[1:]
        

    for exp_name in exp_names:
        print(f"uploading {exp_name}")
        upload_metrics(exp_name, ws_path)

        time.sleep(5)


if __name__ == "__main__":
    main()