import argparse
import os

import numpy as np
import torch
import yaml
from tqdm import tqdm

from core.explore import collect_data
from core.msa import MSAFlat, MSADatasetFromDataframe
from envs.mnist_grid import MNISTHyperGrid
from envs.vault import Vault


def train(env, msa_config, save_folder, n_sample, val_split, seed, name):
    np.random.seed(seed)
    torch.manual_seed(seed)
    save_folder = "__".join([name] + save_folder.split("__")[1:])
    save_folder = os.path.join("save", save_folder)

    data = collect_data(env, max_timestep=n_sample)
    data = data._data

    if not os.path.exists(save_folder):
        os.makedirs(save_folder)

    train_set = MSADatasetFromDataframe(df=data, validation=False, val_split=val_split)
    val_set = MSADatasetFromDataframe(df=data, validation=True, val_split=val_split)

    # train the model
    config = yaml.safe_load(open(msa_config, "r"))
    msa = MSAFlat(config)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=config["batch_size"], shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=config["batch_size"], shuffle=False)
    msa.fit(train_loader, val_loader, config, save_folder)

    # convert the dataset
    dataset = MSADatasetFromDataframe(df=data, validation=False, val_split=0.0)
    loader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=False)
    state = []
    next_state = []
    with torch.inference_mode():
        for x, _, x_ in tqdm(loader):
            z = msa.encode(x)
            z_ = msa.encode(x_)
            state.append(z)
            next_state.append(z_)
    state = torch.cat(state, dim=0).tolist()
    next_state = torch.cat(next_state, dim=0).tolist()
    data["state"] = state
    data["next_state"] = next_state
    data.to_pickle(os.path.join(save_folder, "msa_data.tar.gz"), compression="gzip")


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Train MSA on MNIST grid.")
    parser.add_argument("--env", help="'mnist' or 'maze'", type=str, required=True)
    parser.add_argument("--msa-config", help="Path to MSA configuration file", type=str, required=True)
    parser.add_argument("--n-sample", help="Number of samples", type=int, required=True)
    parser.add_argument("--grid-size", help="[MNIST] Size in each dimension", nargs="+", type=int)
    parser.add_argument("--eps", help="[MNIST] The stochasticity rate of the environment", type=float)
    parser.add_argument("--res", help="[Maze] Observation resolution", nargs="+", type=int)
    parser.add_argument("--add-portraits", help="[Maze] Add portraits to walls", action="store_true")
    parser.add_argument("--val-split", help="Validation split ratio", type=float, required=True)
    parser.add_argument("--save-folder", type=str, required=True)
    parser.add_argument("--name", type=str, required=True)
    parser.add_argument("--seed", type=int, required=True)
    args = parser.parse_args()

    if args.env == "mnist":
        env = MNISTHyperGrid(args.grid_size, eps=args.eps)
    elif args.env == "maze":
        env = Vault(add_portraits=args.add_portraits,
                    obs_width=args.res[0],
                    obs_height=args.res[1])
    else:
        raise ValueError(f"Unknown environment {args.env}")

    train(env=env,
          msa_config=args.msa_config,
          save_folder=args.save_folder,
          n_sample=args.n_sample,
          val_split=args.val_split,
          seed=args.seed,
          name=args.name)
