import argparse
import os

import numpy as np
import torch
from gym_montezuma.envs.montezuma_env import MontezumasRevengeEnv

from core.build import build, _get_modified_factors
from core.data import Factor
from core.explore import collect_data


def main(save_folder, name, n_sample, clip_eps, seed, t_delta,
         r_delta, min_t, min_r, n_cluster, min_samples):
    np.random.seed(seed)
    torch.manual_seed(seed)

    save_folder = "__".join([name] + save_folder.split("__")[1:])
    save_folder = os.path.join("save", save_folder)
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)

    env = MontezumasRevengeEnv(single_screen=True, eps=0.0, clip_eps=clip_eps, observation_mode="ram")
    data = collect_data(env, max_episode=n_sample, max_timestep=n_sample,  # type: ignore
                        max_timestep_per_ep=200, verbose=True, seed=seed, n_jobs=4)

    factors = [
        Factor("room", [3]),
        Factor("score", [19, 20, 21]),
        Factor("player_status", [30]),
        Factor("player_x", [42]),
        Factor("player_y", [43]),
        Factor("key_monster_x", [44]),
        Factor("key_monster_y", [45]),
        Factor("enemy_skull_y", [46]),
        Factor("enemy_skull_x", [47]),
        Factor("player_dir", [52]),
        Factor("level", [57]),
        Factor("lives", [58]),
        Factor("inventory", [61]),
        Factor("room_state", [62])
    ]
    changed_factors = _get_modified_factors(27, factors, data.data)
    for o in changed_factors:
        print(f"{o}: {[f.name for f in changed_factors[o]]}")

    print(f"n_samples={len(data.data)}, n_factors={len(factors)}")

    mdp, _, _ = build(data,
                      env.action_names,
                      factors=factors,
                      transition_error_delta=t_delta,
                      reward_error_delta=r_delta,
                      min_transition_error=min_t,
                      min_reward_error=min_r,
                      n_cluster_trials=n_cluster,
                      changed_factors=changed_factors,
                      min_samples=min_samples)
    mdp.save(os.path.join(save_folder, "mdp.pkl"))
    print("Refinement finished.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Build abstract MDP on MNIST grid.")
    parser.add_argument("--n-sample", help="Number of samples", type=int, required=True)
    parser.add_argument("--clip-eps", help="The stochasticity rate of the environment", type=float, required=True)
    parser.add_argument("--save-folder", type=str, required=True)
    parser.add_argument("--name", type=str, required=True)
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--t-delta", help="Minimum transition error delta for refinement", type=float)
    parser.add_argument("--r-delta", help="Minimum reward error delta for refinement", type=float)
    parser.add_argument("--min-t", help="Minimum transition error for a state", type=float)
    parser.add_argument("--min-r", help="Minimum reward error for a state", type=float)
    parser.add_argument("--n-cluster", help="Number of cluster trials", type=int, default=1)
    parser.add_argument("--min-samples", help="Minimum samples for a state", type=int, default=10)
    args = parser.parse_args()

    main(save_folder=args.save_folder,
         name=args.name,
         n_sample=args.n_sample,
         clip_eps=args.clip_eps,
         seed=args.seed,
         t_delta=args.t_delta,
         r_delta=args.r_delta,
         min_t=args.min_t,
         min_r=args.min_r,
         n_cluster=args.n_cluster,
         min_samples=args.min_samples)
