from functools import partial
import jax.numpy as jnp
from flax import linen as nn
from flax.linen.initializers import zeros, constant, orthogonal
import jax
from typing import Sequence


class MPO_map(nn.Module):
    num_hidden_units: int = 128
    temporally_aware: bool = False
    parametrised_reward_model: bool = False
    add_logsimoid_bias: bool = False
    add_sft_bias: bool = False
    add_dpo_bias: bool = False
    sft_term: bool = True

    def setup(self):
        self.activation = MonotonicActivation()
        # PO term
        self.po_const = SplitDense(
            self.num_hidden_units, n_dense=len(self.activation.activations)
        )
        if self.add_dpo_bias:
            self.dense_po = nn.Dense(1, kernel_init=constant(0), use_bias=False)
            self.init_po = nn.Dense(1, kernel_init=constant(1), use_bias=False)
        else:
            self.dense_po = nn.Dense(1, use_bias=False)
        if self.temporally_aware:
            self.po_time = SplitDense(
                self.num_hidden_units, n_dense=len(self.activation.activations)
            )
        # SFT term
        if self.sft_term:
            self.sft_const = SplitDense(
                self.num_hidden_units, n_dense=len(self.activation.activations)
            )
            if self.add_sft_bias:
                self.dense_sft = nn.Dense(1, kernel_init=constant(0), use_bias=False)
                self.init_sft = nn.Dense(1, kernel_init=constant(1), use_bias=False)
            else:
                self.dense_sft = nn.Dense(1, use_bias=False)
            if self.temporally_aware:
                self.sft_time = SplitDense(
                    self.num_hidden_units, n_dense=len(self.activation.activations)
                )
        # g term
        if self.parametrised_reward_model:
            self.g_const = SplitDense(
                self.num_hidden_units, n_dense=len(self.activation.activations)
            )
            if self.temporally_aware:
                self.g_time = SplitDense(
                    self.num_hidden_units, n_dense=len(self.activation.activations)
                )
            if self.add_logsimoid_bias:
                self.dense_g = nn.Dense(1, kernel_init=constant(0), use_bias=False)
                self.dense_logsigmoid = nn.Dense(
                    1, kernel_init=constant(1), use_bias=False
                )
            else:
                self.dense_g = nn.Dense(1, use_bias=False)

    def __call__(self, pi_in, time=0):
        # Args
        pi = jnp.array([pi_in])
        pi_time = jnp.array([pi_in * time])

        # SFT term
        if self.sft_term:
            if self.temporally_aware:
                out_sft = self.activation(self.sft_const(pi) + self.sft_time(pi_time))
            else:
                out_sft = self.activation(self.sft_const(pi))
            out_sft = self.dense_sft(out_sft).mean()
            if self.add_sft_bias:
                out_sft += self.init_sft(jnp.log(jnp.clip(pi, 1e-7, 0.9999))).sum()
        else:
            out_sft = 0.0

        # PO term
        if self.temporally_aware:
            out_po = self.activation(self.po_const(pi) + self.po_time(pi_time))
        else:
            out_po = self.activation(self.po_const(pi))
        out_po = self.dense_po(out_po).mean()
        if self.add_dpo_bias:
            out_po += self.init_po(
                jnp.log(jnp.clip(pi, 1e-7, 0.9999))
                - jnp.log(1 - jnp.clip(pi, 1e-7, 0.9999))
            ).sum()

        # g term
        if self.parametrised_reward_model:
            if self.temporally_aware:
                out_g = self.activation(self.g_const(pi) + self.g_time(pi_time))
            else:
                out_g = self.activation(self.g_const(pi))
            out_g = self.dense_g(out_g).mean()
            if self.add_logsimoid_bias:
                out_g = out_g + self.dense_logsigmoid(nn.log_sigmoid(pi)).mean()
        else:
            out_g = nn.log_sigmoid(pi).mean()

        return out_sft, out_po, out_g


