import argparse
import json
import os
import sys
import tqdm
from tensorboardX import SummaryWriter

FILE_DIR = os.path.dirname(os.path.realpath(__file__))
ROOT_DIR = os.path.dirname(FILE_DIR)
RES_DIR = os.path.join(ROOT_DIR, "results")
MAIN_DATA_DIR = os.path.join(ROOT_DIR, "data")
sys.path.append(ROOT_DIR)

from modules.velap.dataset.data_rl_high import PolicyHighData
from modules.velap.rl.td3_bc import TD3BC
from modules.utils import batch_to_torch


def train():
    exp_dir = os.path.join(RES_DIR, args.exp_name)

    # Load parameters
    with open(os.path.join(exp_dir, "encoder", "params.json")) as f:
        params = json.load(f)
        args.discount = params["discount"]
        args.max_action = params["max_action"]
        args.dataset_train = params["dataset_train"]

    data_dir = os.path.join(MAIN_DATA_DIR, args.dataset_train)
    model_dir = os.path.join(exp_dir, "policy_high", "model")
    log_dir = os.path.join(exp_dir, "policy_high" , "log")
    pretrain_model_dir = os.path.join(exp_dir, "encoder","model")

    # Create dataset
    dataset = PolicyHighData(data_dir, exp_dir, goal_conditioned=args.goal_conditioned)
    args.action_dim = dataset.action_dim
    args.state_dim = dataset.state_dim
    args.goal_dim = dataset.goal_dim
    args.action_dim = dataset.action_dim

    # Create policy
    model_policy = TD3BC(z_dim=args.state_dim,
                         goal_dim=args.goal_dim,
                         action_dim=args.action_dim,
                         max_action=args.max_action,
                         discount=args.discount,
                         lr_policy=args.lr_policy,
                         lr_critic=args.lr_critic,
                         w_bc=args.w_bc,
                         w_bc_exp=args.w_bc_exp)

    # Load model from joint training
    if args.load_pretrained:
        model_policy.load(pretrain_model_dir, type="_high")

    # Create summary writer
    writer = SummaryWriter(log_dir)

    # Main perception loop
    for i_iter in tqdm.tqdm(range(args.n_iters)):

        batch = dataset.sample_batch(batch_size=args.batch_size)
        batch_t = batch_to_torch(batch, args.device)

        # Train step
        metrics = model_policy.train_rl(batch_t)

        # Save model
        if (i_iter == 0) or (not (i_iter + 1) % args.save_every):
            model_policy.save(model_dir)
            model_policy.save(model_dir, "model_%s" % str(i_iter))

        # Make summary
        if (i_iter == 0) or (not (i_iter + 1) % args.summary_every):
            for key, value in metrics.items():
                writer.add_scalar(key, value.item(), i_iter)

    # Save final model
    model_policy.save(model_dir)
    model_policy.save(model_dir, "model_%s" % str(i_iter))


if __name__ == '__main__':
    # Parse arguments
    parser = argparse.ArgumentParser()

    parser.add_argument('--exp_name', type=str, default="spiral_env_0")
    parser.add_argument('--n_iters', type=int, default=int(25e4))
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--device', type=str, default="cuda")
    parser.add_argument('--save_every', type=int, default=50000)
    parser.add_argument('--summary_every', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--lr_critic', type=float, default=3e-4)
    parser.add_argument('--lr_policy', type=float, default=3e-4)
    parser.add_argument('--w_bc', type=float, default=0.001)
    parser.add_argument('--w_bc_exp', type=float, default=0.5)
    parser.add_argument('--goal_conditioned', type=int, default=0)
    parser.add_argument('--load_pretrained', type=int, default=1)

    args = parser.parse_args()

    # Create folder
    os.makedirs(os.path.join(RES_DIR, args.exp_name, "policy_high/log"), exist_ok=True)
    os.makedirs(os.path.join(RES_DIR, args.exp_name, "policy_high/model" ), exist_ok=True)
    os.makedirs(os.path.join(RES_DIR, args.exp_name, "policy_high/figures" ), exist_ok=True)

    # Store parameter to json
    dict = vars(args)
    with open(os.path.join(RES_DIR, args.exp_name, "policy_high/params.json"), 'w') as json_file:
        json.dump(dict, json_file, sort_keys=True, indent=2)

    # Train encoder
    train()
