import torch
import numpy as np
import psutil


from garage import wrap_experiment
from garage.envs import GymEnv, normalize
from garage.experiment.deterministic import set_seed
from garage.replay_buffer import PathBuffer
from garage.np.exploration_policies import EpsilonGreedyPolicy
from garage.torch.policies import (
    GaussianMLPPolicy,
    DiscreteQFArgmaxPolicy,
)
from garage.torch.value_functions import GaussianMLPValueFunction
from garage.torch.q_functions import ContinuousMLPQFunction
from garage.torch.algos import PPO, DQN
from garage.torch import set_gpu_mode
from garage.trainer import Trainer
from garage.sampler import (
    LocalSampler,
    RaySampler,
    FragmentWorker,
    DefaultWorker,
)

from environments.gym_env import ArgsGymEnv

from environments.wrappers import DiscreteWrapper
from learning.policies.multi_policy_wrapper import MultiPolicyWrapper
from learning.policies.heuristic_mixture_policies_wrapper import (
    HeuristicMixtureOfPoliciesWrapper,
)
from learning.policies.mixture_policies_wrapper import MixtureOfPoliciesWrapper
from learning.policies.env_partition_policy import EnvPartitionPolicy
from learning.algorithms import (
    SAC,
    HDnCSAC,
    DnCPPO,
    DnCCEM,
    MoPDnC,
    MoPDnCv2,
    DistralSAC,
    DistralSQLDiscrete,
)

from learning.policies.two_column_wrapper import (
    DiscreteTwoColumnWrapper,
)

from learning.policies import (
    NamedTanhGaussianMLPPolicy,
    DiscreteMLPQFunction,
    CategoricalMLPPolicy,
)
from learning.utils.path_buffers import DnCPathBuffer, HDnCPathBuffer, MoPPathBuffer
from learning.utils.visualizer import Visualizer
from experiment_utils import (
    init_wandb,
    get_mt_envs,
    get_policies_and_qfs,
)
import environments


### 8/5/22 ###


def run_dnc_cem_test(config, env_args, agent_args, cem_configs):
    function_name = "_".join(
        ("dnc_cem", config.env, config.name.replace(" ", "-"), str(config.seed))
    )

    @wrap_experiment(
        archive_launch_repo=False, name=function_name, snapshot_mode="none"
    )
    def dnc_cem_test(
        ctxt=None, config=None, env_args=None, agent_args=None, cem_configs=None
    ):

        assert (
            config.policy_architecture == "separate"
            and config.Q_architecture == "separate"
        )
        set_seed(config.seed)
        trainer = Trainer(snapshot_config=ctxt)
        ### HACK
        init_wandb(
            config, {**env_args, **cem_configs}, trainer._snapshotter.snapshot_dir
        )

        (
            base_env,
            env_spec,
            train_envs,
            test_envs,
            split_observation,
            policy_assigner,
        ) = get_mt_envs(config.env, env_args)

        (
            policies,
            num_policy_parameters,
            qf1s,
            qf2s,
            num_Q_parameters,
        ) = get_policies_and_qfs(config, env_spec, policy_assigner, split_observation)

        print(
            f"Num parameters in policy: {num_policy_parameters}, Num parameters in Q functions: {num_Q_parameters}"
        )

        cem_policies = [
            NamedTanhGaussianMLPPolicy(
                env_spec=env_spec,
                hidden_sizes=config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
                output_nonlinearity=None,
                min_std=np.exp(-20.0),
                max_std=np.exp(2.0),
                name="LocalPolicy{}".format(i),
            )
            for i in range(config.n_policies)
        ]

        if config.stagewise:
            assert hasattr(base_env, "get_stage_id")
            get_stage_id = base_env.get_stage_id
        else:
            get_stage_id = None

        policy = MultiPolicyWrapper(policies, policy_assigner, split_observation)

        replay_buffers = DnCPathBuffer(
            num_buffers=config.n_policies,
            capacity_in_transitions=int(1e6),
            sampling_type=config.sampling_type,
        )

        n_workers = (
            len(train_envs)
            if isinstance(train_envs, list)
            else psutil.cpu_count(logical=False)
        )

        sampler = LocalSampler(
            agents=policy,
            envs=train_envs,
            max_episode_length=env_spec.max_episode_length,
            worker_class=DefaultWorker,
            n_workers=n_workers,
        )

        visualizer = Visualizer(
            freq=config.vis_freq,
            num_videos=config.vis_num,
            imsize=(config.vis_width, config.vis_height),
            fps=config.vis_fps,
            format=config.vis_format,
        )

        dnc_cem = DnCCEM(
            cem_policies=cem_policies,
            cem_configs=cem_configs,
            env_spec=env_spec,
            policy=policy,
            policies=policies,
            qf1s=qf1s,
            qf2s=qf2s,
            policy_lr=config.lr,
            qf_lr=config.lr,
            sampler=sampler,
            visualizer=visualizer,
            get_stage_id=get_stage_id,
            preproc_obs=split_observation,
            initial_kl_coeff=config.kl_coeff,
            sampling_type=config.sampling_type,
            gradient_steps_per_itr=agent_args["gradient_steps_per_itr"],
            max_episode_length_eval=env_spec.max_episode_length,
            replay_buffers=replay_buffers,
            min_buffer_size=agent_args["min_buffer_size"],
            target_update_tau=agent_args["target_update_tau"],
            discount=agent_args["discount"],
            buffer_batch_size=agent_args["buffer_batch_size"],
            reward_scale=agent_args["reward_scale"],
            steps_per_epoch=agent_args["steps_per_epoch"],
            eval_env=test_envs,
            num_evaluation_episodes=config.num_evaluation_episodes,
            evaluation_frequency=config.evaluation_frequency,
        )

        if torch.cuda.is_available() and config.gpu is not None:
            set_gpu_mode(True, gpu_id=config.gpu)
        else:
            set_gpu_mode(False)
        dnc_cem.to()

        trainer.setup(algo=dnc_cem, env=train_envs)
        trainer.train(n_epochs=config.n_epochs, batch_size=config.batch_size)

    dnc_cem_test(
        config=config, agent_args=agent_args, env_args=env_args, cem_configs=cem_configs
    )


