import os
import pickle
import sys
import numpy as np
import torch
import argparse
import tqdm
import matplotlib.pyplot as plt
import json
from sklearn.manifold import Isomap

FILE_DIR = os.path.dirname(os.path.realpath(__file__))
ROOT_DIR = os.path.dirname(FILE_DIR)
RES_DIR = os.path.join(ROOT_DIR, "results")
MAIN_DATA_DIR = os.path.join(ROOT_DIR, "data")
sys.path.append(ROOT_DIR)

from modules.velap.rl.model_encoder import StateEncoder


def main():

    exp_dir = os.path.join(RES_DIR, args.exp_name)
    model_dir = os.path.join(exp_dir,"encoder/model")

    # Opening JSON file
    with open(os.path.join(exp_dir,"encoder", "params.json")) as f:
        params = json.load(f)

    z_dim = params["z_dim"]
    channel_dim = params["channel_dim"]
    n_frames = params["n_frames"]
    dataset = params["dataset_train"]
    dataset_eval = params["dataset_eval"]
    enc_batch_norm = params["enc_batch_norm"]

    data_dir = os.path.join(MAIN_DATA_DIR,dataset)
    data_dir_eval = os.path.join(MAIN_DATA_DIR,dataset_eval)

    # Load data
    if not args.eval_only:
        state_video_all = np.load(os.path.join(data_dir,"obs_all.npy"), allow_pickle=True)
        prop_state_all = np.load(os.path.join(data_dir, "proprioception_all.npy"), allow_pickle=True)
    else:
        # Load eval data
        state_video_all = np.load(os.path.join(data_dir_eval, "obs_all.npy"), allow_pickle=True)
        prop_state_all = np.load(os.path.join(data_dir_eval, "proprioception_all.npy"), allow_pickle=True)

    # Load encoder
    encoder = StateEncoder(z_dim=z_dim,
                           channel_dim=channel_dim ,
                           n_frames=n_frames,
                           use_batch_norm=enc_batch_norm)
    encoder.load_state_dict(torch.load(os.path.join(model_dir,"model_encoder"), map_location=args.device), strict=True)
    encoder.to(args.device)
    encoder.eval()

    if not args.eval_only:
        # Encode train obs
        dists_all = []
        z_all = []
        dist_eucl_all = []
        for i in tqdm.tqdm(range(len(state_video_all))):
            z_context = []
            for j in range(len(state_video_all[i])):
                video_state = state_video_all[i][j].transpose(0, 3, 1, 2)
                video_state_t = torch.from_numpy(video_state.astype(np.float32)).to(args.device) / 255.0
                z_t = encoder(video_state_t)
                z_context.append(z_t.detach().cpu().numpy())

                dists = torch.abs(z_t[1:] - z_t[:-1]).detach().cpu().numpy()
                dists_all.append(dists)
                dist_eucl_all.append(torch.norm(z_t[1:] - z_t[:-1], dim=1).detach().cpu().numpy())

            z_all.append(z_context)

        z_all = np.array(z_all)
        z_all_flat = np.concatenate(np.concatenate(z_all))
        dists_all = np.concatenate(dists_all)
        dist_eucl_all = np.concatenate(dist_eucl_all)

        # Compute mean and std distance
        mean_dist = np.mean(dists_all, axis=0)
        std_dist = np.std(dists_all, axis=0)

        # Get min and max z values
        min_z = np.min(z_all_flat, axis=0)
        max_z = np.max(z_all_flat, axis=0)
        dist_stats = np.stack([mean_dist, std_dist, min_z, max_z], axis=0)

        # Save train latent states and stats
        np.save(os.path.join(exp_dir, "z_all.npy"), z_all)
        np.save(os.path.join(exp_dir, "z_stats.npy"), dist_stats)

        percentiles = {
            "p1": np.percentile(dist_eucl_all, 1),
            "p2": np.percentile(dist_eucl_all, 2),
            "p5": np.percentile(dist_eucl_all, 5),
            "p10": np.percentile(dist_eucl_all, 10),
            "p20": np.percentile(dist_eucl_all, 20),
            "p50": np.percentile(dist_eucl_all, 50),
            "p75": np.percentile(dist_eucl_all, 75),
            "p95": np.percentile(dist_eucl_all, 95),
            "p98": np.percentile(dist_eucl_all, 98)
        }

        np.save(os.path.join(exp_dir, "transition_eucl_dist_percentiles.npy"), list(percentiles.values()))

        if args.plot:
            plt.figure()
            plt.hist(dist_eucl_all, bins=100)
            plt.axvline(x=percentiles["p2"], color='r')
            plt.axvline(x=percentiles["p5"], color='r')
            plt.axvline(x=percentiles["p10"], color='r')
            plt.axvline(x=percentiles["p75"], color='r')
            plt.axvline(x=percentiles["p95"], color='r')
            plt.axvline(x=percentiles["p98"], color='r')

            plt.figure()
            z_norm = np.linalg.norm(z_all_flat, axis=-1)
            plt.hist(z_norm, bins=100)
            plt.show()

    # Compute eval encodings

    # Load eval data
    state_video_all = np.load(os.path.join(data_dir_eval,"obs_all.npy"), allow_pickle=True)
    prop_state_all = np.load(os.path.join(data_dir_eval, "proprioception_all.npy"), allow_pickle=True)

    z_all_context = []
    dz_all_context = []
    z_iso_all_context = []
    embeddings_eval = []
    for i_context in range(len(state_video_all)):

        z_all = []
        dz_all = []
        for i in tqdm.tqdm(range(len(state_video_all[i_context]))):
            video_state = state_video_all[i_context][i].transpose(0, 3, 1, 2)
            video_state_t = torch.from_numpy(video_state.astype(np.float32)).to(args.device) / 255.0
            z_t = encoder(video_state_t)
            z_all.append(z_t.detach().cpu().numpy())
            dz_all.append(torch.norm(z_t[1:]-z_t[:-1], dim=-1).detach().cpu().numpy())

        z_all = np.concatenate(z_all)
        dz_all = np.concatenate(dz_all)
        dz_all_context.append(dz_all.copy())
        z_all_context.append(z_all.copy())

        # Compute Isomap embedding
        embedding = Isomap(n_neighbors=10, n_components=2)
        # embedding.fit(z_all[np.random.randint(0, z_all.shape[0], 5000)])
        embedding.fit(z_all[np.random.randint(0, z_all.shape[0], 500)])
        embeddings_eval.append(embedding)
        z_iso_all = embedding.transform(z_all)
        z_iso_all_context.append(z_iso_all.copy())

    np.save(os.path.join(exp_dir, "z_all_eval.npy"), z_all_context)
    np.save(os.path.join(exp_dir, "z_iso_all_eval.npy"), z_iso_all_context)

    pickle.dump(embeddings_eval, open(os.path.join(exp_dir, "embeddings_eval.sav"), 'wb'))

    if args.plot:
        f, ax = plt.subplots(4, len(z_iso_all_context), figsize=(4*len(z_iso_all_context), 8 ))
        for i_context in range(len(z_iso_all_context)):
            robot_state_eval = np.concatenate(prop_state_all[i_context])
            min_0 = np.min(robot_state_eval[:, 0])
            max_0 = np.max(robot_state_eval[:, 0])
            min_1 = np.min(robot_state_eval[:, 1])
            max_1 = np.max(robot_state_eval[:, 1])
            r = (robot_state_eval[:, 0] - min_0) / (max_0 - min_0)
            g = np.zeros(len(r))
            b = (robot_state_eval[:, 1] - min_1) / (max_1 - min_1)
            c = np.stack([r, g, b], axis=1)
            c = np.clip(c, 0.0, 0.99)

            if len(z_iso_all_context) > 1:
                ax[0][i_context].imshow(np.concatenate(state_video_all[i_context])[0][:,:,0:3])
                ax[1][i_context].scatter(robot_state_eval[:, 0], robot_state_eval[:, 1], c=c)
                ax[2][i_context].scatter(z_iso_all_context[i_context][:, 0], z_iso_all_context[i_context][:, 1], c=c)
                ax[3][i_context].imshow(np.concatenate(state_video_all[i_context])[-1][:,:,0:3])
            else:
                ax[0].imshow(np.concatenate(state_video_all[i_context])[0][:,:,0:3])
                ax[1].scatter(robot_state_eval[:, 0], robot_state_eval[:, 1], c=c)
                ax[2].scatter(z_iso_all_context[i_context][:, 0], z_iso_all_context[i_context][:, 1], c=c)
                ax[3].imshow(np.concatenate(state_video_all[i_context])[-1][:,:,0:3])
        plt.show()


if __name__ == '__main__':

    # Parse arguments
    parser = argparse.ArgumentParser()

    parser.add_argument('--exp_name', type=str, default="spiral_env_0")
    parser.add_argument('--eval_only', type=int, default=0)
    parser.add_argument('--device', type=str, default="cuda")
    parser.add_argument('--plot', type=int, default=1)

    args = parser.parse_args()

    main()
