"""
Copyright (c) ANONYMOUS
All rights reserved.

MIT License

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""

import os
import random
from functools import partial
from pathlib import Path

import chex
import haiku as hk
import jax
import jax.numpy as jnp
import ml_collections as mlc
import optax
from absl import app, flags
from ml_collections import config_flags

import wandb
from callback import (make_contribution_bias_callback,
                      make_policy_var_callback, performance_conveyor,
                      performance_multiple_conveyor, visualize_contribution, get_track_reward_feature)
from ccoa import agents, envs, networks
from ccoa.accumulator import ReplayBuffer, MultiReplayBuffer
from ccoa.contribution import (MDP_DP, causal, features, parallel, qnet,
                               reinforce, value)
from ccoa.envs.treasure_conveyor import ConveyorTreasure
from ccoa.experiment import Event, Experiment
from ccoa.networks.utils import relu_gate

FLAGS = flags.FLAGS

# Import ml_collections flags defined in configs/
config_flags.DEFINE_config_file(
    name="config",
    default="configs/default.py",
    help_string="Training configuration.",
)

# Expose jax flags, allows for example to disable jit with --jax_disable_jit
jax.config.parse_flags_with_absl()

# Additional abseil flags
flags.DEFINE_bool("checkpoint", False, "Whether to save the checkpoint of trained models.")
flags.DEFINE_integer("log_level", 2, "Logging level.")
flags.DEFINE_string("logdir", "logs", "Directory where logs are saved.")


def load_from_cache_or_compute(file_name, fct):
    import glob
    import pickle

    if len(glob.glob(file_name)) > 0:
        return pickle.load(open(file_name, "rb"))
    else:
        Path(os.path.dirname(file_name)).mkdir(parents=True, exist_ok=True)
        result = fct()
        pickle.dump(result, open(file_name, "wb"))
        return result


def get_contribution(config, env, mdp):

    if config.get("tabular_contribution", False):
        assert mdp is not None

        def action_value_model(x):
            return networks.ActionValueTabular(mdp=mdp)(x)

        def hindsight_model(x, y):
            return networks.HindsightTabular(
                mdp=mdp,
                reward_based=config.hindsight_feature_type == "reward_based",
                reward_values=env.reward_values,
            )(x, y)

        def value_model(x):
            return networks.ValueTabular(mdp=mdp)(x)

    else:

        def action_value_model(x):
            x = hk.Flatten(preserve_dims=1)(x)
            x = hk.nets.MLP(output_sizes=config.hidden_dim_qnet + (env.num_actions,))(x)
            return x

        use_binary_feature = config.get("use_binary_feature", False)
        def feature_model(x, action):
            assert len(x.shape) == 1

            gate = hk.get_parameter(
                name="gate",
                shape=(env.num_actions, config.hidden_dim_features, *x.shape),
                init=hk.initializers.RandomNormal(mean = config.threshold_shift),
            )

            def hard_sigmoid_st(x):
                return x + jax.lax.stop_gradient((x > config.threshold_shift) * 1.0 - x)

            feature = hard_sigmoid_st(gate[action].reshape(config.hidden_dim_features, *x.shape)) @ x
            if use_binary_feature:
                feature = hard_sigmoid_st(feature)
            return feature

        def value_model(x):
            x = hk.Flatten(preserve_dims=1)(x)
            x = hk.nets.MLP(output_sizes=config.hidden_dim_value + (1,))(x)
            return x

        if config.hindsight_model_type == "mlp":
            def hindsight_model(x, y, z):
                x = jnp.concatenate([x.flatten(), y.flatten(), z.flatten()], axis=0)
                x = hk.nets.MLP(output_sizes=config.hidden_dim_hindsight + (env.num_actions,))(x)
                return x

        elif config.hindsight_model_type == "hypernet":
            def hindsight_model(observations, hindsight_objects, policy_logits):
                x = jnp.concatenate([observations.flatten(), hindsight_objects.flatten()], axis=0)
                z = jax.nn.relu(hk.Linear(256)(x))

                logits = relu_gate(hk.Linear(2 * env.num_actions)(z))
                gate = relu_gate(hk.Linear(2 * 2 * env.num_actions * env.num_actions)(z))

                policy_logit_features = jnp.concatenate([
                    policy_logits,
                    jnp.log(1 - jax.nn.softmax(policy_logits, axis=-1))
                ])

                gate_proj = gate.reshape(env.num_actions, 2 * env.num_actions)
                gated_policy_logits = gate_proj @ policy_logit_features

                return logits + gated_policy_logits

        else:
            raise ValueError

    if config.contribution == "reinforce":
        if config.return_contribution == "advantage":
            value_module = value.ValueFunction(
                model=value_model,
                optimizer=getattr(optax, config.optimizer_value)(config.lr_contrib),
                steps=config.steps_value,
                td_lambda=config.lambda_value,
            )
        else:
            value_module = None

        contribution = reinforce.ReinforceContribution(
            num_actions=env.num_actions,
            obs_shape=env.observation_shape,
            return_type=config.return_contribution,
            value_module=value_module,
        )
    elif config.contribution == "reinforce_gt":
        assert mdp is not None
        contribution = reinforce.ReinforceContributionGT(
            num_actions=env.num_actions,
            obs_shape=env.observation_shape,
            return_type=config.return_contribution,
            mdp=mdp,
        )
    elif config.contribution == "qnet":
        contribution = qnet.QNetContribution(
            num_actions=env.num_actions,
            obs_shape=env.observation_shape,
            action_value_model=action_value_model,
            action_value_optimizer=getattr(optax, config.optimizer_qnet)(config.lr_contrib),
            return_type=config.return_contribution,
            action_value_steps=config.steps_qnet,
            td_lambda=config.lambda_qnet,
        )
    elif config.contribution == "qnet_gt":
        assert mdp is not None
        contribution = qnet.QNetContributionGT(
            num_actions=env.num_actions,
            obs_shape=env.observation_shape,
            return_type=config.return_contribution,
            mdp=mdp,
        )
    elif config.contribution == "causal":
        if config.hindsight_feature_type == "reward_predictor":

            feature_module = features.RewardFeatures(
                num_actions=env.num_actions,
                backbone=feature_model,
                optimizer=getattr(optax, config.optimizer_features)(config.lr_features, weight_decay=0),
                steps=config.steps_features,
                reward_values=env.reward_values,
                l1_reg_activation=config.get("l1_reg_activation_features", 0.0),
                l2_reg_activation=config.get("l2_reg_activation_features", 0.0),
                l1_reg_params=config.get("l1_reg_params_features", 0.0),
                l2_reg_readout=config.get("l2_reg_readout_feature", 0.0),
                balance_loss=config.get("balance_loss_features", False),
                mask_zero_reward_loss=config.get("mask_zero_reward_loss_features", False),
                use_mse=config.get("use_mse_feature", False),
            )
        else:
            feature_module = None

        if config.use_baseline:
            value_module = value.ValueFunction(
                model=value_model,
                optimizer=getattr(optax, config.optimizer_value)(config.lr_contrib),
                steps=config.steps_value,
                td_lambda=config.lambda_value,
            )
        else:
            value_module = None

        contribution = causal.CausalContribution(
            num_actions=env.num_actions,
            obs_shape=env.observation_shape,
            return_type=config.return_contribution,
            hindsight_feature_type=config.hindsight_feature_type,
            reward_values=env.reward_values,
            reward_clusters=config.get("reward_clusters", None),
            g_trick=config.get("g_trick", config.hindsight_feature_type == "state_based"),
            use_baseline=config.use_baseline,
            hindsight_model=hindsight_model,
            hindsight_optimizer=getattr(optax, config.optimizer_hindsight)(config.lr_contrib),
            hindsight_steps=config.steps_hindsight,
            hindsight_loss_type=config.hindsight_loss_type,
            modulate_with_policy=config.get("policy_modulation", None),
            contribution_clip=config.get("contribution_clip", None),
            feature_module=feature_module,
            value_module=value_module,
            balance_loss=config.get("balance_loss", False),
            mask_zero_reward_loss=config.get("mask_zero_reward_loss", False),
            clip_contrastive=config.get("clip_contrastive", False),
            max_grad_norm=config.get("hindsight_max_grad_norm", None),
        )

    elif config.contribution == "causal_gt":

        if config.hindsight_feature_type == "reward_predictor":
            feature_module = features.RewardFeatures(
                num_actions=env.num_actions,
                backbone=feature_model,
                optimizer=getattr(optax, config.optimizer_features)(config.lr_features, weight_decay=0),
                steps=config.steps_features,
                reward_values=env.reward_values,
                l1_reg_activation=config.get("l1_reg_activation_features", 0.0),
                l2_reg_activation=config.get("l2_reg_activation_features", 0.0),
                l1_reg_params=config.get("l1_reg_params_features", 0.0),
                l2_reg_readout=config.get("l2_reg_readout_feature", 0.0),
                balance_loss=config.get("balance_loss_features", False),
                mask_zero_reward_loss=config.get("mask_zero_reward_loss_features", False),
            )
        else:
            feature_module = None

        assert mdp is not None
        contribution = causal.CausalContributionGT(
            num_actions=env.num_actions,
            obs_shape=env.observation_shape,
            return_type=config.return_contribution,
            hindsight_feature_type=config.hindsight_feature_type,
            reward_values=env.reward_values,
            reward_clusters=config.get("reward_clusters", None),
            g_trick=config.get("g_trick", config.return_contribution == "state_based"),
            use_baseline=config.use_baseline,
            mdp=mdp,
            feature_module=feature_module,
        )

    elif config.contribution == "parallel":
        contribution_dict = dict()
        for k in config.parallel_keys.split(","):
            contribution_dict[k] = get_contribution(getattr(config, k), env, mdp)

        contribution = parallel.ParallelContribution(
            contribution_dict, config.parallel_main_key, config.parallel_reset_before_update
        )
    else:
        raise ValueError('Contribution module "{}" undefined'.format(config.contribution))

    return contribution


def run(config, logger, logdir, log_level):
    config = mlc.ConfigDict(config)

    # HACK: remove reward_values and env_id from kwargs for env creation
    switch_env = config.get("env_switch_episode", 0) > 0
    env = envs.create(
        **{
            k: config.environment[k]
            for k in set(list(config.environment.keys())) - set(["reward_values", "env_id"])
        }
    )
    if switch_env:
        env_bis_config = {
                k: config.environment[k]
                for k in set(list(config.environment.keys())) - set(["reward_values", "env_id"])
            }
        env_bis_config["reward_treasure"] = [-r for r in env_bis_config["reward_treasure"]]
        env_bis = envs.create(
        **{
            k: env_bis_config[k]
            for k in set(list(env_bis_config.keys())) - set(["reward_values", "env_id"])
        })
    else:
        env_bis = env

    if config.get("compute_mdp", False) or log_level > 2:
        mdp = load_from_cache_or_compute(
            os.path.join("cache", "mdp_" + config.environment["env_id"] + ".pkl"),
            partial(MDP_DP, env),
        )

        if switch_env:
            mdp_bis = load_from_cache_or_compute(
                os.path.join("cache", "mdp_bis_" + config.environment["env_id"] + ".pkl"),
                partial(MDP_DP, env_bis),
            )
        else:
            mdp_bis = None
    else:
        mdp = None

    if config.get("tabular_agent", False):

        def policy_model(x):
            return networks.PolicyTabular(mdp=mdp)(x)

    else:

        def policy_model(x):
            chex.assert_equal(x.shape[1:], env.observation_shape)  # assumes a leading batch dim
            x = hk.Flatten(preserve_dims=1)(x)
            x = hk.nets.MLP(output_sizes=config.hidden_dim_agent + (env.num_actions,))(x)
            return x

    agent = agents.PolicyGradient(
        num_actions=env.num_actions,
        obs_shape=env.observation_shape,
        policy=policy_model,
        optimizer=getattr(optax, config.optimizer_agent)(config.lr_agent),
        loss_type=config.pg_loss,
        num_sample=config.pg_num_sample,
        entropy_reg=config.entropy_reg,
        max_grad_norm=config.get("pg_norm", None),
        epsilon=config.get("epsilon_exploration", 0),
        epsilon_at_eval=config.get("epsilon_exploration_eval", False),
    )

    buffer = ReplayBuffer(config.buffer_size)
    if config.get("multi_replay", False):
        offline_buffer = MultiReplayBuffer(config.offline_buffer_size, env.reward_values)
    else:
        offline_buffer = ReplayBuffer(config.offline_buffer_size)
    contribution = get_contribution(config, env, mdp)

    runner = Experiment(
        agent=agent,
        contribution=contribution,
        env=env,
        env_switched=env_bis,
        env_switch_episode = config.get("env_switch_episode", 0),
        buffer=buffer,
        offline_buffer = offline_buffer,
        num_episodes=config.num_episodes,
        max_trials=config.environment.length,
        batch_size=config.batch_size,
        offline_batch_size=config.offline_batch_size,
        burnin_episodes=config.get("burnin_episodes", 0),
        logger=logger,
        logdir=logdir,
        eval_interval_episodes=config.eval_interval_episodes,
        eval_batch_size=config.eval_batch_size,
        log_level=log_level,
    )
    if config.seed is None:
        config.seed = random.randint(0, 99999)
    runner_state = runner.reset(jax.random.PRNGKey(config.seed))

    if switch_env:
        runner.add_callback(
            Event.EVAL_EPISODE,
            make_policy_var_callback(
                mdp, first_state_only=config.get("policy_var_first_state_only", False), prefix="first"
            ),
            log_level=2,
        )
        runner.add_callback(
            Event.EVAL_EPISODE,
            make_policy_var_callback(
                mdp_bis, first_state_only=config.get("policy_var_first_state_only", False), prefix="second"
            ),
            log_level=2,
        )
    else:
        runner.add_callback(
            Event.EVAL_EPISODE,
            make_policy_var_callback(
                mdp, first_state_only=config.get("policy_var_first_state_only", False), prefix=""
            ),
            log_level=2,
        )

    if config.environment.name in (
        "conveyor",
        "conveyor_onehot",
        "conveyor_pixel",
        "treasure_conveyor",
        "multi_treasure_conveyor",
    ):
        runner.add_callback(Event.EVAL_EPISODE, performance_conveyor, log_level=1)

    if config.environment.name in (
        "multiple_conveyor",
        "multiple_conveyor_pixel",
        "multiple_conveyor_compressed",
    ):
        runner.add_callback(Event.EVAL_EPISODE, performance_multiple_conveyor, log_level=1)

    if (
        config.contribution == "causal"
        or config.contribution == "causal_gt"
        or config.contribution == "parallel"
    ):
        runner.add_callback(Event.EVAL_EPISODE, visualize_contribution, log_level=3)
        runner.add_callback(
            Event.EVAL_EPISODE,
            make_contribution_bias_callback(
                mdp, first_state_only=config.get("policy_var_first_state_only", False)
            ),
            log_level=3,
        )

    runner.run(runner_state)


def main(argv):
    # Get config from flags
    config = flags.FLAGS.config

    # Setup logger
    wandb.init(config=config)

    run(config, logger=wandb, logdir=FLAGS.logdir, log_level=FLAGS.log_level)


if __name__ == "__main__":
    app.run(main)
