"""
Copyright (c) ANONYMOUS
All rights reserved.

MIT License

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""

import chex
import haiku as hk
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import optax
from flax import struct

from ccoa.utils.utils import flatcat


@struct.dataclass
class RewardFeaturesState:
    backbone: hk.Params
    optim: optax.OptState
    readout: hk.Params


class RewardFeatures:
    """
    Args:
        backbone: hk.Module that maps (observation, action) -> feature without batching
    """

    def __init__(
        self,
        num_actions,
        backbone,
        optimizer,
        steps,
        reward_values,
        l1_reg_activation,
        l2_reg_activation,
        l1_reg_params,
        l2_reg_readout,
        balance_loss,
        mask_zero_reward_loss,
        use_mse,
    ) -> None:
        self.num_actions = num_actions
        self.backbone = hk.transform(backbone)
        if use_mse:
            self.readout = hk.without_apply_rng(hk.transform(lambda x: hk.Linear(1)(x)))
        else:
            self.readout = hk.without_apply_rng(hk.transform(lambda x: hk.Linear(num_actions)(x)))
        self.optimizer = optimizer
        self.steps = steps
        self.reward_values = reward_values
        self.l1_reg_activation = l1_reg_activation
        self.l2_reg_activation = l2_reg_activation
        self.l1_reg_params = l1_reg_params
        self.l2_reg_readout = l2_reg_readout
        self.balance_loss = balance_loss
        self.mask_zero_reward_loss = mask_zero_reward_loss
        self.use_mse = use_mse

    def __call__(self, rng: chex.PRNGKey, state: RewardFeaturesState, observation, action):
        return (self.backbone.apply(state.backbone, rng, observation, action) > 0)*1.

    def reset(self, rng: chex.PRNGKey, dummy_observation: jnp.ndarray):
        rng_backbone, rng_readout = jax.random.split(rng)
        params_backbone = self.backbone.init(rng_backbone, dummy_observation, 0)
        dummy_features = self.backbone.apply(params_backbone, rng_backbone, dummy_observation, 0)

        # One readout per action
        rngs_readout = jax.random.split(rng_readout, self.num_actions)
        params_readout = jax.vmap(self.readout.init, in_axes=(0, None))(
            rngs_readout, dummy_features
        )

        optim = self.optimizer.init([params_backbone, params_readout])

        return RewardFeaturesState(backbone=params_backbone, optim=optim, readout=params_readout)

    def update(
        self, rng: chex.PRNGKey, state: RewardFeaturesState, batch_sampler
    ) -> RewardFeaturesState:
        def batch_loss(rng, params, observations, rewards, actions):
            rngs = jax.random.split(rng, observations.shape[0] * observations.shape[1]).reshape(
                observations.shape[0], observations.shape[1], -1
            )
            # vmap over both batch_size and steps in trajectory
            loss_fn_batched = jax.vmap(
                jax.vmap(self.loss_fn, in_axes=(0, None, None, 0, 0, 0)),
                in_axes=(0, None, None, 0, 0, 0),
            )

            loss, metrics = loss_fn_batched(
                rngs, params[0], params[1], observations, rewards, actions
            )
            metrics = jtu.tree_map(lambda x: jnp.mean(x, axis=(0, 1)), metrics)

            for i, value in enumerate(self.reward_values):
                metrics["acc_{}".format(value)] = jnp.where(metrics["num_class_{}".format(value)]>0,
                                                            metrics["true_pos_{}".format(value)] / metrics["num_class_{}".format(value)],
                                                            0)

            if self.balance_loss:
                # Re-weigh the loss by the frequency of unique values to alleviate imbalance
                _, unique_inv, unique_counts = jnp.unique(
                    rewards.reshape(-1), return_inverse=True, return_counts=True, size=len(self.reward_values)
                )
                loss = loss / (unique_counts / unique_counts.sum())[unique_inv].reshape(loss.shape)

            if self.mask_zero_reward_loss:
                # Only consider non-zero rewards
                reward_mask = rewards != 0

                # If there are no non-zero rewards in a trajectory, make sure we do not divide by zero
                normalizer = jnp.sum(reward_mask)
                normalizer = normalizer * (normalizer != 0) + (normalizer == 0)

                loss = jnp.sum(loss * reward_mask, axis=(0, 1)) / normalizer
            else:
                loss = jnp.mean(loss, axis=(0, 1))

            return loss, metrics

        def update_step(carry, rng_t):
            params, optim = carry
            rng_loss, rng_sample = jax.random.split(rng_t)

            # Sample a batch of trajectories from the replay buffer
            batch_trajectory = batch_sampler(rng_sample)
            observations = batch_trajectory.observations
            rewards = batch_trajectory.rewards
            actions = batch_trajectory.actions

            # Compute loss
            (loss, metrics), grads = jax.value_and_grad(batch_loss, argnums=1, has_aux=True)(
                rng_loss, params, observations, rewards, actions
            )

            # Update params
            params_update, optim = self.optimizer.update(grads, optim, params)
            next_params = optax.apply_updates(params, params_update)

            params_backbone, params_readout = next_params[0], next_params[1]

            if self.l1_reg_params > 0:
                params_backbone = jax.tree_util.tree_map(lambda p: p - self.l1_reg_params*jnp.sign(p), params_backbone)

            if self.l2_reg_readout > 0:
                params_readout = jax.tree_util.tree_map(lambda p: p - self.l2_reg_readout*p, params_readout)

            metrics = {"loss": loss, "gradnorm": optax.global_norm(grads),
                       "readout_norm": jnp.sqrt(jnp.sum(flatcat(params_readout) ** 2)),
                       "backbone_norm": jnp.sqrt(jnp.sum(flatcat(params_backbone) ** 2)),
                       **metrics}

            return [[params_backbone, params_readout], optim], metrics
        carry, metrics = jax.lax.scan(
            f=update_step,
            init=[[state.backbone, state.readout], state.optim],
            xs=jax.random.split(rng, self.steps),
        )

        # Only select the last element from metrics.
        metrics_summary=dict()
        metrics_summary.update({k + "_start": metrics[k][0] for k in metrics})
        # metrics_summary.update({k + "_end": metrics[k][-1] for k in metrics})

        params, optim = carry
        state = RewardFeaturesState(backbone=params[0], optim=optim, readout=params[1])

        return state, metrics_summary

    def loss_fn(
        self,
        rng: chex.PRNGKey,
        params_backbone: hk.Params,
        params_readout: hk.Params,
        observation: jnp.ndarray,
        reward: jnp.ndarray,
        action: jnp.ndarray,
    ):
        """
        Loss is defined for a single example (no batching over batch_size or num_steps).
        """
        metrics = dict()
        features = self.backbone.apply(params_backbone, rng, observation, action)

        # Select readout params according to action
        params_readout_action = jtu.tree_map(lambda p: p[action], params_readout)
        logit = self.readout.apply(params_readout_action, features)

        acc_dict={}
        if not self.use_mse:
            # Compute cross-entropy loss on one-hot encoded rewards
            label = jnp.nonzero(reward == jnp.array(self.reward_values), size=1)[0].squeeze()
            loss = optax.softmax_cross_entropy_with_integer_labels(logit, label)
            metrics.update({"loss_xent": loss})

            for i, value in enumerate(self.reward_values):
                num_class = ((label == i)*1.).sum()
                correct_class = ((jnp.argmax(logit, axis=-1) == i)*(label == i)*1.).sum()
                nonzero = (jnp.mean(features>0, axis=-1)*(label == i)*1.).sum()
                acc_dict["nonzero_{}".format(value)] = nonzero
                acc_dict["true_pos_{}".format(value)] = correct_class
                acc_dict["num_class_{}".format(value)] = num_class
            metrics.update(acc_dict)
        else:
            label = jnp.nonzero(reward == jnp.array(self.reward_values), size=1)[0].squeeze()

            loss = optax.l2_loss(logit.squeeze(), reward)
            metrics.update({"loss_l2": loss})
            
            for i, value in enumerate(self.reward_values):
                num_class = ((label == i)*1.).sum()
                nonzero = (jnp.mean(features>0, axis=-1)*(label==i)*1.).sum()
                acc_dict["nonzero_{}".format(value)] = nonzero
                acc_dict["true_pos_{}".format(value)] = loss*(label==i)
                acc_dict["num_class_{}".format(value)] = num_class
            metrics.update(acc_dict)
        
        if self.l2_reg_activation > 0:
            loss_l2_actv = jnp.mean(features**2)
            loss += self.l2_reg_activation * loss_l2_actv
            metrics.update({"loss_l2_actv": loss_l2_actv})

        if self.l1_reg_activation > 0:
            loss_l1_actv = jnp.mean(jnp.abs(features))
            loss += self.l1_reg_activation * loss_l1_actv
            metrics.update({"loss_l1_actv": loss_l1_actv})

        # Compute accuracy
        # metrics.update({"acc": (jnp.argmax(logit, axis=-1) == label) * 1.0})

        return loss, metrics
