"""
Based on PureJaxRL Implementation of PPO
"""

import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal, xavier_normal, normal, kaiming_normal
from typing import Sequence, NamedTuple, Any, Dict, Union
from flax.training.train_state import TrainState
import distrax
import hydra
import flax
from omegaconf import DictConfig, OmegaConf
from safetensors.flax import save_file
from flax.traverse_util import flatten_dict
from flax import traverse_util
from jaxmarl.wrappers.baselines import SMAXLogWrapper
from jaxmarl.environments.smax import map_name_to_scenario, HeuristicEnemySMAX

import wandb
import functools
import matplotlib.pyplot as plt
import os

from gnn_module.gnn import End2EndGCN

class ScannedRNN(nn.Module):
    @functools.partial(
        nn.scan,
        variable_broadcast="params",
        in_axes=0,
        out_axes=0,
        split_rngs={"params": False},
    )
    @nn.compact
    def __call__(self, carry, x):
        """Applies the module."""
        rnn_state = carry
        ins, resets = x
        rnn_state = jnp.where(
            resets[:, np.newaxis],
            self.initialize_carry(*rnn_state.shape),
            rnn_state,
        )
        new_rnn_state, y = nn.GRUCell(features=ins.shape[1])(rnn_state, ins)
        return new_rnn_state, y

    @staticmethod
    def initialize_carry(batch_size, hidden_size):
        # Use a dummy key since the default state init fn is just zeros.
        cell = nn.GRUCell(features=hidden_size)
        return cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, hidden_size))


def get_initializer(config):
    if config["INITIALIZER"] == "orthogonal":
        return orthogonal()
    elif config["INITIALIZER"] == "xavier_normal":
        return xavier_normal()
    elif config["INITIALIZER"] == "normal_0.01":
        return normal(0.01)
    elif config["INITIALIZER"] == "normal_0.05":
        return normal(0.05)
    elif config["INITIALIZER"] == "kaiming_normal":
        return kaiming_normal()


class ActorCriticRNN(nn.Module):
    action_dim: Sequence[int]
    config: Dict

    @nn.compact
    def __call__(self, hidden, x):
        obs, dones, avail_actions = x
        raw_obs = obs[:, :, :132]
        gnn_embedding = obs[:, :, 132:]
        #
        # jax.debug.print('gnn features : gnn {x}', x=gnn_embedding)
        initializer = get_initializer(self.config)
        embedding = nn.Dense(
            self.config["FC_DIM_SIZE"], kernel_init=initializer, bias_init=constant(0.0)
        )(raw_obs)
        embedding = nn.relu(embedding)
        # jax.debug.print(f'hidden shape {hidden.shape}')
        # jax.debug.print(f'emb shape {embedding.shape}')
        # jax.debug.print(f'done shape {dones.shape}')

        rnn_in = (embedding, dones)
        # jax.debug.print("shape {x}", x=hidden.shape)
        hidden, embedding = ScannedRNN()(hidden, rnn_in)

        actor_mean = nn.Dense(self.config["GRU_HIDDEN_DIM"], kernel_init=initializer, bias_init=constant(0.0))(
            embedding
        )
        actor_mean = nn.relu(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=initializer, bias_init=constant(0.0)
        )(actor_mean)
        unavail_actions = 1 - avail_actions
        action_logits = actor_mean - (unavail_actions * 1e10)

        # pi = distrax.Categorical(logits=action_logits)
        critic_input = jnp.concatenate([raw_obs, gnn_embedding], axis=-1)
        critic = nn.Dense(self.config["FC_DIM_SIZE"], kernel_init=initializer, bias_init=constant(0.0))(
            critic_input
        )
        critic = nn.relu(critic)
        critic = nn.Dense(1, kernel_init=initializer, bias_init=constant(0.0))(
            critic
        )

        # jax.debug.print("logits shape {x}", x=action_logits.shape)
        # # jax.debug.print("pi shape {x}", x=pi.shape)
        # jax.debug.print("critic shape {x}", x= critic.shape)
        return hidden, action_logits, jnp.squeeze(critic, axis=-1)
