import sacrebleu
import os
import json
import numpy as np
import glob
from collections.abc import Iterable

def is_non_str_iterable(obj):
    return isinstance(obj, Iterable) and not isinstance(obj, str)

def _sacreformat(refs, preds):
    # Sacrebleu expects (List[str], List[List[str])
    #   e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])

    # Note [ref1_stream] is the first reference for each pred.
    # So lists are size N and (M, N) for N preds and M possible refs for each pred
    # This is a different order of dimensions that I would expect

    # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
    # Must become List[List[str]] with the inner list corresponding to preds
    if not is_non_str_iterable(refs):
        refs = list(refs)
    if not is_non_str_iterable(refs[0]):
        refs = [[ref] for ref in refs]
    refs = list(zip(*refs))
    # Note the number of refs in each ref list much match the number of preds

    # We expect preds to be List[str] or List[List[str]]. Must become List[str]
    if not is_non_str_iterable(preds):
        preds = list(preds)
    if is_non_str_iterable(preds[0]):
        assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
        preds = [pred[0] for pred in preds]

    return refs, preds

datasets = ["boolq", "winogrande", "arc_easy", "math_genie"]


for dataset in datasets:
    print(dataset)
    base = "eval_confirm/mistral/structured/" + dataset
    bootstrap_info = {}
    for run in range(1,3):
        true_scores = []
        bootstrap_scores = []
        for t in ["Base", "Orca", "Platypus", "Genie", "Alpaca", "Khan", "Stax", "Wiki"]:
            patch_path = os.path.join(base, str(run), str(t), "merged")
            if t == "Base":
                patch_path = os.path.join(base, "merged")
            if t == "Orca":
                patch_path = os.path.join(base, "Orca", "merged")

            names = os.listdir(patch_path)
            try:
                names.remove(".ipynb_checkpoints")
            except:
                pass
            name = names[0]
            patch_path = os.path.join(patch_path, name)    
            score_patch_path = glob.glob(os.path.join(patch_path, "results*.json"))[0]
        
            with open(score_patch_path, "r") as f:
                patch_json = json.load(f)
                name = dataset
                if "math_genie" in dataset:
                    true_scores.append(patch_json["results"]["math_genie"]["bleu,none"])
                else:
                    true_scores.append(patch_json["results"][name]["acc,none"])

            bootstrap_patch_path = glob.glob(os.path.join(patch_path, "samples*.jsonl"))[0]
    
            scores = []
            with open(bootstrap_patch_path, 'r') as file:
                for line in file:
                    example = json.loads(line)
                    if "math_genie" in dataset:
                        ref = [example["bleu"][0]]
                        pred = [example["bleu"][1]]
                        refs, preds = _sacreformat(ref, pred)
                        score = sacrebleu.corpus_bleu(preds, refs).score
                        scores.append(score)
                    else:
                        scores.append(example["acc"])
            bootstrap_list = [np.mean(np.random.choice(scores, len(scores), replace=True)) for _ in range(20)]
            bootstrap_scores.append(bootstrap_list)

        bootstrap_info[run] = (true_scores, bootstrap_scores)
    with open(f"structured_bootstrap/{name}.json", "w") as handle:
        json.dump(bootstrap_info, handle)