def run_mop_dnc_test(
    config,
    env_args,
    agent_args,
):
    function_name = "_".join(
        ("dnc_mop", config.env, config.name.replace(" ", "-"), str(config.seed))
    )

    @wrap_experiment(
        archive_launch_repo=False, name=function_name, snapshot_mode="none"
    )
    def mop_dnc_test(
        ctxt=None,
        config=None,
        env_args=None,
        agent_args=None,
    ):
        assert (
            config.policy_architecture == "separate"
            and config.Q_architecture == "separate"
        )
        set_seed(config.seed)
        trainer = Trainer(snapshot_config=ctxt)
        init_wandb(config, env_args, trainer._snapshotter.snapshot_dir)
        (
            base_env,
            env_spec,
            train_envs,
            test_envs,
            split_observation,
            policy_assigner,
        ) = get_mt_envs(config.env, env_args)

        (
            policies,
            num_policy_parameters,
            qf1s,
            qf2s,
            num_Q_parameters,
        ) = get_policies_and_qfs(config, env_spec, policy_assigner, split_observation)

        print(
            f"Num parameters in policy: {num_policy_parameters}, Num parameters in Q functions: {num_Q_parameters}"
        )

        if config.stagewise:
            assert hasattr(base_env, "get_stage_id")
            get_stage_id = base_env.get_stage_id
            mixture_probs = np.array(config.mixture_probs).reshape(
                (config.n_policies, base_env.num_stages, config.n_policies)
            )
        else:
            get_stage_id = None
            mixture_probs = np.array(config.mixture_probs).reshape(
                (config.n_policies, config.n_policies)
            )

        policy = MixtureOfPoliciesWrapper(
            policies,
            policy_assigner,
            mixture_probs,
            config.policy_sampling_freq,
            split_observation,
            config.train_samples_by_task,
            config.stagewise,
            get_stage_id,
        )

        replay_buffers = DnCPathBuffer(
            num_buffers=config.n_policies,
            capacity_in_transitions=int(1e6),
            sampling_type=config.sampling_type,
        )

        n_workers = (
            len(train_envs)
            if isinstance(train_envs, list)
            else psutil.cpu_count(logical=False)
        )
        sampler_cls = RaySampler if config.ray else LocalSampler
        sampler = sampler_cls(
            agents=policy,
            envs=train_envs,
            max_episode_length=env_spec.max_episode_length,
            worker_class=DefaultWorker,
            n_workers=n_workers,
        )

        visualizer = Visualizer(
            freq=config.vis_freq,
            num_videos=config.vis_num,
            imsize=(config.vis_width, config.vis_height),
            fps=config.vis_fps,
            format=config.vis_format,
        )

        mop_dnc = MoPDnC(
            env_spec=env_spec,
            policy=policy,
            policies=policies,
            qf1s=qf1s,
            qf2s=qf2s,
            policy_lr=config.lr,
            qf_lr=config.lr,
            sampler=sampler,
            visualizer=visualizer,
            get_stage_id=get_stage_id,
            preproc_obs=split_observation,
            initial_kl_coeff=config.kl_coeff,
            sampling_type=config.sampling_type,
            gradient_steps_per_itr=agent_args["gradient_steps_per_itr"],
            max_episode_length_eval=env_spec.max_episode_length,
            replay_buffers=replay_buffers,
            min_buffer_size=agent_args["min_buffer_size"],
            target_update_tau=agent_args["target_update_tau"],
            discount=agent_args["discount"],
            buffer_batch_size=agent_args["buffer_batch_size"],
            reward_scale=agent_args["reward_scale"],
            steps_per_epoch=agent_args["steps_per_epoch"],
            eval_env=test_envs,
            num_evaluation_episodes=config.num_evaluation_episodes,
            evaluation_frequency=config.evaluation_frequency,
            regularize_representation=config.regularize_representation,
            min_task_probs=[0] * config.n_policies,
            mixture_warmup=config.n_epochs + 1,
            train_samples_by_task=config.train_samples_by_task,
        )

        if torch.cuda.is_available() and config.gpu is not None:
            set_gpu_mode(True, gpu_id=config.gpu)
        else:
            set_gpu_mode(False)
        mop_dnc.to()

        trainer.setup(algo=mop_dnc, env=train_envs)
        trainer.train(n_epochs=config.n_epochs, batch_size=config.batch_size)

    mop_dnc_test(config=config, env_args=env_args, agent_args=agent_args)


