from evaluation_suite import *

import importlib
import os
from results import results

def get_suites():
    suites = {}
    # import all python files in this dir that start with eval-
    for f in os.listdir(os.path.dirname(__file__)):
        if f.startswith("eval-") and f.endswith(".py"):
            modname = f.replace(".py", "")
            mod = __import__(modname)
            for name in dir(mod):
                # if class
                if isinstance(getattr(mod, name), type):
                    # is subclass of EvaluationSuite
                    if issubclass(getattr(mod, name), EvaluationSuite) and not getattr(mod, name) is EvaluationSuite:
                        if hasattr(getattr(mod, name), "name"):
                            suites[getattr(mod, name).name] = getattr(mod, name)
                        else:
                            print("Cannot find name for suite " + name)
    return suites

def get_suite(name):
    suites = get_suites()
    if name in suites:
        return suites[name]
    else:
        print("Suite " + str(name) + " not found")
        return None

@dataclass
class EvaluationTask:
    suite: EvaluationSuite
    model: str
    decoder: str
    shots: int
    size: str
    num_workers: int
    kwargs: dict

def parse_args():
    # args are evaluator.py suite_name model [decoder=argmax] [shots=0] [size=mini] [num_workers=3]
    # e.g. evaluator.py date_understanding@cot gpt2 [decoder=beam_var] [shots=0] [size=mini] [num_workers=3]
    args = sys.argv
    if len(args) < 3:
        if "list" in set(sys.argv):
            print("Available suites:")
            for suite in get_suites():
                print(suite)
            exit(0)

        print("Usage: evaluator.py suite_name model [decoder=argmax] [shots=0] [size=mini] [num_workers=3]")
        exit(1)
    print(args)
    suite_name = args[1]
    model = args[2]
    decoder = "argmax"
    shots = 0
    size = "mini"
    num_workers = 3
    kwargs = {}
    next_is_diff_file = False

    # parse remaining args as key=value
    for arg in args[3:]:
        if next_is_diff_file:
            kwargs["diff"] = arg
            next_is_diff_file = False
            continue
        if "=" in arg:
            key, value = arg.split("=")
            if key == "decoder":
                decoder = value
                assert decoder in ["beam_var", "var", "best_k", "argmax", "beam_search", "bsseq"]
            elif key == "shots":
                shots = int(value)
            elif key == "size":
                size = value
            elif key == "num_workers":
                num_workers = int(value)
            else:
                kwargs[key] = value
        elif arg == "diff":
            next_is_diff_file = True
            continue
        else:
            if arg == "queries": 
                kwargs["queries"] = True
                continue
            print("Invalid argument " + arg)
            exit(1)
    
    task = EvaluationTask(suite_name, model, decoder, shots, size, num_workers, kwargs)
    return task

if __name__ == "__main__":
    if sys.argv[1] == "rm":
        hashes_to_delete = set(sys.argv[2:])
        print(hashes_to_delete)
        evaluation_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "evaluation")
        os.chdir(evaluation_dir)
        def to_delete(h):
            return h in hashes_to_delete
        result_df_to_trash = results(hash=to_delete)
        hashes_in_df = list(result_df_to_trash["hash"])
        assert len(hashes_in_df) == len(set(hashes_in_df)), f"Duplicate hashes for results: {print(result_df_to_trash)}"

        for s in result_df_to_trash.iloc:
            basepath = os.path.splitext(s.result.path)[0]
            json_file = basepath + ".json"
            csv_file = basepath + ".csv"
        
            print(">", "trash " + json_file)
            os.system("trash " + json_file)
            print(">", "trash " + csv_file)
            os.system("trash " + csv_file)
            # os.system(f"trash resutls/*{t}

        sys.exit(0)

    task = parse_args()
    task.suite = get_suite(task.suite)

    config = {
        "model": task.model,
        "suite": task.suite,
        "shots": task.shots,
        "decoder": task.decoder,
        "kwargs": {**{
            "max_length": 512,
            # only computes top1 condition distribution (save some cost and time)
            "top1_distribution": True,
        }, **task.kwargs},
        "size": task.size,
        "num_workers": task.num_workers
    }

    if task.suite is None:
        print("Suite {} not found".format(str(task.suite)))
        exit(1)
    task.suite(size=config.get("size", None)).main(**config)