import argparse
import os

import torch
import numpy as np

from core.build import build
from core.data import Factor, TransitionData
from core.explore import collect_data
from core.visualise import visualise_mnist_chain
from envs.mnist_grid import MNISTHyperGrid


def main(save_folder, name, n_sample, grid_size, eps, seed, t_delta,
         r_delta, min_t, min_r, n_cluster, min_samples):
    torch.manual_seed(seed)
    np.random.seed(seed)

    changed_factors = None
    save_folder = "__".join([name] + save_folder.split("__")[1:])
    save_folder = os.path.join("save", save_folder)
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)

    file = os.path.join(save_folder, "data.tar.gz")
    if not os.path.exists(file):
        env = MNISTHyperGrid(dimensions=grid_size, eps=eps, use_initiation_vector=True)
        data = collect_data(env=env,
                            max_episode=n_sample,
                            max_timestep_per_ep=20,
                            verbose=True,
                            seed=seed,
                            n_jobs=5)
        data.to_pickle(file)
    else:
        data = TransitionData(file)

    f1 = Factor("f1", list(range(784)))
    factors = [f1]

    mdp, _, _ = build(transition_data=data,
                      option_names=[0, 1],
                      factors=factors,
                      transition_error_delta=t_delta,
                      reward_error_delta=r_delta,
                      min_transition_error=min_t,
                      min_reward_error=min_r,
                      changed_factors=changed_factors,
                      n_cluster_trials=n_cluster,
                      min_samples=min_samples)
    visualise_mnist_chain(mdp, "chain")
    mdp.save(os.path.join(save_folder, "mdp.pkl"))
    print("Refinement finished.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Build abstract MDP on MNIST chain.")
    parser.add_argument("--n-sample", help="Number of samples", type=int, required=True)
    parser.add_argument("--grid-size", help="Size in each dimension", nargs="+", type=int, required=True)
    parser.add_argument("--eps", help="The stochasticity rate of the environment", type=float, required=True)
    parser.add_argument("--save-folder", type=str, required=True)
    parser.add_argument("--name", type=str, required=True)
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--t-delta", help="Minimum transition error delta for refinement", type=float)
    parser.add_argument("--r-delta", help="Minimum reward error delta for refinement", type=float)
    parser.add_argument("--min-t", help="Minimum transition error for a state", type=float)
    parser.add_argument("--min-r", help="Minimum reward error for a state", type=float)
    parser.add_argument("--n-cluster", help="Number of cluster trials", type=int, default=1)
    parser.add_argument("--min-samples", help="Minimum samples for a state", type=int, default=10)
    args = parser.parse_args()

    main(save_folder=args.save_folder,
         name=args.name,
         n_sample=args.n_sample,
         grid_size=args.grid_size,
         eps=args.eps,
         seed=args.seed,
         t_delta=args.t_delta,
         r_delta=args.r_delta,
         min_t=args.min_t,
         min_r=args.min_r,
         n_cluster=args.n_cluster,
         min_samples=args.min_samples)