def run_hmop_dnc_v2_test(
    config,
    env_args,
    agent_args,
):
    function_name = "_".join(
        ("hmop_dnc_v2", config.env, config.name.replace(" ", "-"), str(config.seed))
    )

    @wrap_experiment(
        archive_launch_repo=False, name=function_name, snapshot_mode="none"
    )
    def hmop_dnc_v2_test(
        ctxt=None,
        config=None,
        env_args=None,
        agent_args=None,
    ):
        assert (
            config.policy_architecture == "separate"
            and config.Q_architecture == "separate"
        )
        set_seed(config.seed)
        trainer = Trainer(snapshot_config=ctxt)
        init_wandb(config, env_args, trainer._snapshotter.snapshot_dir)
        (
            base_env,
            env_spec,
            train_envs,
            test_envs,
            split_observation,
            policy_assigner,
        ) = get_mt_envs(config.env, env_args)

        (
            policies,
            num_policy_parameters,
            qf1s,
            qf2s,
            num_Q_parameters,
        ) = get_policies_and_qfs(config, env_spec, policy_assigner, split_observation)

        print(
            f"Num parameters in policy: {num_policy_parameters}, Num parameters in Q functions: {num_Q_parameters}"
        )

        if config.stagewise:
            assert hasattr(base_env, "get_stage_id")
            get_stage_id = base_env.get_stage_id
        else:
            get_stage_id = None
        assert (
            len(config.mixture_probs) == config.n_policies
            or len(config.mixture_probs) == 1
        )
        if len(config.mixture_probs) == 1:
            mixture_probs = config.mixture_probs * np.ones(config.n_policies)
        else:
            mixture_probs = np.array(config.mixture_probs)

        ### Defining Qi functions and using as score function for HMoP

        Qi1s = [
            ContinuousMLPQFunction(
                env_spec=env_spec,
                hidden_sizes=config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
            )
            for i in range(config.n_policies)
        ]

        Qi2s = [
            ContinuousMLPQFunction(
                env_spec=env_spec,
                hidden_sizes=config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
            )
            for i in range(config.n_policies)
        ]

        policy = HeuristicMixtureOfPoliciesWrapper(
            policies,
            config.policy_architecture,
            Qi1s,
            Qi2s,
            policy_assigner,
            np.ones_like(mixture_probs),
            config.Qfilter,
            sampling_freq=config.policy_sampling_freq,
            split_observation=split_observation,
            label_by_task=config.train_samples_by_task,
        )

        replay_buffers = DnCPathBuffer(
            num_buffers=config.n_policies,
            capacity_in_transitions=int(1e6),
            sampling_type=config.sampling_type,
        )

        policy_replay_buffers = MoPPathBuffer(
            num_buffers=config.n_policies,
            capacity_in_transitions=int(1e6) // config.n_policies,
        )

        n_workers = (
            len(train_envs)
            if isinstance(train_envs, list)
            else psutil.cpu_count(logical=False)
        )
        print("n_workers: ", n_workers, ", ray sampler: ", config.ray)
        sampler_cls = RaySampler if config.ray else LocalSampler
        sampler = sampler_cls(
            agents=policy,
            envs=train_envs,
            max_episode_length=env_spec.max_episode_length,
            worker_class=DefaultWorker,
            n_workers=n_workers,
        )

        visualizer = Visualizer(
            freq=config.vis_freq,
            num_videos=config.vis_num,
            imsize=(config.vis_width, config.vis_height),
            fps=config.vis_fps,
            format=config.vis_format,
        )

        dnc_sac = MoPDnCv2(
            env_spec=env_spec,
            policy=policy,
            policies=policies,
            qf1s=qf1s,
            qf2s=qf2s,
            Qi1s=Qi1s,
            Qi2s=Qi2s,
            policy_lr=config.lr,
            qf_lr=config.lr,
            sampler=sampler,
            visualizer=visualizer,
            get_stage_id=get_stage_id,
            preproc_obs=split_observation,
            initial_kl_coeff=config.kl_coeff,
            sampling_type=config.sampling_type,
            gradient_steps_per_itr=agent_args["gradient_steps_per_itr"],
            max_episode_length_eval=env_spec.max_episode_length,
            replay_buffers=replay_buffers,
            policy_replay_buffers=policy_replay_buffers,
            min_buffer_size=agent_args["min_buffer_size"],
            target_update_tau=agent_args["target_update_tau"],
            discount=agent_args["discount"],
            buffer_batch_size=agent_args["buffer_batch_size"],
            reward_scale=agent_args["reward_scale"],
            steps_per_epoch=agent_args["steps_per_epoch"],
            eval_env=test_envs,
            num_evaluation_episodes=config.num_evaluation_episodes,
            evaluation_frequency=config.evaluation_frequency,
            regularize_representation=config.regularize_representation,
            mixture_warmup=config.mixture_warmup,
            min_task_probs=mixture_probs,
        )

        if torch.cuda.is_available() and config.gpu is not None:
            set_gpu_mode(True, gpu_id=config.gpu)
        else:
            set_gpu_mode(False)
        dnc_sac.to()

        trainer.setup(algo=dnc_sac, env=train_envs)
        trainer.train(n_epochs=config.n_epochs, batch_size=config.batch_size)

    hmop_dnc_v2_test(config=config, env_args=env_args, agent_args=agent_args)


### 8/5/22 ###

### Injune ###


