import os
import sys
sys.path.insert(0, ".")

import random
import inspect
import argparse
import warnings

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from typing import TYPE_CHECKING
if TYPE_CHECKING:  # This is a hack to make VS Code intellisense work
    # from tensorflow.python import keras
    from keras.api._v2 import keras
else:
    keras = tf.keras

from nxcl.config import load_config, ConfigDict

from missing import models, data, metrics

from matplotlib import pyplot as plt
import seaborn as sns

sns.set_theme(style="whitegrid")



parser = argparse.ArgumentParser()
parser.add_argument("-d", "--output-dir", type=str)
parser.add_argument("-t", "--type", choices=["dist", "sample"], default="sample")
parser.add_argument("-i", "--sample-idxs", type=int, nargs="+", default=[0, 1, 2, 3])
parser.add_argument("-o", "--image-output", type=str)
args = parser.parse_args()

sample_idxs = args.sample_idxs
output_dir = args.output_dir

basename = output_dir.split("/")[-1]

if args.image_output is None:
    args.image_output = os.path.join("z_impute", f"{basename}.{args.type}.png")


checkpoint = os.path.join(output_dir, "weights.h5")
config_file = os.path.join(output_dir, "config.yaml")
config: ConfigDict = load_config(config_file)


config.train.batch_size = 32

if config.test.get("seed") is None:
    config.test.seed = 42

random.seed(int(config.test.seed))
np.random.seed(int(config.test.seed))
os.environ["PYTHONHASHSEED"] = str(config.test.seed)
tf.random.set_seed(int(config.test.seed))

# Create model
if not hasattr(models, config.model.name):
    raise ValueError(f"Unknown model: {config.model.name}")

config.model.setdefault("output_activation", config.dataset.output_activation)
config.model.setdefault("output_dims", config.dataset.output_dims)

model_class = getattr(models, config.model.name)
model_argnames = inspect.signature(model_class).parameters.keys()
model: keras.Model = model_class(**{k: v for k, v in config.model.items() if k in model_argnames})
model_config = model.get_config()

for k in model_argnames:
    if k not in config.model:
        warnings.warn(f"Using default model argument: {k} = {model_config[k]}")

# Create dataloader
normalize_fn = data.get_dataset_normalize_fn(config.dataset.name)

if config.dataset.get("remove_feature8"):
    if config.dataset.name != "mimic3_mortality":
        raise ValueError("remove_feature8 is only supported for mimic3_mortality dataset")
    print("==================== !!!!! Remove feature #8 !!!!! ====================")
    def model_preprocess_fn(inputs, label):
        d, t, v, m, l = inputs
        return (d, t, tf.concat((v[:, :7], v[:, 8:]), axis=1), tf.concat((m[:, :7], m[:, 8:]), axis=1), l), label
else:
    model_preprocess_fn = getattr(model, "data_preprocessing_fn", lambda: None)()


valid_preprocess_fn = data.build_preprocess_fn(normalize_fn, model_preprocess_fn, class_weights=None)

# Test
def select_data(data, labels):
    return data

def select_labels(data, labels):
    return labels

test_iterator, _ = data.build_test_iterator(
    dataset_name=config.dataset.name,
    batch_size=config.train.batch_size,
    preprocess_fn=valid_preprocess_fn,
)
label_batches = list(tfds.as_numpy(test_iterator.map(select_labels)))

x, y = next(iter(test_iterator))
model(x, training=False)

print(f"Load weight from {checkpoint}")
model.load_weights(checkpoint)



statics, times, values, measurements, lengths = x
y_prob, aux = model(x, training=False, return_aux=True)

statics = statics.numpy()
times = times.numpy()
values = values.numpy()
measurements = measurements.numpy()
lengths = lengths.numpy()
y_prob = y_prob.numpy()


shape = tf.shape(values)
n_batch = shape[0]
n_times = shape[1]
x_dim = shape[2]