class MonotonicActivation(nn.Module):
    def setup(self):
        self.activations = [
            lambda pi: pi,
            lambda pi: pi**3,
            lambda pi: nn.relu(pi) ** 2,
            lambda pi: nn.relu(pi) ** (1 / 2),
            lambda pi: nn.relu(pi) ** (1 / 3),
            lambda pi: (jnp.exp(jnp.minimum(pi, 1)) - 1) / (jnp.exp(1) - 1),
            lambda pi: (jnp.log(jnp.maximum(pi, 1e-4)) - jnp.log(1e-4))
            / (jnp.log(1) - jnp.log(1e-4)),
            lambda pi: (jnp.tanh((pi - 0.5) * 4) - jnp.tanh(-2))
            / (jnp.tanh(2) - jnp.tanh(-2)),
            lambda pi: (
                jnp.log(
                    jnp.maximum(
                        jnp.minimum(pi, 100 / 101) / (1 - jnp.minimum(pi, 100 / 101)),
                        1e-4,
                    )
                )
                - jnp.log(1e-4)
            )
            / (jnp.log(100) - jnp.log(1e-4)),
        ]

    @nn.compact
    def __call__(self, input):
        out = jnp.array(
            [self.activations[i](input[i]) for i in range(len(self.activations))]
        )
        return out.flatten()


class SplitDense(nn.Module):
    num_hidden_units: int = 128
    kernel_init: nn.initializers.Initializer = orthogonal(jnp.sqrt(2))
    use_bias: bool = True
    n_dense: int = 9

    def setup(self):
        self.mono_layers = [
            nn.Dense(
                self.num_hidden_units // self.n_dense,
                kernel_init=self.kernel_init,
                use_bias=self.use_bias,
            )
            for _ in range(self.n_dense)
        ]

    @nn.compact
    def __call__(self, input):
        out = jnp.array([self.mono_layers[i](input) for i in range(self.n_dense)])
        return out


def param_tuner(params, single=True, floor=0):
    if single:
        tuner_const = param_tuner_time_single
        tuner_time = partial(param_tuner_const_single, floor=floor)  # REMOVE FLOOR?
    else:
        tuner_const = param_tuner_time
        tuner_time = partial(param_tuner_const)

    const_layers = [
        "sft_const",
        "po_const",
        "g_const",
        "dense_sft",
        "dense_po",
        "dense_g",
    ]
    time_layers = ["sft_time", "po_time", "g_time"]

    for layer in const_layers:
        if layer in params:
            params[layer] = jax.tree_util.tree_map_with_path(tuner_const, params[layer])
    for layer in time_layers:
        if layer in params:
            params[layer] = jax.tree_util.tree_map_with_path(tuner_time, params[layer])
    return params


def param_tuner_time(path, value_const, value_time):
    keys = [k.key for k in path]

    def _param_tuner(value_const, value_time):
        if "kernel" in keys:
            return jnp.maximum(value_time, -value_const)
        else:
            return value_time * 0

    return jax.vmap(jax.vmap(jax.vmap(_param_tuner)))(value_const, value_time)


def param_tuner_time_single(path, value_const, value_time):
    keys = [k.key for k in path]

    def _param_tuner(value_const, value_time):
        if "kernel" in keys:
            return jnp.maximum(value_time, -nn.relu(value_const))
        else:
            return value_time * 0

    return _param_tuner(value_const, value_time)


def param_tuner_const(path, values):
    keys = [k.key for k in path]

    def _param_tuner(values):
        if "kernel" in keys:
            return nn.relu(values)
        else:
            return values

    return jax.vmap(jax.vmap(jax.vmap(_param_tuner)))(values)


def param_tuner_const_single(path, values, floor=0):
    keys = [k.key for k in path]

    def _param_tuner(values):
        if "kernel" in keys:
            return jnp.maximum(values, floor)
        else:
            return values

    return _param_tuner(values)