def run_sac_discrete_test(
    config,
    env_args,
    agent_args,
):
    function_name = "_".join(
        ("sac_discrete", config.env, config.name.replace(" ", "-"), str(config.seed))
    )

    @wrap_experiment(
        archive_launch_repo=False, name=function_name, snapshot_mode="none"
    )
    def sac_discrete_test(
        ctxt=None,
        config=None,
        env_args=None,
        agent_args=None,
    ):
        """Set up environment and algorithm and run the task.
        Args:
            ctxt (garage.experiment.ExperimentContext): The experiment
                configuration used by Trainer to create the snapshotter.
            seed (int): Used to seed the random number generator to produce
                determinism.
        """
        set_seed(config.seed)
        trainer = Trainer(snapshot_config=ctxt)
        init_wandb(config, env_args, trainer._snapshotter.snapshot_dir)
        base_env = ArgsGymEnv(config.env, env_args)
        base_env = DiscreteWrapper(
            env=base_env,
            n_actions=base_env.get_num_discrete_actions(),
            disc2cont=base_env.disc2cont,
        )
        env_spec = base_env.spec

        if hasattr(base_env, "get_train_envs"):
            train_envs = base_env.get_train_envs()
        else:
            train_envs = normalize(base_env)

        if hasattr(base_env, "get_test_envs"):
            test_envs = base_env.get_test_envs()
        else:
            test_envs = normalize(base_env)

        policy = CategoricalMLPPolicy(
            env_spec=env_spec,
            output_dim=env_spec.action_space.flat_dim,
            hidden_sizes=config.hidden_sizes,
            hidden_nonlinearity=torch.relu,
            output_nonlinearity=None,
        )

        qf1 = DiscreteMLPQFunction(
            env_spec=env_spec,
            hidden_sizes=config.hidden_sizes,
            hidden_nonlinearity=torch.relu,
        )

        qf2 = DiscreteMLPQFunction(
            env_spec=env_spec,
            hidden_sizes=config.hidden_sizes,
            hidden_nonlinearity=torch.relu,
        )

        replay_buffer = PathBuffer(capacity_in_transitions=int(1e6))

        n_workers = (
            len(train_envs)
            if isinstance(train_envs, list)
            else psutil.cpu_count(logical=False)
        )
        sampler_cls = RaySampler if config.ray else LocalSampler
        sampler = sampler_cls(
            agents=policy,
            envs=train_envs,
            max_episode_length=env_spec.max_episode_length,
            worker_class=DefaultWorker,
            n_workers=n_workers,
        )

        visualizer = Visualizer(
            freq=config.vis_freq,
            num_videos=config.vis_num,
            imsize=(config.vis_width, config.vis_height),
            fps=config.vis_fps,
            format=config.vis_format,
        )

        num_tasks = (
            getattr(base_env, "num_tasks", 1) if config.experiment == "mtsac" else 1
        )
        sac = SAC(
            env_spec=env_spec,
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            sampler=sampler,
            visualizer=visualizer,
            num_tasks=num_tasks,
            gradient_steps_per_itr=agent_args["gradient_steps_per_itr"],
            max_episode_length_eval=env_spec.max_episode_length,
            replay_buffer=replay_buffer,
            min_buffer_size=agent_args["min_buffer_size"],
            target_update_tau=agent_args["target_update_tau"],
            discount=agent_args["discount"],
            buffer_batch_size=agent_args["buffer_batch_size"],
            reward_scale=agent_args["reward_scale"],
            steps_per_epoch=agent_args["steps_per_epoch"],
            policy_lr=config.lr,
            qf_lr=config.lr,
            eval_env=test_envs,
            num_evaluation_episodes=config.num_evaluation_episodes,
        )

        if torch.cuda.is_available() and config.gpu is not None:
            set_gpu_mode(True, gpu_id=config.gpu)
        else:
            set_gpu_mode(False)
        sac.to()

        trainer.setup(algo=sac, env=train_envs)
        trainer.train(n_epochs=config.n_epochs, batch_size=config.batch_size)

    sac_discrete_test(config=config, env_args=env_args, agent_args=agent_args)


