import argparse
import os

import torch
import numpy as np

from core.build import build
from core.data import Factor, TransitionData
from core.explore import collect_data
from core.msa import MSAFlat
from envs.mnist_grid import MNISTHyperGrid


def main(save_folder, name, n_sample, grid_size, eps, seed, t_delta,
         r_delta, min_t, min_r, n_cluster, min_samples,
         msa_encoding, use_msa_for_independence):
    torch.manual_seed(seed)
    np.random.seed(seed)

    changed_factors = None
    msa = None
    msa_folder = None
    save_folder = "__".join([name] + save_folder.split("__")[1:])
    save_folder = os.path.join("save", save_folder)
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)

    if msa_encoding:
        msa_folder = os.path.join("save", f"{name}__nsample_{n_sample}__seed_{seed}")
        data = TransitionData(os.path.join(msa_folder, "msa_data.tar.gz"))
        f1 = Factor("f1", [0])
        f2 = Factor("f2", [1])
        factors = [f1, f2]
        changed_factors = None
        if use_msa_for_independence:
            msa = MSAFlat(msa_folder).cpu()
    else:
        file = os.path.join(save_folder, "data.tar.gz")
        if not os.path.exists(file):
            env = MNISTHyperGrid(dimensions=grid_size, eps=eps)
            data = collect_data(env=env,
                                max_episode=n_sample,
                                max_timestep_per_ep=20,
                                verbose=True,
                                seed=seed,
                                n_jobs=5)
            data.to_pickle(file)
        else:
            data = TransitionData(file)

        f1 = Factor("f1", list(range(784)))
        f2 = Factor("f2", list(range(784, 2*784)))
        factors = [f1, f2]
        changed_factors = {
            0: [f1],
            1: [f1],
            2: [f2],
            3: [f2]
        }

    mdp, _, _ = build(transition_data=data,
                      option_names=[0, 1, 2, 3],
                      factors=factors,
                      transition_error_delta=t_delta,
                      reward_error_delta=r_delta,
                      min_transition_error=min_t,
                      min_reward_error=min_r,
                      changed_factors=changed_factors,
                      n_cluster_trials=n_cluster,
                      min_samples=min_samples,
                      msa=msa)
    mdp.save(os.path.join(save_folder, "mdp.pkl"))
    print("Refinement finished.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Build abstract MDP on MNIST grid.")
    parser.add_argument("--n-sample", help="Number of samples", type=int, required=True)
    parser.add_argument("--grid-size", help="Size in each dimension", nargs="+", type=int, required=True)
    parser.add_argument("--eps", help="The stochasticity rate of the environment", type=float, required=True)
    parser.add_argument("--save-folder", type=str, required=True)
    parser.add_argument("--name", type=str, required=True)
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--t-delta", help="Minimum transition error delta for refinement", type=float)
    parser.add_argument("--r-delta", help="Minimum reward error delta for refinement", type=float)
    parser.add_argument("--min-t", help="Minimum transition error for a state", type=float)
    parser.add_argument("--min-r", help="Minimum reward error for a state", type=float)
    parser.add_argument("--n-cluster", help="Number of cluster trials", type=int, default=1)
    parser.add_argument("--min-samples", help="Minimum samples for a state", type=int, default=10)
    parser.add_argument("--msa-encoding", help="Use MSA encodings", action="store_true")
    parser.add_argument("--use-msa-ind", help="Whether to use MSA for independence test", action="store_true")
    args = parser.parse_args()

    main(save_folder=args.save_folder,
         name=args.name,
         n_sample=args.n_sample,
         grid_size=args.grid_size,
         eps=args.eps,
         seed=args.seed,
         t_delta=args.t_delta,
         r_delta=args.r_delta,
         min_t=args.min_t,
         min_r=args.min_r,
         n_cluster=args.n_cluster,
         min_samples=args.min_samples,
         msa_encoding=args.msa_encoding,
         use_msa_for_independence=args.use_msa_ind)
