import json
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

INPUT_FOLDER = "/work/<name>/RESULTS_LLM_ATTACKS/"
STEPS_CUTOFF = 45

def read_json_files(folder):
    """
    Read all JSON files in the specified folder and return a list of dictionaries.
    """
    json_dicts = []
    for file in Path(folder).glob("*.json"):
        with open(file, "r") as f:
            json_dict = json.load(f)
            json_dict["fname"] = file.name
            json_dicts.append(json_dict)
    return json_dicts


def process_dict(test_dict):
    """
    Extract and return the goal, loss, and jb from the test dictionary.
    """
    goal = list(test_dict.keys())[0]
    loss, jb = test_dict["n_loss"][0], test_dict["n_passed"][0]
    return goal, loss, jb


def get_controls_goals_targets(folder=None):
    """
    Process JSON files to extract controls, goals, and targets whith more optimisation steps than STEPS_CUTOFF.
    """
    if folder is None:
        folder = INPUT_FOLDER
        
    json_dicts = read_json_files(folder)

    controls_goals_targets = []
    n_goals_total, n_suf_steps_total, n_jb_total = 0, 0, 0


    for json_dict in json_dicts:
        controls, losses, tests = json_dict["controls"], json_dict["losses"], json_dict["tests"]
        assert len(controls) == len(losses) == len(tests)

        goal_dicts = {}

        for i, test in enumerate(tests):
            goal, loss, jb = process_dict(test)
            if goal not in goal_dicts:
                goal_dicts[goal] = {"loss": loss, "control": controls[i], "steps": 1, "jb": jb}
            else:
                if loss < goal_dicts[goal]["loss"] and jb:
                    goal_dicts[goal].update({"loss": loss, "control": controls[i], "jb": jb})
                goal_dicts[goal]["steps"] += 1

        for goal, data in goal_dicts.items():
            n_goals_total += 1
            target = json_dict["params"]["targets"][json_dict["params"]["goals"].index(goal)]

            if data["steps"] > STEPS_CUTOFF:
                n_suf_steps_total += 1
    
            if data["jb"]:
                n_jb_total += 1

            if data["steps"] > STEPS_CUTOFF and data["jb"]:
                controls_goals_targets.append((data["control"], goal, target))
    

    print(f"total: {n_goals_total}, suf_steps: {n_suf_steps_total}, jb: {n_jb_total}, kept: {len(controls_goals_targets)}")
    return controls_goals_targets


def main():
    controls_goals_targets = get_controls_goals_targets(INPUT_FOLDER)


if __name__ == "__main__":
    main()


