import os
from PIL import Image
from tqdm import tqdm
import math
from functools import partial
from moviepy.editor import ImageSequenceClip
import numpy as np

import jax
import jax.numpy as jnp


def order_str_to_list(order, T, **kwargs):
    if order == 'autoregressive':
        order_lst = np.arange(T)
    elif order == 'strided':
        pass
    elif order == 'hierarchical':
        pass
    else:
        raise ValueError(f'Invalid order: {order}')
    return order_lst


def masks_from_order(order, max_frames, t): # "t" is different from indices in "order"
    # exact ordering shouldn't matter since models are general order agnostic
    if t < max_frames:
        frame_indices = order[:max_frames]
        n_obs, n_lat = 0, max_frames
    else:
        frame_indices = order[t - max_frames + 1:t + 1]
        n_obs, n_lat = max_frames - 1, max_frames
    
    obs_mask = np.zeros((len(order),), dtype=np.float32)
    obs_mask[frame_indices[:n_obs]] = 1.
    latent_mask = np.zeros((len(order),), dtype=np.float32)
    latent_mask[frame_indices[n_obs:]] = 1.

    return obs_mask, latent_mask, frame_indices


def sample_some_indices(max_indices, T):
    s = np.random.randint(low=1, high=max_indices + 1, size=())
    max_scale = T / (s - 0.999)
    scale = np.exp(np.random.rand() * np.log(max_scale))
    pos = np.random.rand() * (T - scale * (s - 1))
    indices = [int(pos + i * scale) for i in range(s)]

    if all(i < T and i >= 0 for i in indices):
        return indices
    else:
        print('warning: sampled indices', [int(pos + i * scale) for i in range(s)], 'trying again')
        return sample_some_indices(max_indices, T)


def sample_masks(batch, max_frames):
    M = max_frames
    B, T = batch['video'].shape[:2]
    masks = {k: np.zeros((B, T, 1, 1, 1), dtype=np.float32)
                for k in ['obs', 'latent']}
    for obs_row, latent_row in zip(*[masks[k] for k in ['obs', 'latent']]):
        latent_row[sample_some_indices(max_indices=M, T=T)] = 1.
        while sum(obs_row) + sum(latent_row) < M:
            mask_i = np.random.randint(0, 2, size=())
            mask = [obs_row, latent_row][mask_i]
            indices = sample_some_indices(max_indices=M, T=T)
            taken = (obs_row[indices] + latent_row[indices]).reshape(-1)
            indices = indices[taken == 0]
            if len(indices) > M - sum(obs_row) - sum(latent_row):
                continue
            mask[indices] = 1.
    
    represented_mask = (masks['obs'] + masks['latent']).clip(max=1)
    represented_mask, batch, (obs_mask, latent_mask), frame_indices =\
        gather(represented_mask, batch, (masks['obs'], masks['latent']))
    return batch, frame_indices, obs_mask.astype(bool), latent_mask.astype(bool)


def gather(mask, batch, tensors, max_frames):
    # TODO action conditioning
    # TODO check if actions are correct
    # TODO action aggregation
    B, T = mask.shape[:2]
    mask = mask.reshape(B, T)
    M = max_frames

    indices = np.zeros_like(mask[:, :M], dtype=np.int32)
    new_batch = {k: np.zeros_like(v[:, :M]) for k, v in batch.items()}
    new_tensors = [np.zeros_like(t[:, :M]) for t in tensors]

    for b in range(B):
        indices[b, :M] = mask[b].nonzero()[0].flatten()
        for k, v in new_batch.items():
            v[b, :] = batch[k][b][mask[b] == 1]

        for new_t, t in zip(new_tensors, tensors):
            new_t[b, :] = t[b][mask[b] == 1]
            
    return new_batch, new_tensors, indices


def tokens_to_text(tokens, model, eos_id=2):
    # tokens: ...L, returns in FLATTENED form
    tokens = jax.device_get(tokens)
    tokens = np.reshape(tokens, (-1, tokens.shape[-1]))

    texts = model.decode_tf(tokens).numpy().tolist()
    texts = [t.decode('utf-8') for t in texts]
    return texts

    texts = []
    for i in range(tokens.shape[0]):
        tok = tokens[i]
        idx = np.nonzero(tok == eos_id)[0][0]
        tok = tok[1:idx].tolist() # exclude bos, eos, pad
        text = model.decode_ids(tok)
        texts.append(text)
    return texts

