import uuid
import io
import tensorflow as tf
import numpy as np
import pathlib
import copy
import functools
from tensorflow.keras.mixed_precision import experimental as prec
from tensorflow.keras import models
from datetime import datetime
from io import StringIO
import sys


class Capturing(list):
    def __enter__(self):
        self._stdout = sys.stdout
        sys.stdout = self._stringio = StringIO()
        return self

    def __exit__(self, *args):
        self.extend(self._stringio.getvalue().splitlines())
        del self._stringio    # free up some memory
        sys.stdout = self._stdout


class AttrDict(dict):
    __setattr__ = dict.__setitem__
    __getattr__ = dict.__getitem__


def flatten(x):
    return tf.reshape(x, [-1] + list(x.shape[2:]))


def distortion_de(tau, risk_param):
    tau = tf.clip_by_value(tau, clip_value_min=0, clip_value_max=1)
    if risk_param >= 0:
        tau_ = tf.constant((1. / risk_param) * (tau.numpy() < risk_param))
        tau_ = tf.cast(tau_, dtype=float)
        return tf.clip_by_value(tau_, clip_value_min=0, clip_value_max=5)
    else:
        return distortion_de(1 - tau, -risk_param)


def quantile_regression_loss(input, target, tau, weight):
    """
    :param input: (batch x T)
    :param target: (batch x T)
    :param tau:
    :param weigt:
    :return:
    """
    input = tf.expand_dims(input, axis=-1)
    target = tf.expand_dims(target, axis=-2)
    tau = tf.expand_dims(tau, axis=-1)
    weight = tf.expand_dims(weight, axis=-2)

    expanded_input = tf.tile(input, [1, 1, tau.shape[-2]])
    expanded_target = tf.tile(target, [1, tau.shape[-2], 1])
    expanded_tau = tf.tile(tau, [1, 1, tau.shape[-2]])

    huber = tf.keras.losses.Huber(reduction=tf.keras.losses.Reduction.NONE)
    L = huber(tf.expand_dims(expanded_input, axis=-1), tf.expand_dims(expanded_target, axis=-1))
    sign = tf.math.sign(expanded_input - expanded_target) / 2. + 0.5
    rho = tf.math.abs(expanded_tau - sign) * L * weight
    rho = tf.reduce_sum(rho, axis=-1)
    return tf.reduce_mean(rho)



@tf.function
def quantile_huber_loss(T_theta, Theta, tau_quantiles):
    """Compute quantile huber loss.
    Source ORAAC
    Parameters
    ----------
    T_theta: torch.Tensor
            Target quantiles of size [batch_size x num_quantiles]

    Theta: torch.Tensor
            Current quantiles of size [batch_size x num_quantiles]
    tau_quantiles: torch.Tensor
        Quantile levles: [1xnum_quantiles]

    Returns
    -------
    loss: float
        Quantile Huber loss
    """

    batch_size, num_quantiles = Theta.shape
    Theta_ = tf.expand_dims(Theta, axis=2)  # batch_size, N, 1
    T_theta_ = tf.expand_dims(T_theta, axis=1)  # batch_size. 1. N
    tau = tf.expand_dims(tf.expand_dims(tau_quantiles, axis=0), axis=2)     # 1, N, 1
    error = T_theta_ - Theta_  # all minus all [batch_size, N, N]

    quantile_loss = tf.abs(tau - tf.cast(tf.math.less(error, tf.constant([0.])), tf.float32))  # (batch_size, N, N)
    huber = tf.keras.losses.Huber(reduction=tf.keras.losses.Reduction.NONE)
    huber_loss_ = huber(tf.expand_dims(tf.tile(Theta_, tf.constant([1, 1, num_quantiles])), axis=-1),
                        tf.expand_dims(tf.tile(T_theta_, tf.constant([1, num_quantiles, 1])), axis=-1))

    loss_ = tf.reduce_mean((quantile_loss * huber_loss_))
    return loss_


class Adam(tf.Module):
    def __init__(self, name, modules, lr, clip=None, wd=None, wdpattern=r'.*'):
        self._name = name
        self._modules = modules
        self._clip = clip
        self._wd = wd
        self._wdpattern = wdpattern
        self._opt = tf.optimizers.Adam(lr)

    @property
    def variables(self):
        return self._opt.variables()

    def __call__(self, tape, loss):
        variables = [module.variables for module in self._modules]
        self._variables = tf.nest.flatten(variables)
        assert len(loss.shape) == 0, loss.shape
        grads = tape.gradient(loss, self._variables)
        norm = tf.linalg.global_norm(grads)
        if self._clip:
            grads, _ = tf.clip_by_global_norm(grads, self._clip, norm)
        self._opt.apply_gradients(zip(grads, self._variables))
        return norm


def count_episodes(directory, episode_length):
    filenames = directory.glob('*.npz')
    lengths = [int(n.stem.rsplit('-', 1)[-1]) - 1 for n in filenames]
    episodes = len(lengths)
    return episodes, int(episodes*episode_length)


