from pathlib import Path
import pandas as pd
from collections import namedtuple
import json
import numpy as np
from numpy.testing import assert_equal

import pickle
from anpl.sandbox import import_module_from_string, timeout

Log = namedtuple("Log", ["task", "user", "robot", "code", "entry"])

def read_progs(path):
    logs = []
    for date_dir in path.iterdir():
        for user_dir in date_dir.iterdir():
            for task_dir in user_dir.iterdir():
                task_id, robot = task_dir.stem.split("_")
                if robot == "A":
                    anpl_files = list(task_dir.glob("*.pkl"))
                    assert len(anpl_files) <= 1
                    if len(anpl_files) == 0:
                        codes, entry = None, None
                    else:
                        with open(anpl_files[0], 'rb') as f:
                            anpl = pickle.load(f)
                            codes = anpl.to_python(for_user=False)
                            entry = anpl.entry
                elif robot == "B":
                    py_files = list(task_dir.glob("*.py"))
                    assert len(anpl_files) <= 1
                    if len(py_files) == 0:
                        codes, entry = None, None
                    else:
                        with open(py_files[0], 'r') as f:
                            codes = f.read()
                            entry = "main"
                logs.append(Log(int(task_id[4:]), user_dir.stem, robot, codes, entry))
    return logs

def eq_outs(m1, m2):
    try:
        assert_equal(m1, m2)
        return True
    except Exception:
        return False

def check(codes, entry, excepted_ios):
    module = import_module_from_string(codes)
    entry_point = getattr(module, entry, None)

    for v in excepted_ios:
        inp, out = np.array(v["input"]), np.array(v["output"])
        try:
            f = timeout(timeout=1)(entry_point)
            real_out = f(inp)
        except Exception as e:
            print(e)
            real_out = None
        if not eq_outs(out, real_out):
            return False
    return True

def load_ios():
    ios_dict = {}
    for file in Path("train_io").iterdir():
        task_id = int(file.stem)
        with open(file, "r") as f:
            ios = json.load(f)
        ios_dict[task_id] = ios
    return ios_dict

if __name__ == "__main__":
    ios_dict = load_ios()
    progs = read_progs(Path("../anpl_data"))

    A_one_shot = {5, 6, 11, 20, 21, 26, 29, 34, 37, 44, 48, 51, 61, 69, 74, 77, 80, 82, 83, 86, 87, 89, 90, 94, 102, 107, 108, 110, 115, 119, 120, 125, 128, 134, 139, 141, 149, 151, 154, 163, 170, 171, 178, 188, 192, 193, 206, 209, 216, 222, 223, 226, 228, 240, 241, 257, 265, 266, 273, 275, 279, 281, 286, 288, 290, 302, 304, 309, 310, 311, 316, 321, 325, 326, 328, 333, 335, 336, 337, 341, 342, 343, 345, 346, 350, 358, 370, 374, 379, 383, 387, 390, 394, 398}
    B_one_shot = {128, 257, 135, 9, 139, 141, 398, 271, 273, 275, 149, 23, 151, 26, 154, 28, 29, 31, 34, 290, 291, 165, 166, 41, 170, 171, 44, 49, 178, 51, 180, 52, 310, 311, 184, 313, 314, 315, 316, 187, 185, 193, 321, 69, 77, 336, 209, 82, 80, 86, 87, 93, 222, 226, 99, 230, 108, 110, 240, 241, 370, 115, 371, 374, 119, 379, 380}

    print(len(A_one_shot))
    print(len(B_one_shot))

    is_one_shot = lambda x: x.task in A_one_shot if x.robot == "A" else x.task in B_one_shot

    data = {k: {"task": k, "user": None, "A_train": 0, "B_train": 0, "A_one_train": 0, "B_one_train": 0, "A_test": 0, "B_test": 0, "A_one_test": 0, "B_one_test": 0} for k in range(400)}
    for prog in progs:
        data[prog.task]["user"] = prog.user
        if prog.code:
            one_shot_flag = is_one_shot(prog)
            data[prog.task][f"{prog.robot}_test"] += 1
            if one_shot_flag:
                data[prog.task][f"{prog.robot}_one_test"] += 1

            if check(prog.code, prog.entry, ios_dict[prog.task]):
                data[prog.task][f"{prog.robot}_train"] += 1
                if one_shot_flag:
                    data[prog.task][f"{prog.robot}_one_train"] += 1


    df = pd.DataFrame.from_records(list(data.values()))
    df.to_csv("train_io.csv", index=False)