class AgentModel(nn.Module):
    action_dim: int
    num_all_agents: int
    config: Any
    temperature: float
    c_out: int = 16
    obs_encoder_dim: int = 32

    def setup(self):
        """Initialize GNN and RNN modules."""
        self.gnn = End2EndGCN(c_out=self.c_out, num_all_agents=self.num_all_agents,
                              obs_encoder_dim=self.obs_encoder_dim, temperature=self.temperature)
        self.agent_rnn = ActorCriticRNN(self.action_dim, self.config)

    def __call__(self, hidden, x):
        """Process observations and actions through GNN and RNN."""
        obs, dones, avail_actions = x

        # Process observation through GNN (details omitted)
        gnn_features = self.gnn(obs[:, :, :140])

        # Concatenate GNN features with raw observations (details omitted)
        concatenated_features = jnp.concatenate([obs, gnn_features], axis=-1)

        # Process concatenated features through RNN
        new_x = (concatenated_features, dones, avail_actions)
        hidden, action_logits, critic = self.agent_rnn(hidden, new_x)

        return hidden, action_logits, critic


class CombinedActorCriticRNN(nn.Module):
    action_dim: int
    num_all_agents: int
    config: Any
    temperature: float
    obs_encoder_dim: int
    c_out: int = 8

    def setup(self):
        """Initialize two separate models for different agent groups."""
        self.actor_critic_group_1 = AgentModel(self.action_dim, self.num_all_agents, self.config,
                                               self.temperature, self.c_out, self.obs_encoder_dim)
        self.actor_critic_group_2 = AgentModel(self.action_dim, self.num_all_agents, self.config,
                                               self.temperature, self.c_out, self.obs_encoder_dim)

    @nn.compact
    def __call__(self, hidden, inputs):
        """Process agent groups separately and combine results."""
        obs, dones, avail_actions = inputs

        # Extract agent IDs (details omitted)
        agent_ids = obs[:, :, -1].astype(jnp.int32)
        obs = obs[:, :, :-1]

        # Mask agents into two groups (details omitted)
        mask_1 = (agent_ids <= 2).astype(jnp.float32)
        mask_2 = (agent_ids >= 3).astype(jnp.float32)

        # Process first group of agents
        x_1 = self._mask_and_process(obs, avail_actions, dones, hidden, mask_1)
        hidden_1, action_logits_1, critic_1 = self.actor_critic_group_1(hidden, x_1)

        # Process second group of agents
        x_2 = self._mask_and_process(obs, avail_actions, dones, hidden, mask_2)
        hidden_2, action_logits_2, critic_2 = self.actor_critic_group_2(hidden, x_2)

        # Combine hidden states, action logits, and critic values (details omitted)
        new_hidden = self._combine_hidden_states(hidden_1, hidden_2, mask_1, mask_2)
        combined_logits = self._combine_logits(action_logits_1, action_logits_2, mask_1, mask_2)
        combined_critic = self._combine_critic(critic_1, critic_2, mask_1, mask_2)

        # Create policy distribution
        pi = distrax.Categorical(logits=combined_logits)

        return new_hidden, pi, combined_critic

    def _mask_and_process(self, obs, avail_actions, dones, hidden, mask):
        """Helper function to apply masking and process agents."""
        # Details omitted
        pass

    def _combine_hidden_states(self, hidden_1, hidden_2, mask_1, mask_2):
        """Combine hidden states from both groups."""
        # Details omitted
        pass

    def _combine_logits(self, logits_1, logits_2, mask_1, mask_2):
        """Combine action logits from both groups."""
        # Details omitted
        pass

    def _combine_critic(self, critic_1, critic_2, mask_1, mask_2):
        """Combine critic values from both groups."""
        # Details omitted
        pass


class Transition(NamedTuple):
    global_done: jnp.ndarray
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray
    avail_actions: jnp.ndarray


def batchify(x: dict, agent_list, num_actors):
    x = jnp.stack([x[a] for a in agent_list])
    return x.reshape((num_actors, -1))


def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
    x = x.reshape((num_actors, num_envs, -1))
    return {a: x[i] for i, a in enumerate(agent_list)}