def run_distral_sac_test(
    config,
    env_args,
    agent_args,
):
    function_name = "_".join(
        ("distral_sac", config.env, config.name.replace(" ", "-"), str(config.seed))
    )

    @wrap_experiment(
        archive_launch_repo=False, name=function_name, snapshot_mode="none"
    )
    def distral_sac_test(
        ctxt=None,
        config=None,
        env_args=None,
        agent_args=None,
    ):
        """Set up environment and algorithm and run the task.
        Args:
            ctxt (garage.experiment.ExperimentContext): The experiment
                configuration used by Trainer to create the snapshotter.
            seed (int): Used to seed the random number generator to produce
                determinism.
        """
        set_seed(config.seed)
        trainer = Trainer(snapshot_config=ctxt)
        init_wandb(config, env_args, trainer._snapshotter.snapshot_dir)
        base_env = ArgsGymEnv(config.env, env_args)
        env_spec = base_env.spec

        if hasattr(base_env, "get_train_envs"):
            train_envs = base_env.get_train_envs()
        else:
            train_envs = normalize(base_env)

        if hasattr(base_env, "get_test_envs"):
            test_envs = base_env.get_test_envs()
        else:
            test_envs = normalize(base_env)

        central_policy = NamedTanhGaussianMLPPolicy(
            env_spec=env_spec,
            hidden_sizes=config.hidden_sizes,
            hidden_nonlinearity=torch.relu,
            output_nonlinearity=None,
            min_std=np.exp(-20.0),
            max_std=np.exp(2.0),
            name="CentralPolicy",
        )

        policies = [
            NamedTanhGaussianMLPPolicy(
                env_spec=env_spec,
                hidden_sizes=config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
                output_nonlinearity=None,
                min_std=np.exp(-20.0),
                max_std=np.exp(2.0),
                name="LocalPolicy{}".format(i),
            )
            for i in range(config.n_policies)
        ]

        # if config.two_column:
        #     kl_coeffs = (
        #         config.kl_coeffs if isinstance(config.kl_coeffs, list)
        #         else [config.kl_coeffs] * config.n_policies
        #     )
        #     policies = [
        #         TwoColumnWrapper(
        #             policy,
        #             central_policy,
        #             alpha=1,
        #             beta=kl_coeff,
        #         )
        #         for policy, kl_coeff in zip(policies, kl_coeffs)
        #     ]

        policy_assigner = get_partition(
            config.partition, config.env, base_env, config.n_policies
        )
        split_observation = getattr(base_env, "split_observation", None)
        if config.stagewise:
            assert hasattr(base_env, "get_stage_id")
            get_stage_id = base_env.get_stage_id
        else:
            get_stage_id = None

        policy = MultiPolicyWrapper(policies, policy_assigner, split_observation)

        if config.two_column:
            central_qf1 = ContinuousMLPQFunction(
                env_spec=env_spec,
                hidden_sizes=config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
            )

            central_qf2 = ContinuousMLPQFunction(
                env_spec=env_spec,
                hidden_sizes=config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
            )
        else:
            central_qf1 = None
            central_qf2 = None

        qf1s = [
            ContinuousMLPQFunction(
                env_spec=env_spec,
                hidden_sizes=config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
            )
            for i in range(config.n_policies)
        ]

        qf2s = [
            ContinuousMLPQFunction(
                env_spec=env_spec,
                hidden_sizes=config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
            )
            for i in range(config.n_policies)
        ]

        replay_buffers = DnCPathBuffer(
            num_buffers=config.n_policies,
            capacity_in_transitions=int(1e6),
            sampling_type=config.sampling_type,
        )

        n_workers = (
            len(train_envs)
            if isinstance(train_envs, list)
            else psutil.cpu_count(logical=False)
        )
        sampler_cls = RaySampler if config.ray else LocalSampler
        sampler = sampler_cls(
            agents=policy,
            envs=train_envs,
            max_episode_length=env_spec.max_episode_length,
            worker_class=DefaultWorker,
            n_workers=n_workers,
        )

        visualizer = Visualizer(
            freq=config.vis_freq,
            num_videos=config.vis_num,
            imsize=(config.vis_width, config.vis_height),
            fps=config.vis_fps,
            format=config.vis_format,
        )

        agent = DistralSAC(
            env_spec=env_spec,
            central_policy=central_policy,
            policy=policy,
            policies=policies,
            central_qf1=central_qf1,
            central_qf2=central_qf2,
            qf1s=qf1s,
            qf2s=qf2s,
            policy_lr=config.lr,
            qf_lr=config.lr,
            sampler=sampler,
            visualizer=visualizer,
            get_stage_id=get_stage_id,
            preproc_obs=split_observation,
            two_column=config.two_column,
            initial_kl_coeff=config.kl_coeff,
            gradient_steps_per_itr=agent_args["gradient_steps_per_itr"],
            max_episode_length_eval=env_spec.max_episode_length,
            replay_buffers=replay_buffers,
            min_buffer_size=agent_args["min_buffer_size"],
            target_update_tau=agent_args["target_update_tau"],
            discount=agent_args["discount"],
            buffer_batch_size=agent_args["buffer_batch_size"],
            reward_scale=agent_args["reward_scale"],
            steps_per_epoch=agent_args["steps_per_epoch"],
            eval_env=test_envs,
            num_evaluation_episodes=config.num_evaluation_episodes,
        )

        if torch.cuda.is_available() and config.gpu is not None:
            set_gpu_mode(True, gpu_id=config.gpu)
        else:
            set_gpu_mode(False)
        agent.to()

        trainer.setup(algo=agent, env=train_envs)
        trainer.train(n_epochs=config.n_epochs, batch_size=config.batch_size)

    distral_sac_test(config=config, env_args=env_args, agent_args=agent_args)


