import argparse
import os

import torch
import numpy as np

from core.build import build
from core.data import Factor, TransitionData
from core.msa import MSAFlat


def main(save_folder, name, n_sample, seed, t_delta, r_delta,
         min_t, min_r, n_cluster, min_samples, use_msa_for_independence):
    torch.manual_seed(seed)
    np.random.seed(seed)

    changed_factors = None
    msa = None
    msa_folder = None
    save_folder = "__".join([name] + save_folder.split("__")[1:])
    save_folder = os.path.join("save", save_folder)
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)

    msa_folder = os.path.join("save", f"{name}__nsample_{n_sample}__seed_{seed}")
    data = TransitionData(os.path.join(msa_folder, "msa_data.tar.gz"))
    # modify available options here.
    data._data.options_available = [(1, 1, 1, 1, 1) for _ in range(n_sample)]
    data._data.next_options_available = [(1, 1, 1, 1, 1) for _ in range(n_sample)]
    f1 = Factor("f1", [0, 1])
    factors = [f1]
    changed_factors = None
    if use_msa_for_independence:
        msa = MSAFlat(msa_folder).cpu()

    mdp, _, _ = build(transition_data=data,
                      option_names=["gold", "mid", "c", "ac", "d"],
                      factors=factors,
                      transition_error_delta=t_delta,
                      reward_error_delta=r_delta,
                      min_transition_error=min_t,
                      min_reward_error=min_r,
                      changed_factors=changed_factors,
                      n_cluster_trials=n_cluster,
                      min_samples=min_samples,
                      msa=msa)
    mdp.save(os.path.join(save_folder, "mdp.pkl"))
    print("Refinement finished.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Build abstract MDP on MNIST grid.")
    parser.add_argument("--n-sample", help="Number of samples", type=int, required=True)
    parser.add_argument("--res", help="Observation resolution", nargs="+", type=int, required=True)
    parser.add_argument("--add-portraits", action="store_true")
    parser.add_argument("--save-folder", type=str, required=True)
    parser.add_argument("--name", type=str, required=True)
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--t-delta", help="Minimum transition error delta for refinement", type=float)
    parser.add_argument("--r-delta", help="Minimum reward error delta for refinement", type=float)
    parser.add_argument("--min-t", help="Minimum transition error for a state", type=float)
    parser.add_argument("--min-r", help="Minimum reward error for a state", type=float)
    parser.add_argument("--n-cluster", help="Number of cluster trials", type=int, default=1)
    parser.add_argument("--min-samples", help="Minimum samples for a state", type=int, default=10)
    parser.add_argument("--use-msa-ind", help="Whether to use MSA for independence test", action="store_true")
    args = parser.parse_args()

    main(save_folder=args.save_folder,
         name=args.name,
         n_sample=args.n_sample,
         seed=args.seed,
         t_delta=args.t_delta,
         r_delta=args.r_delta,
         min_t=args.min_t,
         min_r=args.min_r,
         n_cluster=args.n_cluster,
         min_samples=args.min_samples,
         use_msa_for_independence=args.use_msa_ind)