def make_train(config):
    scenario = map_name_to_scenario(config["MAP_NAME"])
    env = HeuristicEnemySMAX(scenario=scenario, **config["ENV_KWARGS"])
    config["NUM_ACTORS"] = env.num_agents * config["NUM_ENVS"]
    config["NUM_UPDATES"] = (
            config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
            config["NUM_ACTORS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    config["CLIP_EPS"] = (
        config["CLIP_EPS"] / env.num_agents
        if config["SCALE_CLIP_EPS"]
        else config["CLIP_EPS"]
    )

    env = SMAXLogWrapper(env)

    def linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["LR"] * frac

    def train(rng):
        # INIT NETWORK
        network = CombinedActorCriticRNN(env.action_space(env.agents[0]).n, num_all_agents=len(env.all_agents),
                                         config=config, c_out=config['C_OUT'], temperature=config['TEMPERATURE'],
                                         obs_encoder_dim=config['OBS_ENCODER_DIM'])
        rng, _rng = jax.random.split(rng)
        init_x = (
            jnp.zeros(
                (1, config["NUM_ACTORS"], env.observation_space(env.agents[0]).shape[0]+6)
            ),
            jnp.zeros((1, config["NUM_ACTORS"])),
            jnp.zeros((1, config["NUM_ACTORS"], env.action_space(env.agents[0]).n)),
        )
        init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], config["GRU_HIDDEN_DIM"])

        network_params = network.init(_rng, init_hstate, init_x)


        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(config["LR"], eps=1e-5),
            )

        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )



        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rng)
        init_hstate = ScannedRNN.initialize_carry(config["NUM_ACTORS"], config["GRU_HIDDEN_DIM"])

        # TRAIN LOOP
        def _update_step(update_runner_state, unused):
            # COLLECT TRAJECTORIES
            runner_state, update_steps = update_runner_state

            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, last_done, hstate, rng = runner_state
                # jax.debug.print(f'{last_obs}')
                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)
                avail_actions = jax.lax.stop_gradient(
                    batchify(avail_actions, env.agents, config["NUM_ACTORS"])
                )

                # jax.debug.print(f"avail aciton shape {avail_actions.shape}")
                # (N_AGENT * N_ENVs) , N_DIM (128)
                # 5 128 , 5 256, (127 + id)
                obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])

                # agent_ids = obs_batch[:, :, -1].astype(jnp.int32)
                # pre_policy_num = jnp.sum(agent_ids <= 2)
                # jax.debug.print("Sample ENV: obs shape {x} ", x=obs_batch.shape)
                # jax.debug.print("num {x}", x=pre_policy_num)
                ac_in = (
                    obs_batch[np.newaxis, :],
                    last_done[np.newaxis, :],
                    avail_actions,
                )
                hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)
                env_act = unbatchify(
                    action, env.agents, config["NUM_ENVS"], env.num_agents
                )
                env_act = {k: v.squeeze() for k, v in env_act.items()}

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = jax.vmap(
                    env.step, in_axes=(0, 0, 0)
                )(rng_step, env_state, env_act)
                info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
                done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
                transition = Transition(
                    jnp.tile(done["__all__"], env.num_agents),
                    last_done,
                    action.squeeze(),
                    value.squeeze(),
                    batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
                    log_prob.squeeze(),
                    obs_batch,
                    info,
                    avail_actions,
                )
                runner_state = (train_state, env_state, obsv, done_batch, hstate, rng)
                return runner_state, transition

            initial_hstate = runner_state[-2]


            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # CALCULATE ADVANTAGE
            train_state, env_state, last_obs, last_done, hstate, rng = runner_state
            last_obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
            avail_actions = jnp.ones(
                (config["NUM_ACTORS"], env.action_space(env.agents[0]).n)
            )
            ac_in = (
                last_obs_batch[np.newaxis, :],
                last_done[np.newaxis, :],
                avail_actions,
            )
            _, _, last_val = network.apply(train_state.params, hstate, ac_in)
            last_val = last_val.squeeze()

            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.global_done,
                        transition.value,
                        transition.reward,
                    )
                    delta = reward + config["GAMMA"] * next_value * (1 - done) - value
                    gae = (
                            delta
                            + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                    )
                    return (gae, value), gae

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):

                def _update_minbatch(train_state, batch_info):
                    init_hstate, traj_batch, advantages, targets = batch_info

                    def _loss_fn(params, init_hstate, traj_batch, gae, targets):
                        # RERUN NETWORK
                        _, pi, value = network.apply(
                            params,
                            init_hstate.squeeze(),
                            (traj_batch.obs, traj_batch.done, traj_batch.avail_actions),
                        )
                        log_prob = pi.log_prob(traj_batch.action)

                        # agent_ids = traj_batch.obs[ :, -1].astype(jnp.int32)
                        # pre_policy_num = jnp.sum(agent_ids <= 2)
                        # jax.debug.print("Train ENV: obs shape {x} ", x=traj_batch.obs.shape)
                        # jax.debug.print("num {x}", x=pre_policy_num)

                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                                value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = 0.5 * jnp.maximum(
                            value_losses, value_losses_clipped
                        ).mean()

                        # CALCULATE ACTOR LOSS
                        logratio = log_prob - traj_batch.log_prob
                        ratio = jnp.exp(logratio)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor1 = ratio * gae
                        loss_actor2 = (
                                jnp.clip(
                                    ratio,
                                    1.0 - config["CLIP_EPS"],
                                    1.0 + config["CLIP_EPS"],
                                )
                                * gae
                        )
                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                        loss_actor = loss_actor.mean()
                        entropy = pi.entropy().mean()

                        # debug
                        approx_kl = ((ratio - 1) - logratio).mean()
                        clip_frac = jnp.mean(jnp.abs(ratio - 1) > config["CLIP_EPS"])

                        total_loss = (
                                loss_actor
                                + config["VF_COEF"] * value_loss
                                - config["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy, ratio, approx_kl, clip_frac)

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grads = grad_fn(
                        train_state.params, init_hstate, traj_batch, advantages, targets
                    )
                    train_state = train_state.apply_gradients(grads=grads)
                    return train_state, total_loss

                (
                    train_state,
                    init_hstate,
                    traj_batch,
                    advantages,
                    targets,
                    rng,
                ) = update_state
                rng, _rng = jax.random.split(rng)

                # adding an additional "fake" dimensionality to perform minibatching correctly
                init_hstate = jnp.reshape(
                    init_hstate, (1, config["NUM_ACTORS"], -1)
                )
                batch = (
                    init_hstate,
                    traj_batch,
                    advantages.squeeze(),
                    targets.squeeze(),
                )
                permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=1), batch
                )
                # tuple = jax tree map (tuple)
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.swapaxes(
                        jnp.reshape(
                            x,
                            [x.shape[0], config["NUM_MINIBATCHES"], -1]
                            + list(x.shape[2:]),
                        ),
                        1,
                        0,
                    ),
                    shuffled_batch,
                )
                # jax.debug.print("shuffled_batch shape {x} ", x=(shuffled_batch[1].obs))
                # jax.debug.print("shuffled_batch shape {x} ", x=(shuffled_batch[0]))
                # jax.debug.print("shuffled_batch shape {x} 0 ", x=shuffled_batch[0].shape)
                # jax.debug.print("shuffled_batch shape {x} 1 ", x=shuffled_batch[1].shape)
                # jax.debug.print("shuffled_batch shape {x} 2 ", x=shuffled_batch[2].shape)
                # jax.debug.print("shuffled_batch shape {x} 3 ", x=shuffled_batch[3].shape)
                #
                # jax.debug.print("1 {x}", x=minibatches[0].shape)

                train_state, total_loss = jax.lax.scan(
                    _update_minbatch, train_state, minibatches
                )
                update_state = (
                    train_state,
                    init_hstate.squeeze(),
                    traj_batch,
                    advantages,
                    targets,
                    rng,
                )
                return update_state, total_loss

            update_state = (
                train_state,
                initial_hstate,
                traj_batch,
                advantages,
                targets,
                rng,
            )
            update_state, loss_info = jax.lax.scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            )
            train_state = update_state[0]
            metric = traj_batch.info
            metric = jax.tree_map(
                lambda x: x.reshape(
                    (config["NUM_STEPS"], config["NUM_ENVS"], env.num_agents)
                ),
                traj_batch.info,
            )
            ratio_0 = loss_info[1][3].at[0, 0].get().mean()
            loss_info = jax.tree_map(lambda x: x.mean(), loss_info)
            metric["loss"] = {
                "total_loss": loss_info[0],
                "value_loss": loss_info[1][0],
                "actor_loss": loss_info[1][1],
                "entropy": loss_info[1][2],
                "ratio": loss_info[1][3],
                "ratio_0": ratio_0,
                "approx_kl": loss_info[1][4],
                "clip_frac": loss_info[1][5],
            }

            rng = update_state[-1]

            def callback(metric):
                wandb.log(
                    {
                        # the metrics have an agent dimension, but this is identical
                        # for all agents so index into the 0th item of that dimension.
                        "returns": metric["returned_episode_returns"][:, :, 0][
                            metric["returned_episode"][:, :, 0]
                        ].mean(),
                        "win_rate": metric["returned_won_episode"][:, :, 0][
                            metric["returned_episode"][:, :, 0]
                        ].mean(),
                        "env_step": metric["update_steps"]
                                    * config["NUM_ENVS"]
                                    * config["NUM_STEPS"],
                        **metric["loss"],
                    }
                )

            metric["update_steps"] = update_steps
            jax.experimental.io_callback(callback, None, metric)
            update_steps = update_steps + 1
            runner_state = (train_state, env_state, last_obs, last_done, hstate, rng)
            return (runner_state, update_steps), metric

        rng, _rng = jax.random.split(rng)
        runner_state = (
            train_state,
            env_state,
            obsv,
            jnp.zeros((config["NUM_ACTORS"]), dtype=bool),
            init_hstate,
            _rng,
        )
        runner_state, metric = jax.lax.scan(
            _update_step, (runner_state, 0), None, config["NUM_UPDATES"]
        )
        return {"runner_state": runner_state}

    return train



