import argparse
import json
import os
import sys
import numpy as np
import tqdm
from tensorboardX import SummaryWriter
from sklearn.manifold import Isomap
import matplotlib.pyplot as plt

FILE_DIR = os.path.dirname(os.path.realpath(__file__))
ROOT_DIR = os.path.dirname(FILE_DIR)
RES_DIR = os.path.join(ROOT_DIR, "results")
MAIN_DATA_DIR = os.path.join(ROOT_DIR, "data")
sys.path.append(FILE_DIR)
sys.path.append(ROOT_DIR)

from modules.velap.rl.model_encoder import StateEncoder
from modules.velap.dynamics.model_dynamics import DynamicsModel
from modules.velap.rl.td3_bc_encoder import TD3BCEncoder
from modules.utils import batch_to_torch
from modules.velap.dataset.data_rl_encoder import EncoderTrainData, EncoderEvalData


def train():
    # Load data
    exp_dir = os.path.join(RES_DIR, args.exp_name)
    data_dir_train = os.path.join(MAIN_DATA_DIR, args.dataset_train)
    data_dir_eval = os.path.join(MAIN_DATA_DIR, args.dataset_eval)
    model_dir = os.path.join(exp_dir, "encoder", "model")
    log_dir = os.path.join(exp_dir, "encoder", "log")
    fig_dir = os.path.join(exp_dir, "encoder", "figures")

    # Load data
    dataset_train = EncoderTrainData(data_dir_train, n_step_prediction=args.n_step_prediction)
    dataset_eval = EncoderEvalData(data_dir_eval)
    args.action_dim = dataset_train.action_dim

    # Create encoder
    model_enc = StateEncoder(z_dim=args.z_dim,
                             channel_dim=args.channel_dim,
                             n_frames=args.n_frames,
                             use_batch_norm=args.enc_batch_norm).to(args.device)

    # Create dynamics model
    model_dyn = DynamicsModel(z_dim=args.z_dim,
                              action_dim=args.action_dim).to(args.device)

    # Create policy
    model_policy = TD3BCEncoder(z_dim=args.z_dim,
                                action_dim=args.action_dim,
                                goal_dim=0,
                                w_rl_high=args.w_rl_high,
                                w_rl_low=args.w_rl_low,
                                w_bc_low=args.w_bc_low,
                                w_bc_exp_low=args.w_bc_exp_low,
                                w_bc_high=args.w_bc_high,
                                w_bc_exp_high=args.w_bc_exp_high,
                                w_dyn=args.w_dyn,
                                T_contr=args.T_contr,
                                model_enc=model_enc,
                                model_dyn=model_dyn,
                                lr_policy=args.lr_policy,
                                lr_critic=args.lr_critic,
                                n_qs=args.n_qs,
                                dyn_stop_grad=args.dyn_stop_grad,
                                dyn_loss_type=args.dyn_loss_type,
                                add_neg_noise_samples=args.add_neg_noise_samples,
                                lmda=args.lmda,
                                discount=args.discount)

    # Create summary writer
    writer = SummaryWriter(log_dir)

    # Main perception loop
    for i_iter in tqdm.tqdm(range(args.n_iters)):

        batch = dataset_train.sample_batch(batch_size=args.batch_size)
        batch_t = batch_to_torch(batch, args.device)

        # Train step
        metrics = model_policy.train_rl_and_embedding(batch_t)

        # Save model
        if (i_iter == 0) or (not (i_iter + 1) % args.save_every):
            model_policy.save(model_dir)

        # Make summary
        if (i_iter == 0) or (not (i_iter + 1) % args.summary_every):
            for key, value in metrics.items():
                writer.add_scalar("encoder/" + key, value, i_iter)

        # Eval
        if (i_iter == 0) or (not (i_iter + 1) % args.eval_every):

            model_enc.eval()

            f, ax = plt.subplots(2, dataset_eval.n_contexts, figsize=(4 * dataset_eval.n_contexts, 8))

            for c_id in range(dataset_eval.n_contexts):
                batches_eval = dataset_eval.sample_context_data(c_id=c_id)
                batch_size_eval = 128
                z_eval = []
                for i_batch in range(len(batches_eval["obs"]) // batch_size_eval - 1):
                    batch = {"obs": batches_eval["obs"][i_batch * batch_size_eval:
                                                        (i_batch + 1) * batch_size_eval],
                             "prop": batches_eval["prop"][i_batch * batch_size_eval:
                                                          (i_batch + 1) * batch_size_eval]}

                    batch_t = batch_to_torch(batch, device=args.device)

                    # Encode obs
                    z_t = model_policy.model_enc(batch_t["obs"])
                    z_eval.append(z_t.cpu().detach().numpy().copy())

                z_eval = np.concatenate(z_eval)
                robot_state_eval = batches_eval["prop"][:len(z_eval)]

                r = (robot_state_eval[:, 0] - dataset_eval.min_x) / (dataset_eval.max_x - dataset_eval.min_x)
                g = np.zeros(len(r))
                b = (robot_state_eval[:, 1] - dataset_eval.min_y) / (dataset_eval.max_y - dataset_eval.min_y)
                c = np.stack([r, g, b], axis=1)
                c = np.clip(c, 0.0, 0.99)

                embedding = Isomap(n_neighbors=20, n_components=2)
                embedding.fit(z_eval[:min(3000, len(z_eval))])
                z_eval_iso = embedding.transform(z_eval)

                if dataset_eval.n_contexts > 1:
                    ax[0, c_id].scatter(robot_state_eval[:, 0], robot_state_eval[:, 1], c=c)
                    ax[1, c_id].scatter(z_eval_iso[:, 0], z_eval_iso[:, 1], c=c)
                else:
                    ax[0].scatter(robot_state_eval[:, 0], robot_state_eval[:, 1], c=c)
                    ax[1].scatter(z_eval_iso[:, 0], z_eval_iso[:, 1], c=c)
            plt.savefig(os.path.join(fig_dir, str(i_iter) + ".png"), bbox_inches='tight', pad_inches=0)
            plt.close(f)

            model_enc.train()

    # Save final model
    model_policy.save(model_dir)


if __name__ == '__main__':
    # Parse arguments
    parser = argparse.ArgumentParser()

    parser.add_argument('--exp_name', type=str, default="test")
    parser.add_argument('--env', type=str, default="sponge_env")
    parser.add_argument('--dataset_train', type=str, default="sponge_env")
    parser.add_argument('--dataset_eval', type=str, default="sponge_env")
    parser.add_argument('--z_dim', type=int, default=32)
    parser.add_argument('--lr_policy', type=float, default=3e-4)
    parser.add_argument('--lr_critic', type=float, default=3e-4)

    parser.add_argument('--n_frames', type=int, default=3)
    parser.add_argument('--channel_dim', type=int, default=3, help="number of image channels")
    parser.add_argument('--max_action', type=float, default=1.0)
    parser.add_argument('--n_iters', type=int, default=int(5e4))
    parser.add_argument('--discount', type=float, default=0.96)
    parser.add_argument('--n_step_prediction', type=int, default=3)
    parser.add_argument('--w_bc_low', type=float, default=0.001)
    parser.add_argument('--w_bc_exp_low', type=float, default=0.0)
    parser.add_argument('--w_bc_high', type=float, default=0.001)
    parser.add_argument('--w_bc_exp_high', type=float, default=0.5)
    parser.add_argument('--w_dyn', type=float, default=0.01)
    parser.add_argument('--w_rl_low', type=float, default=1.0)
    parser.add_argument('--w_rl_high', type=float, default=1.0)
    parser.add_argument('--T_contr', type=float, default=1.0)
    parser.add_argument('--lmda', type=float, default=1.0)
    parser.add_argument('--n_qs', type=int, default=3)

    parser.add_argument('--device', type=str, default="cuda")
    parser.add_argument('--eval_every', type=int, default=1000)
    parser.add_argument('--summary_every', type=int, default=1000)
    parser.add_argument('--save_every', type=int, default=1000)
    parser.add_argument('--enc_batch_norm', type=int, default=1)
    parser.add_argument('--dyn_stop_grad', type=int, default=0)
    parser.add_argument('--dyn_loss_type', type=str, default="contrastive")
    parser.add_argument('--vae_latent_dim', type=int, default=16)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--add_neg_noise_samples', type=int, default=1)

    args = parser.parse_args()

    # Create folder
    os.makedirs(os.path.join(RES_DIR, args.exp_name, "encoder/log"), exist_ok=True)
    os.makedirs(os.path.join(RES_DIR, args.exp_name, "encoder/model"), exist_ok=True)
    os.makedirs(os.path.join(RES_DIR, args.exp_name, "encoder/figures"), exist_ok=True)

    # Store parameter to json
    dict = vars(args)
    with open(os.path.join(RES_DIR, args.exp_name, "encoder/params.json"), 'w') as json_file:
        json.dump(dict, json_file, sort_keys=True, indent=2)

    # Train encoder
    train()