log_w = tf.reshape(tf.transpose(aux["log_w"], perm=(2, 0, 1)), (n_batch, -1)).numpy()
x_mixed = tf.reshape(tf.transpose(aux["x_impute"], perm=(2, 0, 1, 3, 4)), (n_batch, -1, n_times, x_dim)).numpy()
x_mu = tf.transpose(aux["x_mu"], perm=(1, 0, 2, 3)).numpy()
x_sigma = tf.transpose(aux["x_sigma"], perm=(1, 0, 2, 3)).numpy()
w = np.exp(log_w)
w = w / w.max(axis=1, keepdims=True)

if config.model.name == "SupNonMiwaeGRUdecoderModel":
    x_mu = tf.reshape(x_mu, (values.shape[0], -1, values.shape[1], values.shape[2]))
    x_sigma = tf.reshape(x_sigma, (values.shape[0], -1, values.shape[1], values.shape[2]))
    x_mixed = tf.reshape(x_mixed, (values.shape[0], -1, values.shape[1], values.shape[2]))

n_samples = len(sample_idxs)
n_features = values.shape[-1]




import h5py

weights = h5py.File(checkpoint, "r")
encoder_input_decay_kernel = tf.constant(weights["encoder/sup_not_miwae_model/encoder/grud_cell/input_decay_kernel:0"])
encoder_input_decay_bias   = tf.constant(weights["encoder/sup_not_miwae_model/encoder/grud_cell/input_decay_bias:0"])
interp_input_decay_kernel  = tf.constant(weights["interpolator/sup_not_miwae_model/interpolator/decay_cell/decay_kernel:0"])
interp_input_decay_bias    = tf.constant(weights["interpolator/sup_not_miwae_model/interpolator/decay_cell/decay_bias:0"])


def exp_relu(x):
    return tf.exp(-tf.nn.relu(x))

def exp_softplus(x):
    return tf.exp(-tf.nn.softplus(x))


dt = np.linspace(0, 40, 400)[:, None]
encoder_gamma = exp_relu(encoder_input_decay_kernel * dt + encoder_input_decay_bias)
interp_gamma = exp_relu(interp_input_decay_kernel * dt + interp_input_decay_bias)


fig, axes = plt.subplots(nrows=n_features, ncols=(n_samples + 1), figsize=(9 * (n_samples + 1), 3 * n_features), squeeze=False)

for i in range(n_features):
    axes[i, 0].set_ylabel(f"Feature {i + 1}")

for c, idx in enumerate(sample_idxs):
    axes[0, c].set_title(f"Sample {idx}")
    tmax = times[idx, lengths[idx] - 1]
    for i in range(n_features):
        ax = axes[i, c]
        ax.set_xlim(0, tmax)
        ks = [k for k in range(lengths[idx]) if measurements[idx, k, i]]
        ls = list(range(lengths[idx]))
        t_obs, t_all = times[idx, ks], times[idx, ls]
        if args.type == "dist":
            for j in range(x_mu.shape[1]):
                ax.plot(t_all, x_mu[idx, j, ls, i], color="b", alpha=(w[idx, j] * 0.8 + 0.2))
                ax.fill_between(t_all, x_mu[idx, j, ls, i] - x_sigma[idx, j, ls, i],
                                    x_mu[idx, j, ls, i] + x_sigma[idx, j, ls, i], color="b", alpha=(w[idx, j] * 0.3 + 0.1))
        elif args.type == "sample":
            for j in range(x_mixed.shape[1]):
                ax.plot(t_all, x_mixed[idx, j, ls, i], color="b")
        ax.scatter(t_obs, values[idx, ks, i], marker="x", color="r", linewidths=3)

for i in range(n_features):
    axes[i, n_samples].set_title("$\gamma$")
    axes[i, n_samples].plot(dt, encoder_gamma[:, i], label="encoder")
    axes[i, n_samples].plot(dt, interp_gamma[:, i], label="interpolator")
    axes[i, n_samples].legend()
    axes[i, n_samples].set_ylim(-0.1, 1.1)

fig.tight_layout()
fig.savefig(args.image_output)