def load_episodes(directory, rescan, length=None, balance=False, seed=0, load_episodes=1000):
    directory = pathlib.Path(directory).expanduser()
    random = np.random.RandomState(seed)
    filenames = list(directory.glob('*.npz'))
    load_episodes = min(len(filenames), load_episodes)
    if load_episodes is None:
        load_episodes = int(count_episodes(directory)[0] / 20)

    while True:
        cache = {}
        for filename in random.choice(list(directory.glob('*.npz')),
                                      load_episodes,
                                      replace=False):
            try:
                with filename.open('rb') as f:
                    episode = np.load(f)
                    episode = {k: episode[k] for k in episode.keys() if k not in ['image_128']}
                    # episode['reward'] = copy.deepcopy(episode['success'])
            except Exception as e:
                print(f'Could not load episode: {e}')
                continue
            cache[filename] = episode

        keys = list(cache.keys())
        for index in random.choice(len(keys), rescan):
            episode = copy.deepcopy(cache[keys[index]])
            if length:
                total = len(next(iter(episode.values())))
                available = total - length
                if available < 0:
                    for key in episode.keys():
                        shape = episode[key].shape
                        episode[key] = np.concatenate([episode[key],
                                                       np.zeros([abs(available)] + list(shape[1:]))],
                                                      axis=0)
                    episode['mask'] = np.ones(length)
                    episode['mask'][available:] = 0.0
                elif available > 0:
                    if balance:
                        index = min(random.randint(0, total), available)
                    else:
                        index = int(random.randint(0, available))
                    episode = {k: v[index: index + length] for k, v in episode.items()}
                    episode['mask'] = np.ones(length)
                else:
                    episode['mask'] = np.ones_like(episode['reward'])
            else:
                episode['mask'] = np.ones_like(episode['reward'])
            yield episode


def preprocess_raw(obs, config):
    dtype = prec.global_policy().compute_dtype
    obs = obs.copy()

    with tf.device('cpu:0'):
        obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5
        if 'image_128' in obs.keys():
            obs['image_128'] = tf.cast(obs['image_128'], dtype) / 255.0 - 0.5
        clip_rewards = dict(none=lambda x: x, tanh=tf.tanh)[config['clip_rewards']]
        obs['reward'] = clip_rewards(obs['reward'])
        for k in obs.keys():
            obs[k] = tf.cast(obs[k], dtype)
    return obs


def preprocess_latent(batch):
    """
    Source LOMPO
    """
    dtyp = prec.global_policy().compute_dtype
    batch = batch.copy()
    with tf.device('cpu:0'):
        for key in batch.keys():
            batch[key] = tf.cast(batch[key], dtyp)
    return batch


def load_dataset(directory, config):
    episode = next(load_episodes(directory, 1000, load_episodes=1))
    types = {k: v.dtype for k, v in episode.items()}
    shapes = {k: (None,) + v.shape[1:] for k, v in episode.items()}
    generator = lambda: load_episodes(directory, config['num_train_step_latent_model'],
                                      config['latent_batch_length'])
    dataset = tf.data.Dataset.from_generator(generator, types, shapes)
    dataset = dataset.batch(config['latent_batch_size'], drop_remainder=True)
    dataset = dataset.map(functools.partial(preprocess_raw, config=config))
    dataset = dataset.prefetch(10)
    return dataset


def episode_itterator(datadir):
    while True:
        filenames = list(datadir.glob('*.npz'))
        for filename in list(filenames):
            yield filename


def static_scan(fn, inputs, start, reverse=False):
    last = start
    outputs = [[] for _ in tf.nest.flatten(start)]
    indices = range(len(tf.nest.flatten(inputs)[0]))
    if reverse:
        indices = reversed(indices)
    for index in indices:
        inp = tf.nest.map_structure(lambda x: x[index], inputs)
        last = fn(last, inp)
        [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))]
    if reverse:
        outputs = [list(reversed(x)) for x in outputs]
    outputs = [tf.stack(x, 0) for x in outputs]
    return tf.nest.pack_sequence_as(start, outputs)


def parse_layers(layers, in_dim, non_linearity, normalized=False):
    """Parse layers of nn.
    References
    ----------
    Code from https://github.com/sebascuri/rllib.git."""
    if layers is None:
        layers = []
    elif isinstance(layers, int):
        layers = [layers]

    nonlinearity = non_linearity.lower()
    model = models.Sequential()
    for layer in layers:
        model.add(tf.keras.layers.Dense(layer, activation=nonlinearity))
        if normalized:
            pass
        in_dim = layer

    return model, in_dim


def save_episodes(directory, episodes):
    directory = pathlib.Path(directory).expanduser()
    directory.mkdir(parents=True, exist_ok=True)
    timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
    for episode in episodes:
        identifier = str(uuid.uuid4().hex)
        length = len(episode['reward'])
        filename = directory / f'{timestamp}-{identifier}-{length}.npz'
        with io.BytesIO() as f1:
            np.savez_compressed(f1, **episode)
            f1.seek(0)
            with filename.open('wb') as f2:
                f2.write(f1.read())


def summarize_episode(episode, config, writer, prefix):
    length = (len(episode['reward']) - 1) * config.action_repeat
    ret = episode['reward'].sum()
    print(f'{ret:.1f}')
    metrics = [
            (f'{prefix}/return', float(episode['reward'].sum())),
            (f'{prefix}/length', len(episode['reward']) - 1)]
    with writer.as_default():
        [tf.summary.scalar('sim/' + k, v) for k, v in metrics]