def per_block_map(fn, block_size, *args, axis=0, need_split=None, n_flatten_dims=1):
    if need_split is None:
        need_split = [True] * len(args)
    assert len(need_split) == len(args)

    idx = min([i for i, ns in enumerate(need_split) if ns])
    shape = args[idx].shape[axis:axis + n_flatten_dims]
    size = np.prod(shape)

    args = [flatten(a, axis, axis + n_flatten_dims) if ns else a
            for a, ns in zip(args, need_split)]

    out = []
    for i in tqdm(list(range(0, size, block_size))):
        slc = [slice(None, None)] * (axis + 1)
        slc[axis] = slice(i, i + block_size)
        slc = tuple(slc)
        block_args = [a[slc] if ns else a for a, ns in zip(args, need_split)]
        out.append(jax.device_get(fn(*block_args)))

    out = [np.concatenate(o, axis=axis) for o in zip(*out)]
    out = [reshape_range(o, axis, axis + 1, shape) for o in out]
    return tuple(out)


def normalize(x):
    x = x / jnp.clip(jnp.linalg.norm(x, axis=-1, keepdims=True), a_min=1e-6, a_max=None)
    return x


def topk_sample(rng, logits, top_k=None, top_p=None):
    if top_k is not None:
        top_k = min(top_k, logits.shape[-1])
        indices_to_remove = logits < jax.lax.top_k(logits, top_k)[0][..., -1, None]
        logits = jnp.where(indices_to_remove, jnp.finfo(logits.dtype).min, logits)

    if top_p is not None:
        assert 0 < top_p < 1
        sorted_logits = jax.lax.sort(logits, is_stable=False)
        sorted_probs = jax.nn.softmax(sorted_logits)
        threshold_idx = jnp.argmax(
            jnp.cumsum(sorted_probs, -1) >= 1 - top_p, axis=-1
        )
        threshold_largest_logits = jnp.take_along_axis(
            sorted_logits, threshold_idx[..., None], axis=-1
        )
        assert threshold_largest_logits.shape == logits.shape[:-1] + (1,)
        mask = logits >= threshold_largest_logits
        logits = jnp.where(mask, logits, jnp.finfo(logits.dtype).min)
        
    samples = jax.random.categorical(rng, logits, axis=-1)
    return samples 