def run_distral_sql_discrete_test(
    config,
    env_args,
    agent_args,
):
    function_name = "_".join(
        (
            "distral_sql_discrete",
            config.env,
            config.name.replace(" ", "-"),
            str(config.seed),
        )
    )

    @wrap_experiment(
        archive_launch_repo=False, name=function_name, snapshot_mode="none"
    )
    def distral_sql_discrete_test(
        ctxt=None,
        config=None,
        env_args=None,
        agent_args=None,
    ):
        """Set up environment and algorithm and run the task.
        Args:
            ctxt (garage.experiment.ExperimentContext): The experiment
                configuration used by Trainer to create the snapshotter.
            seed (int): Used to seed the random number generator to produce
                determinism.
        """

        set_seed(config.seed)
        trainer = Trainer(snapshot_config=ctxt)
        init_wandb(config, env_args, trainer._snapshotter.snapshot_dir)
        base_env = ArgsGymEnv(config.env, env_args)
        env_spec = base_env.spec

        if hasattr(base_env, "get_train_envs"):
            train_envs = base_env.get_train_envs()
        else:
            train_envs = normalize(base_env)

        if hasattr(base_env, "get_test_envs"):
            test_envs = base_env.get_test_envs()
        else:
            test_envs = normalize(base_env)

        central_policy = CategoricalMLPPolicy(
            env_spec=env_spec,
            output_dim=env_spec.action_space.flat_dim,
            hidden_sizes=config.hidden_sizes,
            hidden_nonlinearity=torch.relu,
            output_nonlinearity=None,
            name="CentralPolicy",
        )

        policies = [
            CategoricalMLPPolicy(
                env_spec=env_spec,
                output_dim=env_spec.action_space.flat_dim,
                hidden_sizes=config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
                output_nonlinearity=None,
                name="LocalPolicy{}".format(i),
            )
            for i in range(config.n_policies)
        ]

        # qfs = [
        #     DiscreteMLPQFunction(
        #         env_spec=env_spec,
        #         output_dim=env_spec.action_space.flat_dim,
        #         hidden_sizes=config.hidden_sizes,
        #         hidden_nonlinearity=torch.relu,
        #         output_nonlinearity=None,
        #     )
        #     for i in range(config.n_policies)
        # ]

        if config.two_column:
            kl_coeffs = (
                config.kl_coeffs
                if isinstance(config.kl_coeffs, list)
                else [config.kl_coeffs] * config.n_policies
            )
            two_column_policies = [
                DiscreteTwoColumnWrapper(
                    central_policy=central_policy,
                    policy=policies[i],
                    alpha=1,
                    beta=kl_coeffs[i],
                )
                for i in range(config.n_policies)
            ]
            policies = two_column_policies

        policy_assigner = get_partition(
            config.partition, config.env, base_env, config.n_policies
        )

        split_observation = getattr(base_env, "split_observation", None)
        if config.stagewise:
            assert hasattr(base_env, "get_stage_id")
            get_stage_id = base_env.get_stage_id
        else:
            get_stage_id = None

        policy = MultiPolicyWrapper(policies, policy_assigner, split_observation)

        replay_buffers = DnCPathBuffer(
            num_buffers=config.n_policies,
            capacity_in_transitions=int(1e6),
            sampling_type=config.sampling_type,
        )

        n_workers = (
            len(train_envs)
            if isinstance(train_envs, list)
            else psutil.cpu_count(logical=False)
        )
        sampler_cls = RaySampler if config.ray else LocalSampler
        sampler = sampler_cls(
            agents=policy,
            envs=train_envs,
            max_episode_length=env_spec.max_episode_length,
            worker_class=DefaultWorker,
            n_workers=n_workers,
        )

        visualizer = Visualizer(
            freq=config.vis_freq,
            num_videos=config.vis_num,
            imsize=(config.vis_width, config.vis_height),
            fps=config.vis_fps,
            format=config.vis_format,
        )

        agent = DistralSQLDiscrete(
            env_spec=env_spec,
            central_policy=central_policy,
            policy=policy,
            policies=policies,
            central_lr=config.lr,
            model_lr=config.lr,
            sampler=sampler,
            visualizer=visualizer,
            get_stage_id=get_stage_id,
            preproc_obs=split_observation,
            two_column=config.two_column,
            initial_kl_coeff=config.kl_coeff,
            gradient_steps_per_itr=agent_args["gradient_steps_per_itr"],
            max_episode_length_eval=env_spec.max_episode_length,
            replay_buffers=replay_buffers,
            min_buffer_size=agent_args["min_buffer_size"],
            target_update_tau=agent_args["target_update_tau"],
            discount=agent_args["discount"],
            buffer_batch_size=agent_args["buffer_batch_size"],
            reward_scale=agent_args["reward_scale"],
            steps_per_epoch=agent_args["steps_per_epoch"],
            eval_env=test_envs,
            num_evaluation_episodes=config.num_evaluation_episodes,
        )

        if torch.cuda.is_available() and config.gpu is not None:
            set_gpu_mode(True, gpu_id=config.gpu)
        else:
            set_gpu_mode(False)
        agent.to()

        trainer.setup(algo=agent, env=train_envs)
        trainer.train(n_epochs=config.n_epochs, batch_size=config.batch_size)

    distral_sql_discrete_test(config=config, env_args=env_args, agent_args=agent_args)


### Injune ###


def run_dnc_ppo_test(
    env_name="InvertedPendulum-v2",
    seed=1,
    n_policies=1,
    kl_coeff=1.0,
    track_centroid=False,
    partition=None,
    name=None,
):
    function_name = "_".join(("dnc_ppo", env_name, name.replace(" ", "-"), str(seed)))

    @wrap_experiment(
        archive_launch_repo=False, name=function_name, snapshot_mode="none"
    )
    def dnc_ppo_test(
        ctxt=None,
        env_name=env_name,
        seed=seed,
        n_policies=n_policies,
        kl_coeff=kl_coeff,
        track_centroid=track_centroid,
        partition=partition,
        name=name,
    ):
        """Train PPO with InvertedDoublePendulum-v2 environment.
        Args:
            ctxt (garage.experiment.ExperimentContext): The experiment
                configuration used by Trainer to create the snapshotter.
            seed (int): Used to seed the random number generator to produce
                determinism.
        """
        set_seed(seed)

        env = GymEnv(env_name)

        trainer = Trainer(ctxt)

        policies = [
            GaussianMLPPolicy(
                env.spec,
                name="LocalPolicy{}".format(i),
                hidden_sizes=[64, 64],
                hidden_nonlinearity=torch.tanh,
                output_nonlinearity=None,
            )
            for i in range(n_policies)
        ]

        centroid = None

        if track_centroid:
            centroid = GaussianMLPPolicy(
                env.spec,
                name="Centroid",
                hidden_sizes=[64, 64],
                hidden_nonlinearity=torch.tanh,
                output_nonlinearity=None,
            )

        policy_assigner = get_partition(partition, env_name, env, n_policies)

        policy = MultiPolicyWrapper(policies, policy_assigner)

        value_functions = [
            GaussianMLPValueFunction(
                env_spec=env.spec,
                name="LocalValue{}".format(i),
                hidden_sizes=[32, 32],
                hidden_nonlinearity=torch.tanh,
                output_nonlinearity=None,
            )
            for i in range(n_policies)
        ]

        sampler = LocalSampler(
            agents=policy, envs=env, max_episode_length=env.spec.max_episode_length
        )

        algo = DnCPPO(
            env_spec=env.spec,
            policy=policy,
            policies=policies,
            centroid=centroid,
            value_functions=value_functions,
            sampler=sampler,
            kl_coeff=kl_coeff,
            track_centroid=track_centroid,
        )

        trainer.setup(algo, env)
        trainer.train(n_epochs=100, batch_size=10000)

    dnc_ppo_test(
        env_name=env_name,
        seed=seed,
        n_policies=n_policies,
        kl_coeff=kl_coeff,
        track_centroid=track_centroid,
        partition=partition,
        name=name,
    )


