import json
import os
import numpy as np
import sys
import argparse
import torch
import torch.nn.functional as F
import tqdm
from tensorboardX import SummaryWriter

FILE_DIR = os.path.dirname(os.path.realpath(__file__))
ROOT_DIR = os.path.dirname(FILE_DIR)
RES_DIR = os.path.join(ROOT_DIR, "results")
MAIN_DATA_DIR = os.path.join(ROOT_DIR, "data")
sys.path.append(ROOT_DIR)

from modules.velap.action_sampler.model_action_sampler import ActionSampler
from modules.utils import batch_to_torch


def train():
    exp_dir = os.path.join(RES_DIR, args.exp_name)
    model_dir = os.path.join(exp_dir,"action_sampler/model")
    log_dir = os.path.join(exp_dir,"action_sampler/log")

    # Load parameters
    with open(os.path.join(exp_dir,"encoder", "params.json")) as f:
        params = json.load(f)

    args.dataset_train = params["dataset_train"]
    data_dir = os.path.join(MAIN_DATA_DIR, args.dataset_train)

    # Dataset
    z_all = np.load(os.path.join(exp_dir,"z_all.npy"), allow_pickle=True)
    z_all = np.concatenate(z_all)
    action_all = np.load(os.path.join(data_dir,"actions_all.npy"), allow_pickle=True)
    action_all = np.concatenate(action_all)
    args.state_dim = z_all[0].shape[-1]
    args.action_dim = action_all[0].shape[-1]
    traj_length = np.array([len(t) for t in z_all])

    def sample_batch(batch_size=64):

        traj_ids = np.random.randint(0, len(z_all), batch_size)
        step_ids = np.random.randint(0, traj_length[traj_ids]-1, batch_size)

        action = []
        state = []
        for i in range(batch_size):
            state.append(z_all[traj_ids[i]][step_ids[i]])
            action.append(action_all[traj_ids[i]][step_ids[i]])

        batch = {
            "action": np.array(action),
            "state": np.array(state)
        }
        return batch

    # Create VAE model
    action_sampler = ActionSampler(args.state_dim,
                                   args.action_dim,
                                   args.vae_latent_dim,
                                   args.device).to(args.device)
    if args.load_pretrained:
        action_sampler.load_state_dict(torch.load(os.path.join(exp_dir,"encoder/model/model_action_sampler"),
                                                  map_location=args.device), strict=True)

    # Create optimizer
    optimizer = torch.optim.Adam(action_sampler.parameters(), lr=args.vae_lr)

    # Create summary writer
    writer = SummaryWriter(log_dir)

    # Train loop
    for i_iter in tqdm.tqdm(range(args.n_iters)):

        batch = sample_batch(args.batch_size)
        batch_t = batch_to_torch(batch, args.device)

        # Variational Auto-Encoder Training
        recon, mean, std = action_sampler(batch_t["state"], batch_t["action"])
        recon_loss = F.mse_loss(recon, batch_t["action"])
        KL_loss	= -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
        vae_loss = recon_loss + args.vae_beta * KL_loss

        optimizer.zero_grad()
        vae_loss.backward()
        optimizer.step()

        # Make summary
        if (i_iter == 0) or (not i_iter % args.summary_every):
            writer.add_scalar("recon_loss", recon_loss, i_iter)
            writer.add_scalar("KL_loss", KL_loss, i_iter)
            writer.add_scalar("vae_loss", vae_loss, i_iter)
            writer.add_histogram("mean", mean, i_iter)
            writer.add_histogram("std", std, i_iter)
            writer.add_histogram("recon", recon, i_iter)
            writer.add_histogram("gt", batch_t["action"], i_iter)

        # Save model
        if (i_iter == 0) or (not i_iter % args.save_every):
            action_sampler.save(os.path.join(model_dir, "action_sampler"))

        if (i_iter == 0) or (not i_iter % args.eval_every):
            action_sampler.eval()

            batch = sample_batch(16)
            batch_t = batch_to_torch(batch, args.device)
            z_t = torch.repeat_interleave(batch_t["state"], 100, 0)
            a_t = action_sampler.decode(z_t)
            a_t = torch.reshape(a_t, (16,100, args.action_dim))
            a_vars = torch.std(a_t, dim=1).sum(-1).detach().cpu().numpy()
            writer.add_histogram("vars", a_vars, i_iter)

            action_sampler.train()

    action_sampler.save(os.path.join(model_dir, "action_sampler"))


if __name__ == '__main__':

    # Parse arguments
    parser = argparse.ArgumentParser()

    parser.add_argument('--exp_name', type=str, default="bench/spiral_env_0")
    parser.add_argument('--vae_latent_dim', type=int, default=16)
    parser.add_argument('--n_iters', type=int, default=int(1e5))
    parser.add_argument('--summary_every', type=int, default=1000)
    parser.add_argument('--save_every', type=int, default=1000)
    parser.add_argument('--eval_every', type=int, default=1000)

    parser.add_argument('--vae_lr', type=float, default=1e-3)
    parser.add_argument('--vae_beta', type=float, default=0.0001)

    parser.add_argument('--load_pretrained', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--device', type=str, default="cuda")

    args = parser.parse_args()

    # Create folder
    os.makedirs(os.path.join(RES_DIR, args.exp_name,"action_sampler/log"), exist_ok=True)
    os.makedirs(os.path.join(RES_DIR, args.exp_name,"action_sampler/model"), exist_ok=True)

    # Store parameter to json
    dict = vars(args)
    with open(os.path.join(RES_DIR, args.exp_name,"action_sampler/params.json"), 'w') as json_file:
        json.dump(dict, json_file, sort_keys=True, indent=2)

    train()