def block_flatten(x, block_size):
    B, H, W, C = x.shape
    G = block_size

    x = jnp.reshape(x, (B, H // G, G, W // G, G, C))
    x = jnp.transpose(x, (0, 1, 3, 2, 4, 5))
    x = jnp.reshape(x, (B, H // G, W // G, G ** 2 * C))
    return x


def block_unflatten(x, block_size):
    B, H, W, C = x.shape
    G = block_size

    x = jnp.reshape(x, (B, H, W, G, G, C // G ** 2))
    x = jnp.transpose(x, (0, 1, 3, 2, 4, 5))
    x = jnp.reshape(x, (B, H * G, W * G, C // G ** 2))
    return x


def quantize(x):
    # [-1, 1] -> {0, .., 255}
    x = x * 0.5 + 0.5
    x = (x * 255).astype(jnp.int32)
    return x

    
def dequantize(x):
    # {0, ..., 255} -> [-1, 1]
    x = x.astype(jnp.float32) / 255.
    x = 2 * x - 1
    return x


def reset(variables):
    if 'cache_index' not in variables:
        return variables
    leaves, treedef = jax.tree_util.tree_flatten(variables['cache_index'])
    leaves = [jnp.zeros_like(leaf) for leaf in leaves]
    variables['cache_index'] = jax.tree_util.tree_unflatten(treedef, leaves)

    return variables


def sample_pred(model, variables, video, actions, seed=0):
    rng = jax.random.PRNGKey(seed)
    rng = jax.random.split(rng, jax.local_device_count())

    assert video.shape[0] == jax.local_device_count()
    shape = (model.config.n_timesteps_train, *model.vqvae.latent_shape)
    seq_len = np.prod(shape)
    frame_seq_len = np.prod(shape[1:])
    B = video.shape[1]
    n_itrs = model.config.eval_seq_len - model.config.open_loop_ctx
    
    assert model.config.open_loop_ctx >= model.config.n_timesteps_train - 1
    cond_frames = video[:, :, :model.config.open_loop_ctx]
    _, cond_encodings = jax.pmap(model.vq_fns['encode'], axis_name='batch')(cond_frames)
    cond_encodings = jax.pmap(lambda ce: [flatten(ce[:, t], start=1) 
                                          for t in range(ce.shape[1])])(cond_encodings)

    def _model_step(variables, sample, t, actions):
        embeddings = model.vq_fns['lookup'](sample)

        logits, cache = model.apply(variables, embeddings, t, 
                                    actions, method=model._step,
                                    mutable=['cache', 'cache_index'])
        if model.decode:
            logits = logits[:, -1]

        return logits, cache

    def _sample_step(logits, rng):
        rng, new_rng = jax.random.split(rng)
        s = jax.random.categorical(rng, logits, axis=-1)
        return s, new_rng

    samples = cond_encodings
    pbar = tqdm(total=n_itrs * frame_seq_len)
    for i in range(model.config.open_loop_ctx, model.config.eval_seq_len):
        t = jax.device_put_replicated(i, jax.local_devices())

        acts = actions[:, :, i - model.config.n_timesteps_train + 1:i + 1] if actions is not None else None
        sample = jnp.zeros((B, frame_seq_len), dtype=jnp.int32)
        sample = jax.device_put_replicated(sample, jax.local_devices())
        variables = reset(variables)
        for j in range(frame_seq_len):
            sample_inp = sample[:, :, [j - 1]] if model.decode else sample
            if j == 0 or not model.decode:
                cond = samples[-(model.config.n_timesteps_train - 1):]
                assert len(cond) == model.config.n_timesteps_train - 1, len(cond)
                cond = jax.pmap(lambda x: jnp.concatenate(x, axis=1))(cond)
                sample_inp = jax.pmap(lambda c, si: jnp.concatenate([c, si], axis=1))(cond, sample_inp)   
            logits, new_cache = jax.pmap(
                _model_step,
            )(variables=variables, sample=sample_inp, t=t, actions=acts)
            variables.update(new_cache)

            if not model.decode:
                logits = logits[:, :, j + seq_len - frame_seq_len]
            
            s, rng = jax.pmap(_sample_step)(logits, rng)

            sample = sample.at[:, :, j].set(s)

            pbar.update(1)
            
            #if os.environ.get('DEBUG') == '1':
            #    break
        samples.append(sample)
    pbar.close()

    samples = [jax.pmap(lambda s: reshape_range(s, 1, 2, shape[1:]))(s)
                for s in samples] # [NBHW, ... NBHW]
    samples = jax.pmap(partial(jnp.stack, axis=1))(samples) # NBTHW
    
    def _decode(samples):
        # samples: BTHW
        B, T = samples.shape[:2]
        samples = flatten(samples, 0, 2) # (BT)HW
        samples = jnp.concatenate([model.vq_fns['decode'](samples[i:i+64][:, None])[:, 0]
                                    for i in range(0, samples.shape[0], 64)])
        samples = reshape_range(samples, 0, 1, (B, T))
        samples = jnp.clip(samples, -1, 1)
        return samples # BTHWC

    samples = jax.device_get(jax.pmap(_decode)(samples))
    if video.shape[-1] == 16:
        video = jax.device_get(jax.pmap(_decode)(video))
    return samples, video
    

def sample(model, variables, video, actions, seed=0, sample_greedy=False, return_logits=False):
    has_memory = hasattr(model, 'init_memory')
    if has_memory:
        memory = model.init_memory(video.shape[1])
    
    rng = jax.random.PRNGKey(seed)
    rng = jax.random.split(rng, jax.local_device_count())

    assert video.shape[0] == jax.local_device_count()
    shape = (model.config.n_timesteps_train, *model.vqvae.latent_shape)
    seq_len = np.prod(shape)
    N, B = video.shape[:2]
    T = model.config.n_timesteps_train
    n_itrs = video.shape[2] // T
    
    if model.config.open_loop_ctx > 0:
        cond_frames = video[:, :, :model.config.open_loop_ctx]
        _, cond_encodings = jax.pmap(model.vq_fns['encode'], axis_name='batch')(cond_frames)
        cond_encodings = jax.pmap(partial(flatten, start=1))(cond_encodings)
        cond_latents = cond_encodings.shape[2]
    else:
        cond_latents = 0

    n_cond_used = 0

    if has_memory:
        def _model_step(variables, sample, t, actions, memory):
            embeddings = model.vq_fns['lookup'](sample)

            logits, memory, cache = model.apply(variables, embeddings, t, 
                                                actions, memory,  method=model._step,
                                                mutable=['cache', 'cache_index'])  
            if model.decode:
                logits = jnp.squeeze(logits, axis=1)

            return logits, memory, cache
    else:
        def _model_step(variables, sample, t, actions):
            embeddings = model.vq_fns['lookup'](sample)

            logits, cache = model.apply(variables, embeddings, t, 
                                        actions, method=model._step,
                                        mutable=['cache', 'cache_index'])
            if model.decode:
                logits = jnp.squeeze(logits, axis=1)

            return logits, cache

    def _sample_step(logits, rng):
        rng, new_rng = jax.random.split(rng)
        if sample_greedy:
            s = jnp.argmax(logits, axis=-1)
        else:
            s = jax.random.categorical(rng, logits, axis=-1)
        return s, new_rng

    all_logits, samples = [], []
    pbar = tqdm(total=n_itrs * seq_len)
    for i in range(n_itrs):
        t = jax.device_put_replicated(i, jax.local_devices())

        acts = actions[:, :, i * T:(i + 1) * T] if actions is not None else None
        sample = jnp.zeros((B, seq_len), dtype=jnp.int32)
        sample = jax.device_put_replicated(sample, jax.local_devices())
        variables = reset(variables)
        for j in range(seq_len):
            sample_inp = sample[:, :, [j - 1]] if model.decode else sample
            if has_memory:
                if j == seq_len - 1:
                    logits, memory, new_cache = jax.pmap(
                        _model_step
                    )(variables=variables, sample=sample_inp, t=t, actions=acts, memory=memory)
                else:
                    logits, _ = jax.pmap(
                        _model_step
                    )(variables=variables, sample=sample_inp, t=t, actions=acts, memory=memory)
            else:
                logits, new_cache = jax.pmap(
                    _model_step,
                )(variables=variables, sample=sample_inp, t=t, actions=acts)
            variables.update(new_cache)

            if not model.decode:
                logits = logits[:, :, j]
            
            if return_logits:
                all_logits.append(logits)
            
            if n_cond_used < cond_latents:
                s = cond_encodings[:, :, n_cond_used]
                n_cond_used += 1
            else:
                s, rng = jax.pmap(_sample_step)(logits, rng)

            sample = sample.at[:, :, j].set(s)

            pbar.update(1)
            
            #if os.environ.get('DEBUG') == '1':
            #    break
        samples.append(sample)
    pbar.close()

    samples = [jax.pmap(lambda s: reshape_range(s, 1, 2, shape))(s)
                for s in samples] # [NBT'HW, ... NBT'HW]
    samples = jax.pmap(partial(jnp.concatenate, axis=1))(samples) # NBTHW
    
    def _decode(samples):
        # samples: BTHW
        B, T = samples.shape[:2]
        samples = flatten(samples, 0, 2) # (BT)HW
        samples = jnp.concatenate([model.vq_fns['decode'](samples[i:i+64][:, None])[:, 0]
                                    for i in range(0, samples.shape[0], 64)])
        samples = reshape_range(samples, 0, 1, (B, T))
        samples = jnp.clip(samples, -1, 1)
        return samples # BTHWC

    samples = jax.device_get(jax.pmap(_decode)(samples))
    if video.shape[3] == 16:
        video = jax.device_get(jax.pmap(_decode)(video))
    return all_logits if return_logits else (samples, video)


def add_border(video, color, width=0.025):
    # video: BTHWC in [0, 1]
    S = math.ceil(int(video.shape[3] * width))

    # top
    video[:, :, :S, :, 0] = color[0]
    video[:, :, :S, :, 1] = color[1]
    video[:, :, :S, :, 2] = color[2]

    # bottom
    video[:, :, -S:, :, 0] = color[0]
    video[:, :, -S:, :, 1] = color[1]
    video[:, :, -S:, :, 2] = color[2]

    # left
    video[:, :, :, :S, 0] = color[0]
    video[:, :, :, :S, 1] = color[1]
    video[:, :, :, :S, 2] = color[2]

    # right
    video[:, :, :, -S:, 0] = color[0]
    video[:, :, :, -S:, 1] = color[1]
    video[:, :, :, -S:, 2] = color[2]

    
def pad(x, max_len, axis, pre_pad=True):
    if isinstance(x, jnp.ndarray):
        ndim = len(x.shape)
        padding = [(0, 0)] * ndim
        assert x.shape[axis] <= max_len, f'{x.shape[axis]} > {max_len}'
        to_pad = max_len - x.shape[axis]
        if pre_pad:
            padding[axis] = (to_pad, 0)
        else:
            padding[axis] = (0, to_pad)
        return jnp.pad(x, padding)
    elif isinstance(x, np.ndarray):
        ndim = len(x.shape)
        padding = [(0, 0)] * ndim
        assert x.shape[axis] <= max_len, f'{x.shape[axis]} > {max_len}'
        to_pad = max_len - x.shape[axis]
        if pre_pad:
            padding[axis] = (to_pad, 0)
        else:
            padding[axis] = (0, to_pad)
        return np.pad(x, padding)
    elif isinstance(x, dict):
        return {k: pad(v, max_len, axis) for k, v in x.items()}
    elif isinstance(x, tuple):
        return tuple([pad(x_i, max_len, axis) for x_i in x])
    elif isinstance(x, list):
        return [pad(x_i, max_len, axis) for x_i in x]
    else:
        raise NotImplementedError


def flatten(x, start=0, end=None):
    i, j = start, end
    n_dims = len(x.shape)
    if i < 0:
        i = n_dims + i

    if j is None:
        j = n_dims
    elif j < 0:
        j = n_dims + j

    return reshape_range(x, i, j, (np.prod(x.shape[i:j]),))

    
def reshape_range(x, i, j, shape):
    shape = tuple(shape)

    n_dims = len(x.shape)
    if i < 0:
        i = n_dims + i
    
    if j is None:
        j = n_dims
    elif j < 0:
        j = n_dims + j
    
    assert 0 <= i < j <= n_dims

    x_shape = x.shape
    target_shape = x_shape[:i] + shape + x_shape[j:]
    return jnp.reshape(x, target_shape)


def save_image_grid(images, fname=None, nrow=None):
    b, h, w, c = images.shape
    images = (images * 255).astype('uint8')

    if nrow is None:
        nrow = math.ceil(math.sqrt(b))
    ncol = math.ceil(b / nrow)
    padding = 1
    image_grid = np.zeros(((padding + h) * ncol + padding,
                          (padding + w) * nrow + padding, c), dtype='uint8')
    for i in range(b):
        r = i // nrow
        c = i % nrow
        
        start_r = (padding + h) * r
        start_c = (padding + w) * c
        image_grid[start_r:start_r + h, start_c:start_c + w] = images[i]

    if fname is not None:
        image = Image.fromarray(image_grid)
        image.save(fname)
        print('saved image to', fname)

    return image_grid
 

def save_video_grid(video, fname=None, nrow=None, fps=10):
    b, t, h, w, c = video.shape
    video = (video * 255).astype('uint8')

    if nrow is None:
        nrow = math.ceil(math.sqrt(b))
    ncol = math.ceil(b / nrow)
    padding = 1
    video_grid = np.zeros((t, (padding + h) * ncol + padding,
                          (padding + w) * nrow + padding, c), dtype='uint8')
    for i in range(b):
        r = i // nrow
        c = i % nrow

        start_r = (padding + h) * r
        start_c = (padding + w) * c
        video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]

    if fname is not None:
        clip = ImageSequenceClip(list(video_grid), fps=fps)
        clip.write_gif(fname, fps=fps)
        print('saved videos to', fname)
    
    return video_grid # THWC, uint8