def run_ppo_test(env="InvertedPendulum-v2", seed=1, n_epochs=100, name=None):
    function_name = "_".join(("ppo", env, name.replace(" ", "-"), str(seed)))

    @wrap_experiment(
        archive_launch_repo=False, name=function_name, snapshot_mode="none"
    )
    def ppo_test(ctxt=None, env=env, seed=seed, n_epochs=n_epochs, name=name):
        """Train PPO with InvertedDoublePendulum-v2 environment.
        Args:
            ctxt (garage.experiment.ExperimentContext): The experiment
                configuration used by Trainer to create the snapshotter.
            seed (int): Used to seed the random number generator to produce
                determinism.
        """
        set_seed(seed)
        env = GymEnv(env)

        trainer = Trainer(ctxt)

        policy = GaussianMLPPolicy(
            env.spec,
            hidden_sizes=[128, 128],
            hidden_nonlinearity=torch.tanh,
            output_nonlinearity=None,
        )

        value_function = GaussianMLPValueFunction(
            env_spec=env.spec,
            hidden_sizes=[128, 128],
            hidden_nonlinearity=torch.tanh,
            output_nonlinearity=None,
        )

        sampler = LocalSampler(
            agents=policy, envs=env, max_episode_length=env.spec.max_episode_length
        )

        algo = PPO(
            env_spec=env.spec,
            policy=policy,
            value_function=value_function,
            sampler=sampler,
            gae_lambda=0.95,
        )

        trainer.setup(algo, env)
        trainer.train(n_epochs=n_epochs, batch_size=2048)

    ppo_test(env=env, seed=seed, n_epochs=n_epochs, name=name)


def run_dqn_test(env="CartPole-v0", seed=1, n_epochs=1000, name=None, gpu=None):
    function_name = "_".join(("dqn", env, name.replace(" ", "-"), str(seed)))

    @wrap_experiment(
        archive_launch_repo=False, name=function_name, snapshot_mode="none"
    )
    def dqn_test(ctxt=None, env=env, seed=seed, n_epochs=n_epochs, name=name, gpu=gpu):
        """Train DQN with CartPole-v0 environment.
        Args:
            ctxt (garage.experiment.ExperimentContext): The experiment
                configuration used by LocalRunner to create the snapshotter.
            seed (int): Used to seed the random number generator to produce
                determinism.
        """
        set_seed(seed)
        trainer = Trainer(ctxt)

        steps_per_epoch = 10
        sampler_batch_size = 512
        num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size
        env = GymEnv(env)
        replay_buffer = PathBuffer(capacity_in_transitions=int(1e6))
        ### ASDF hidden sizes?
        qf = DiscreteMLPQFunction(
            env_spec=env.spec,
            output_dim=env.spec.action_space.flat_dim,
            hidden_sizes=(8, 5),
        )
        policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf)
        exploration_policy = EpsilonGreedyPolicy(
            env_spec=env.spec,
            policy=policy,
            total_timesteps=num_timesteps,
            max_epsilon=1.0,
            min_epsilon=0.01,
            decay_ratio=0.4,
        )
        sampler = LocalSampler(
            agents=exploration_policy,
            envs=env,
            max_episode_length=env.spec.max_episode_length,
            worker_class=FragmentWorker,
        )
        algo = DQN(
            env_spec=env.spec,
            policy=policy,
            qf=qf,
            exploration_policy=exploration_policy,
            replay_buffer=replay_buffer,
            sampler=sampler,
            steps_per_epoch=steps_per_epoch,
            qf_lr=5e-5,
            discount=0.9,
            min_buffer_size=int(1e4),
            n_train_steps=500,
            target_update_freq=30,
            buffer_batch_size=64,
        )

        if torch.cuda.is_available() and gpu is not None:
            set_gpu_mode(True, gpu_id=gpu)
        else:
            set_gpu_mode(False)
        algo.to()

        trainer.setup(algo, env)
        trainer.train(n_epochs=n_epochs, batch_size=sampler_batch_size)

        env.close()

    dqn_test(env=env, seed=seed, n_epochs=n_epochs, name=name)


