import json
import os
import sys
import argparse
import numpy as np
import torch
import tqdm
from tensorboardX import SummaryWriter

FILE_DIR = os.path.dirname(os.path.realpath(__file__))
ROOT_DIR = os.path.dirname(FILE_DIR)
RES_DIR = os.path.join(ROOT_DIR, "results")
MAIN_DATA_DIR = os.path.join(ROOT_DIR, "data")
sys.path.append(FILE_DIR)
sys.path.append(ROOT_DIR)

from modules.velap.dynamics.model_dynamics import DynamicsModel
from modules.utils import contr_loss
import torch.nn.functional as F


def train():
    exp_dir = os.path.join(RES_DIR, args.exp_name)

    model_dir = os.path.join(exp_dir, "dynamics", "model")
    log_dir = os.path.join(exp_dir, "dynamics", "log")

    # Load parameters
    with open(os.path.join(exp_dir, "encoder", "params.json")) as f:
        params = json.load(f)

    args.z_dim = params["z_dim"]
    args.lmda = params["lmda"]
    args.dataset = params["dataset_train"]
    args.dyn_loss_type = params["dyn_loss_type"]
    args.n_step_prediction = params["n_step_prediction"]
    args.T_contr = params["T_contr"]

    # Load data
    data_dir = os.path.join(MAIN_DATA_DIR, args.dataset)
    z_all = np.load(os.path.join(exp_dir, "z_all.npy"), allow_pickle=True)
    actions_all = np.load(os.path.join(data_dir, "actions_all.npy"), allow_pickle=True)
    args.action_dim = actions_all[0][0].shape[-1]
    mean_dist = np.load(os.path.join(exp_dir, "z_stats.npy"))[0]

    # Create model
    model_dyn = DynamicsModel(z_dim=args.z_dim,
                              action_dim=args.action_dim).to(args.device)

    if args.load_pretrained:
        model_dyn.load_state_dict(torch.load(os.path.join(exp_dir, "encoder/model/model_dynamics"),
                                             map_location=args.device), strict=True)

    # Set optimizer
    optimizer = torch.optim.Adam(model_dyn.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    # Create summary writer
    writer = SummaryWriter(log_dir)

    n_context = len(z_all)
    traj_length = np.array([[len(t) for t in z_all[i]] for i in range(n_context)])
    valid_traj = traj_length > args.n_step_prediction
    valid_contexts = np.array(valid_traj).sum(-1) > 0
    valid_context_ids = np.arange(0, n_context)[valid_contexts]
    valid_trj_ids = [np.where(valid_traj[i])[0] for i in range(n_context)]

    def sample_batch():

        z_n = []
        z_n_neg = []
        z_neg = []
        action_n = []

        c_ids = np.random.choice(valid_context_ids, size=args.batch_size)
        for i in range(args.batch_size):
            c_id = c_ids[i]
            t_id = np.random.choice(valid_trj_ids[c_id])
            s_id = np.random.randint(0, traj_length[c_id][t_id] - args.n_step_prediction)

            z_n_s = []
            z_n_neg_s = []
            actions_n_s = []
            for i_h in range(1 + args.n_step_prediction):
                z_h = z_all[c_id][t_id][s_id + i_h].copy()
                z_n_s.append(z_h.copy())

                z_n_noisy_h = z_h + np.random.randn(args.z_dim) * mean_dist * np.random.randint(1, 5)
                z_n_neg_s.append(z_n_noisy_h.copy())

                if i_h < args.n_step_prediction:
                    actions_n_s.append(actions_all[c_id][t_id][s_id + i_h].copy())

            z_n.append(z_n_s)
            z_n_neg.append(z_n_neg_s)
            action_n.append(actions_n_s)

            # Sample negative example from same context
            t_id_neg = np.random.choice(valid_trj_ids[c_id])
            s_id_neg = np.random.randint(0, traj_length[c_id][t_id_neg] - args.n_step_prediction)
            z_neg.append(z_all[c_id][t_id_neg][s_id_neg].copy())

        z_n = np.array(z_n)
        action_n = np.array(action_n)
        z_n_neg = np.array(z_n_neg)
        z_neg = np.array(z_neg)
        return z_n, action_n, z_n_neg, z_neg

    # Train loop
    for i_iter in tqdm.tqdm(range(args.n_iters)):

        # Sample batch
        z_n_step, action_n_step, z_n_step_neg, z_neg = sample_batch()
        z_n_step = torch.from_numpy(z_n_step.astype(np.float32)).to(args.device)
        action_n_step = torch.from_numpy(action_n_step.astype(np.float32)).to(args.device)
        z_n_step_neg = torch.from_numpy(z_n_step_neg.astype(np.float32)).to(args.device)
        z_neg = torch.from_numpy(z_neg.astype(np.float32)).to(args.device)

        # Predict
        z_tmp = z_n_step[:, 0]
        z_pred = []
        lmda = args.lmda ** np.arange(0, args.n_step_prediction)
        for i in range(args.n_step_prediction):
            z_tmp = z_tmp + model_dyn(z_tmp, action_n_step[:, i])
            z_pred.append(z_tmp)
        z_pred = torch.stack(z_pred, 0)
        z_gt = z_n_step[:, 1:].transpose(1, 0)
        z_n_step_neg = z_n_step_neg[:, 1:].transpose(1, 0)

        if args.dyn_loss_type == "mse":
            dyn_loss = [lmda[i] * F.mse_loss(z_pred[i], z_gt[i]) for i in range(args.n_step_prediction)]
            dyn_loss = torch.stack(dyn_loss).mean()
        elif args.dyn_loss_type == "contrastive":
            bs = z_gt[0].shape[0]
            dyn_loss = [lmda[i] * contr_loss(anchor=z_gt[i],
                                             pos=z_pred[i],
                                             neg=[z_gt[i][torch.randperm(bs).to(args.device)],
                                                  z_n_step_neg[i],
                                                  z_neg], T=args.T_contr) for i in range(args.n_step_prediction)]
            dyn_loss = torch.stack(dyn_loss).mean()
        else:
            raise NotImplementedError

        optimizer.zero_grad()
        dyn_loss.backward()
        optimizer.step()

        # Save model
        if (i_iter == 0) or (not i_iter % args.save_every):
            torch.save(model_dyn.state_dict(), os.path.join(model_dir, "model_dynamics_%s" % str(i_iter)))

        # Make summary
        if (i_iter == 0) or (not i_iter % args.summary_every):
            writer.add_scalar("dynamics/loss", dyn_loss.item(), i_iter)

    # Save final model
    torch.save(model_dyn.state_dict(), os.path.join(model_dir, "model_dynamics"))


if __name__ == '__main__':
    # Parse arguments
    parser = argparse.ArgumentParser()

    parser.add_argument('--exp_name', type=str, default="bench/spiral_env_0")

    parser.add_argument('--n_iters', type=int, default=50000)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--device', type=str, default="cuda")
    parser.add_argument('--save_every', type=int, default=10000)
    parser.add_argument('--summary_every', type=int, default=1000)
    parser.add_argument('--load_pretrained', type=int, default=1)

    args = parser.parse_args()

    # Create folder
    os.makedirs(os.path.join(RES_DIR, args.exp_name, "dynamics", "log"), exist_ok=True)
    os.makedirs(os.path.join(RES_DIR, args.exp_name, "dynamics", "model"), exist_ok=True)

    # Store parameter to json
    dict = vars(args)
    with open(os.path.join(RES_DIR, args.exp_name, "dynamics/params.json"), 'w') as json_file:
        json.dump(dict, json_file, sort_keys=True, indent=2)

    train()
