import functools
from typing import Any, Dict, Sequence

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.core import FrozenDict
from flax.training import train_state
from utils import Batch


class Encoder(nn.Module):
    hid_dim: int = 64
    emb_dim: int = 64

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = nn.Dense(self.hid_dim)(x)
        x = nn.relu(x)
        h = nn.Dense(self.hid_dim)(x)
        h = nn.LayerNorm()(h)
        h = nn.tanh(h)
        embedding = nn.Dense(self.emb_dim)(h)
        return embedding


class ContrastiveEncoder:

    def __init__(self,
                 obs_dim: int,
                 emb_dim: int = 16,
                 ensemble_num: int = 4,
                 memory_size: int = 200,
                 seed: int = 0,
                 lr: float = 3e-4,
                 temperature: float = 0.1,
                 k: int = 10,
                 contrast_batch_size: int = 512):
        self.k = k
        rng = jax.random.PRNGKey(seed)
        self.ensemble_num = ensemble_num
        self.memory_size = memory_size
        self.temperature = temperature
        self.encoder = Encoder(emb_dim=emb_dim)
        dummy_obs = jnp.ones((1, obs_dim), dtype=jnp.float32)
        params = self.encoder.init(rng, dummy_obs)["params"]
        self.state = train_state.TrainState.create(apply_fn=self.encoder.apply,
                                                   params=params,
                                                   tx=optax.adam(lr))
        self.mask = jnp.arange(contrast_batch_size)
        self.ensemble_mask = jnp.ones(ensemble_num) - jnp.eye(ensemble_num)
        self.anchor_idx = jnp.arange(contrast_batch_size).reshape(-1, 1)
        self.positive_idx = jnp.arange(contrast_batch_size, 2*contrast_batch_size).reshape(-1, 1)

    @functools.partial(jax.jit, static_argnames=("self"))
    def reward(self, encoder_state, Xs, Ys):
        cosine_fn = jax.vmap(jax.vmap(optax.cosine_similarity, in_axes=(0, None)),
                             in_axes=(0, None))

        # (4, 256, 64), (1000, 64)
        emb_Xs = self.encoder.apply({"params": encoder_state.params}, Xs) 
        emb_Ys = self.encoder.apply({"params": encoder_state.params}, Ys)

        # (4, 256, 4, 500)
        similarities = cosine_fn(emb_Xs,
                                 emb_Ys).reshape(self.ensemble_num, -1,
                                                 self.ensemble_num,
                                                 self.memory_size)

        # (4, 256, 4, 20)
        top_k_similarities, _ = jax.lax.approx_max_k(similarities, k=self.k)

        # (4, 256, 4)
        top_k_similarity = top_k_similarities[..., -1]

        # (4, 256, 4), (4, 4)
        def mask_fn(sim, mask):
            self_sim = sim * (1 - mask)
            other_sim = sim * mask
            return self_sim, other_sim

        # (4, 256, 4), (4, 256, 4)
        self_similarity, other_similarity = jax.vmap(
            jax.vmap(mask_fn, in_axes=(0, None)), in_axes=(0, 0))(
                top_k_similarity, self.ensemble_mask)

        # (4, 256, 4) => (4, 256)
        self_reward = -self_similarity.sum(-1)
        other_reward = -other_similarity.sum(-1) / (self.ensemble_num - 1)
        reward = (self_reward + other_reward)/2
        return reward

    def compute_reward(self, Xs, Ys):
        return self.reward(self.state, Xs, Ys)

    @functools.partial(jax.jit, static_argnames=("self"))
    def train(self,
               observations: jnp.ndarray,
               encoder_state: train_state.TrainState):
        def loss_fn(params: FrozenDict):
            embeddings = self.encoder.apply({"params": params}, observations)

            # (N, emb_dim)
            emb_anchor = jnp.take_along_axis(embeddings, self.anchor_idx, axis=0)
            emb_positive = jnp.take_along_axis(embeddings, self.positive_idx, axis=0)
            cos_sim = jax.vmap(optax.cosine_similarity, in_axes=(0, None))(
                emb_anchor, emb_positive) / self.temperature

            def infoNCE(similarity, i):
                pos_sim = similarity[i]
                neg_sim = similarity.at[i].set(-9e12)
                loss = -pos_sim + jax.scipy.special.logsumexp(neg_sim)
                return loss, pos_sim
            loss, pos_sim = jax.vmap(infoNCE, in_axes=(0, 0))(cos_sim, self.mask)
            loss = loss.mean()
            return loss, {
                "encoder_loss": loss,
                "cos_sim": cos_sim.mean() * self.temperature,
                "pos_sim": pos_sim.mean() * self.temperature
            }

        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (_, log_info), grads = grad_fn(encoder_state.params)
        new_encoder_state = encoder_state.apply_gradients(grads=grads)
        return new_encoder_state, log_info

    def update(self, observations):
        self.state, log_info = self.train(observations, self.state)
        return log_info