def run_hdnc_sac_test(
    env_name="InvertedPendulum-v2",
    seed=1,
    n_epochs=1000,
    n_policies=4,
    kl_coeff=0.01,
    name=None,
    gpu=None,
    env_args={},
):
    function_name = "_".join(("hdnc_sac", env_name, name.replace(" ", "-"), str(seed)))

    @wrap_experiment(
        archive_launch_repo=False, name=function_name, snapshot_mode="none"
    )
    def hdnc_sac_test(
        ctxt=None,
        env_name=env_name,
        seed=seed,
        n_epochs=n_epochs,
        kl_coeff=kl_coeff,
        n_policies=n_policies,
        name=name,
        gpu=gpu,
        env_args=env_args,
    ):

        set_seed(seed)
        trainer = Trainer(snapshot_config=ctxt)
        env = normalize(ArgsGymEnv(env_name, env_args))

        ### Make categorical policy

        hl_qf = DiscreteMLPQFunction(
            env_spec=env.spec,
            output_dim=n_policies,
            hidden_sizes=[32, 32],
            hidden_nonlinearity=torch.relu,
        )
        hl_policy = DiscreteQFArgmaxPolicy(qf=hl_qf, env_spec=env.spec)
        ll_policies = [
            NamedTanhGaussianMLPPolicy(
                env_spec=env.spec,
                hidden_sizes=[256, 256],
                hidden_nonlinearity=torch.relu,
                output_nonlinearity=None,
                min_std=np.exp(-20.0),
                max_std=np.exp(2.0),
                name="LocalPolicy{}".format(i),
            )
            for i in range(n_policies)
        ]

        policy = MultiPolicyWrapper(ll_policies, hl_policy)

        qf1s = [
            ContinuousMLPQFunction(
                env_spec=env.spec,
                hidden_sizes=[256, 256],
                hidden_nonlinearity=torch.relu,
            )
            for i in range(n_policies)
        ]

        qf2s = [
            ContinuousMLPQFunction(
                env_spec=env.spec,
                hidden_sizes=[256, 256],
                hidden_nonlinearity=torch.relu,
            )
            for i in range(n_policies)
        ]

        replay_buffers = HDnCPathBuffer(
            num_buffers=n_policies, capacity_in_transitions=int(1e6)
        )

        sampler = LocalSampler(
            agents=policy,
            envs=env,
            max_episode_length=env.spec.max_episode_length,
            worker_class=DefaultWorker,
            # n_workers=4,
        )

        hdnc_sac = HDnCSAC(
            env_spec=env.spec,
            policy=policy,
            hl_policy=hl_policy,
            ll_policies=ll_policies,
            hl_qf=hl_qf,
            qf1s=qf1s,
            qf2s=qf2s,
            sampler=sampler,
            kl_coeff=kl_coeff,
            gradient_steps_per_itr=1000,
            max_episode_length_eval=env.spec.max_episode_length,
            replay_buffers=replay_buffers,
            min_buffer_size=1e4,
            target_update_tau=5e-3,
            discount=0.99,
            buffer_batch_size=256,
            reward_scale=1.0,
            steps_per_epoch=1,
            eval_env=env,
        )
        if torch.cuda.is_available() and gpu is not None:
            set_gpu_mode(True, gpu_id=gpu)
        else:
            set_gpu_mode(False)
        hdnc_sac.to()

        trainer.setup(algo=hdnc_sac, env=env)
        trainer.train(n_epochs=n_epochs, batch_size=1000)

    hdnc_sac_test(
        env_name=env_name,
        seed=seed,
        n_epochs=n_epochs,
        kl_coeff=kl_coeff,
        n_policies=n_policies,
        name=name,
        env_args=env_args,
    )


def get_partition(partition, env_name, env, n_policies):
    if partition == "random":
        policy_assigner = EnvPartitionPolicy(
            env_spec=env.spec, mode="random", num_partitions=n_policies
        )

    elif partition == "goal_quadrant":
        assert n_policies % 4 == 0 and hasattr(env, "get_goal_quadrant_partition")
        policy_assigner = EnvPartitionPolicy(
            env_spec=env.spec,
            mode="fixed",
            num_partitions=n_policies,
            partitions=env.get_goal_quadrant_partition,
        )
    elif partition == "obstacle_id":
        assert n_policies % 4 == 0 and hasattr(env, "get_obstacle_id_partition")
        policy_assigner = EnvPartitionPolicy(
            env_spec=env.spec,
            mode="fixed",
            num_partitions=n_policies,
            partitions=env.get_obstacle_id_partition,
        )
    elif partition == "obstacle_orientation":
        assert n_policies % 2 == 0 and hasattr(
            env, "get_obstacle_orientation_partition"
        )
        policy_assigner = EnvPartitionPolicy(
            env_spec=env.spec,
            mode="fixed",
            num_partitions=n_policies,
            partitions=env.get_obstacle_orientation_partition,
        )
    elif partition == "goal_cluster":
        assert n_policies % 2 == 0 and hasattr(env, "get_goal_cluster")
        policy_assigner = EnvPartitionPolicy(
            env_spec=env.spec,
            mode="fixed",
            num_partitions=n_policies,
            partitions=env.get_goal_cluster,
        )
    elif partition == "task_id":
        assert hasattr(env, "get_task_id")
        policy_assigner = EnvPartitionPolicy(
            env_spec=env.spec,
            mode="fixed",
            num_partitions=n_policies,
            partitions=env.get_task_id,
        )
    else:
        raise NotImplementedError

    return policy_assigner

