import os
import pickle
import sys

import torch

from core.msa import MSAFlat
from envs.mnist_grid import MNISTHyperGrid


def to_tensor(x):
    return torch.tensor(x, dtype=torch.float32).reshape(-1)


def test(save_path, msa_folder, size, n=100):
    mdp_path = os.path.join(save_path, "mdp.pkl")
    msa = None

    env = MNISTHyperGrid(dimensions=size, eps=0.0, goal_conditioned=True)
    mdp = pickle.load(open(mdp_path, "rb"))
    if msa_folder is not None:
        msa = MSAFlat(msa_folder).cpu()

    success = 0
    i = 0
    while i < n:
        env.reset()
        done = False
        if env.terminated:
            continue

        while not done:
            obs_i = to_tensor(env.observation["observation"].copy())
            init_i = env.available_mask
            if msa is not None:
                with torch.no_grad():
                    obs_i = msa.encode(obs_i)
            si_probs = mdp.get_grounding_prob(obs_i, init_i)
            if si_probs is None:
                continue
            s_i = mdp.states[si_probs.argmax()]

            obs_f = to_tensor(env.observation["desired_goal"].copy())
            st = env._state.copy()
            env._state = env._goal.copy()
            init_f = env.available_mask
            env._state = st
            if msa is not None:
                with torch.no_grad():
                    obs_f = msa.encode(obs_f)
            s_probs = mdp.get_grounding_prob(obs_f, init_f)
            if s_probs is None:
                continue
            s_f = mdp.states[s_probs.argmax()]
            paths = mdp.traverse_graph(s_i)
            if s_f in paths:
                _, path = paths[s_f]
                if len(path) > 0:
                    first_parent, first_action = path[0]
                    # take the first step of the plan
                    env.step(first_action)
                else:
                    break
            else:
                break
            done = env.truncated or env.terminated

        if env.terminated:
            success += 1
        i += 1

    return success / n


if __name__ == "__main__":
    save_path = sys.argv[1]
    msa_folder = None
    if len(sys.argv) > 2:
        msa_folder = sys.argv[2]
    acc = test(save_path, msa_folder, (6, 6))
    print(f"Accuracy={acc:.2f}")
