import os
import pickle
import sys

import numpy as np
import torch

from core.msa import MSAFlat
from envs.vault import Vault


def eps_pos():
    eps = np.random.randn(3) * 0.1
    eps[1] = 0
    return eps


def eps_dir():
    return np.random.randn() * 0.05


def to_tensor(x):
    return torch.tensor(x, dtype=torch.float32).reshape(-1)


def test(save_path, msa_folder, resolution, add_portraits, n=100):
    msa = MSAFlat(msa_folder).cpu()
    mdp_path = os.path.join(save_path, "mdp.pkl")
    mdp = pickle.load(open(mdp_path, "rb"))

    env = Vault(render_mode="rgb_array",
                max_option_steps=20,
                goal_conditioned=True,
                obs_width=resolution[0],
                obs_height=resolution[1],
                add_portraits=add_portraits)

    success = 0
    i = 0
    while i < n:
        env.reset()
        done = False
        terminated = False
        if env.reward(env.info) > 0.5:
            continue

        while not done:
            obs_i = to_tensor(env.observation["observation"].copy())
            # init_i = env.available_mask
            init_i = (1, 1, 1, 1, 1)
            with torch.no_grad():
                z = msa.encode(obs_i)
            si_probs = mdp.get_grounding_prob(z, init_i)
            if si_probs is None:
                break
            s_i = mdp.states[si_probs.argmax()]

            obs_f = to_tensor(env.observation["desired_goal"].copy())
            # init_f = env.info["goal_init"]
            init_f = (1, 1, 1, 1, 1)
            with torch.no_grad():
                z = msa.encode(obs_f)
            sf_probs = mdp.get_grounding_prob(z, init_f)
            if sf_probs is None:
                break
            s_f = mdp.states[sf_probs.argmax()]

            paths = mdp.traverse_graph(s_i)
            if s_f in paths:
                _, path = paths[s_f]
                if len(path) > 0:
                    first_parent, first_action = path[0]
                    # take the first step of the plan
                    _, _, _, truncated, info = env.step(first_action)
                    terminated = env.reward(info) > 0.5
                    done = terminated or truncated
                else:
                    break
            elif env.reward(env.info) > 0.5:
                terminated = True
                break
            else:
                break

        if terminated:
            success += 1
        i += 1

    return success / n


if __name__ == "__main__":
    save_path = sys.argv[1]
    msa_folder = None
    if len(sys.argv) > 2:
        msa_folder = sys.argv[2]
    acc = test(save_path, msa_folder, resolution=(80, 60), add_portraits=False)
    print(f"Accuracy={acc:.2f}")