def single_run(config):



    run = wandb.init(
        entity=config["ENTITY"],
        project=config["PROJECT"],
        tags=["IPPO", "RNN"],
        config=config,
        mode=config["WANDB_MODE"],
    )
    rng = jax.random.PRNGKey(config["SEED"])
    train_jit = jax.jit(make_train(config), device=jax.devices()[0])
    out = train_jit(rng)

    # Define algorithm and environment names
    alg_name = "ippo"
    env_name = "smax"

    # Save params
    if config['SAVE_PATH'] is not None:

        def save_params(params: Dict, filename: Union[str, os.PathLike]) -> None:
            flattened_dict = flatten_dict(params, sep=',')
            save_file(flattened_dict, filename)

        params = out['runner_state'][0][0].params
        save_dir = os.path.join(config['SAVE_PATH'], run.project, run.name)
        os.makedirs(save_dir, exist_ok=True)
        save_params(params, f'{save_dir}/model.safetensors')
        print(f'Parameters of first batch saved in {save_dir}/model.safetensors')

        # upload this to wandb as an artifact
        artifact = wandb.Artifact(f'{run.name}-checkpoint', type='checkpoint')
        artifact.add_file(f'{save_dir}/model.safetensors')
        artifact.save()



def tune(default_config):
    """Hyperparameter sweep with wandb."""

    # default_config = {**default_config, **default_config["alg"]}  # merge the alg config with the main config


    def wrapped_make_train():
        import copy
        wandb.init(project=default_config["PROJECT"])

        # update the default params
        config = copy.deepcopy(default_config)
        for k, v in dict(wandb.config).items():
            config[k] = v

        print("running experiment with params:", config)

        rng = jax.random.PRNGKey(config["SEED"])

        train_vjit = jax.jit(make_train(config),  device=jax.devices()[0])
        outs = train_vjit(rng)


    id = default_config["TUNED_CONFIG_ID"]
    sweep_config = {
        "name": f"tuned_ippo_smax_config_{id}",
        "method": "grid",
        "metric": {
            "name": "win_rate",
            "goal": "maximize",
        },

        "parameters": {
            'SEED': {'values': [0,1,2,3,4,5,6,7,8,9]}
        }
    }

    wandb.login()
    sweep_id = wandb.sweep(
        sweep_config, entity=default_config["ENTITY"], project=default_config["PROJECT"]
    )
    if default_config['SWEEP_ID'] is not None:
        sweep_id = default_config['SWEEP_ID']
    else:
        sweep_id = sweep_id

    wandb.agent(sweep_id, wrapped_make_train,entity=default_config["ENTITY"], project=default_config["PROJECT"], count=300)



@hydra.main(version_base=None, config_path="config")
def main(config):
    config = OmegaConf.to_container(config)
    print("Config:\n", OmegaConf.to_yaml(config))
    if config["HYP_TUNE"]:
        tune(config)
    else:
        single_run(config)


if __name__ == "__main__":
    main()