#!/usr/bin/env python3
"""
from garage examples torch ppo/trpo pendulum
"""
import abc
import torch
import numpy as np
import ray
import os
import wandb
import time
import psutil

from garage import wrap_experiment
from garage.envs import normalize
from garage.experiment.deterministic import set_seed
from garage.replay_buffer import PathBuffer
from garage.torch.policies import (
    TanhGaussianMLPPolicy,
)
from garage.torch.q_functions import (
    ContinuousMLPQFunction,
    DiscreteMLPQFunction,
)
from garage.torch import set_gpu_mode
from garage.trainer import Trainer
from garage.sampler import (
    LocalSampler,
    RaySampler,
    FragmentWorker,
    VecWorker,
    DefaultWorker,
)

from environments.gym_env import ArgsGymEnv
from environments.wrappers import DiscreteWrapper

from learning.policies.multi_policy_wrapper import MultiPolicyWrapper
from learning.policies.env_partition_policy import EnvPartitionPolicy
from learning.policies.two_column_wrapper import TwoColumnWrapper, DiscreteTwoColumnWrapper
from learning.algorithms import (
    SAC,
    DnCSAC,
    DnCCEM,
    DistralSAC,
    DistralSQLDiscrete,
)
from learning.policies import (
    NamedTanhGaussianMLPPolicy,
    CategoricalMLPPolicy,
)
from learning.utils.path_buffers import DnCPathBuffer
from learning.utils.visualizer import Visualizer
import environments


class Experiment(abc.ABC):
    NAME = ""
    SAMPLER_WORKER = DefaultWorker

    def __init__(self, config, env_args, agent_args):
        self.config = config
        self.env_args = env_args
        self.agent_args = agent_args

    @property
    def function_name(self):
        return "_".join((
            self.NAME,
            self.config.env,
            self.config.name.replace(" ", "-"),
            str(self.config.seed),
        ))

    @property
    def wandb_args(self):
        return self.env_args

    def set_seed(self):
        set_seed(self.config.seed)

    def setup_trainer(self, ctxt):
        self.trainer = Trainer(snapshot_config=ctxt)

    def setup_wandb(self):
        exclude = ["device"]

        if not self.config.wandb:
            os.environ["WANDB_MODE"] = "dryrun"

        all_configs = {
            **{k: v for k, v in self.config.__dict__.items() if k not in exclude},
            **{k: v for k, v in self.wandb_args.items() if k not in exclude},
        }
        wandb.init(
            name="_".join((
                self.config.experiment,
                self.config.env,
                self.config.name.replace(" ", "-"),
                str(self.config.seed),
            )),
            project=self.config.wandb_project,
            config=all_configs,
            dir=self.trainer._snapshotter.snapshot_dir,
            entity=self.config.wandb_entity,
            notes=self.config.notes,
        )

    def setup_envs(self):
        self.base_env = ArgsGymEnv(self.config.env, self.env_args)
        self.base_env = self.wrap_env(self.base_env)
        self.env_spec = self.base_env.spec

        if hasattr(self.base_env, "get_train_envs"):
            self.train_envs = self.base_env.get_train_envs()
        else:
            self.train_envs = normalize(self.base_env)

        if hasattr(self.base_env, "get_test_envs"):
            self.test_envs = self.base_env.get_test_envs()
        else:
            self.test_envs = normalize(self.base_env)

        self.split_observation = getattr(
            self.base_env, "split_observation", None
        )
        if self.config.stagewise:
            assert hasattr(self.base_env, "get_stage_id")
            self.get_stage_id = self.base_env.get_stage_id
        else:
            self.get_stage_id = None

    def wrap_env(self, env):
        return env

    @abc.abstractmethod
    def setup_models(self):
        raise NotImplementedError()

    def setup_replay_buffer(self):
        self.replay_buffers = DnCPathBuffer(
            num_buffers=self.config.n_policies,
            capacity_in_transitions=int(1e6),
            sampling_type=self.config.sampling_type,
        )

    def setup_sampler(self):
        n_workers = (
            len(self.train_envs)
            if isinstance(self.train_envs, list)
            else psutil.cpu_count(logical=False)
        )
        sampler_cls = RaySampler if self.config.ray else LocalSampler
        self.sampler = sampler_cls(
            agents=self.policy,
            envs=self.train_envs,
            max_episode_length=self.env_spec.max_episode_length,
            worker_class=self.SAMPLER_WORKER,
            n_workers=n_workers,
        )

    def setup_visualizer(self):
        self.visualizer = Visualizer(
            freq=self.config.vis_freq,
            num_videos=self.config.vis_num,
            imsize=(self.config.vis_width, self.config.vis_height),
            fps=self.config.vis_fps,
            format=self.config.vis_format,
        )

    @abc.abstractmethod
    def setup_agent(self):
        raise NotImplementedError()

    def set_gpu(self):
        if torch.cuda.is_available() and self.config.gpu is not None:
            set_gpu_mode(True, gpu_id=self.config.gpu)
        else:
            set_gpu_mode(False)
        self.agent.to()

    def train_agent(self):
        self.trainer.setup(algo=self.agent, env=self.train_envs)
        self.trainer.train(
            n_epochs=self.config.n_epochs, batch_size=self.config.batch_size
        )

    def run(self, snapshot_mode="none"):
        @wrap_experiment(
            archive_launch_repo=False,
            name=self.function_name,
            snapshot_mode=snapshot_mode,
        )
        def experiment(ctxt=None):
            self.set_seed()
            self.setup_trainer(ctxt)
            self.setup_wandb()
            self.setup_envs()
            self.setup_models()
            self.setup_replay_buffer()
            self.setup_sampler()
            self.setup_visualizer()
            self.setup_agent()
            self.set_gpu()
            self.train_agent()
        experiment()


class DnCCEMExperiment(Experiment):
    NAME = "dnc_cem"

    def __init__(self, config, env_args, agent_args, cem_configs):
        super().__init__(config, env_args, agent_args)
        self.cem_configs = cem_configs

    @property
    def wandb_args(self):
        return {**self.env_args, **self.cem_configs}

    def setup_models(self):
        self.tic1 = time.time()
        self.policies = [
            NamedTanhGaussianMLPPolicy(
                env_spec=self.env_spec,
                hidden_sizes=self.config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
                output_nonlinearity=None,
                min_std=np.exp(-20.0),
                max_std=np.exp(2.0),
                name="LocalPolicy{}".format(i),
            )
            for i in range(self.config.n_policies)
        ]

        self.cem_policies = [
            NamedTanhGaussianMLPPolicy(
                env_spec=self.env_spec,
                hidden_sizes=self.config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
                output_nonlinearity=None,
                min_std=np.exp(-20.0),
                max_std=np.exp(2.0),
                name="LocalPolicy{}".format(i),
            )
            for i in range(self.config.n_policies)
        ]

        policy_assigner = get_partition(
            self.config.partition, self.config.env,
            self.base_env, self.config.n_policies
        )

        self.policy = MultiPolicyWrapper(
            self.policies, policy_assigner, self.split_observation
        )

        self.qf1s = [
            ContinuousMLPQFunction(
                env_spec=self.env_spec,
                hidden_sizes=self.config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
            )
            for i in range(self.config.n_policies)
        ]

        self.qf2s = [
            ContinuousMLPQFunction(
                env_spec=self.env_spec,
                hidden_sizes=self.config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
            )
            for i in range(self.config.n_policies)
        ]

        self.tic2 = time.time()

    def setup_replay_buffer(self):
        super().setup_replay_buffer()
        self.tic3 = time.time()

    def setup_sampler(self):
        super().setup_sampler()
        self.tic4 = time.time()

    def setup_agent(self):
        self.agent = DnCCEM(
            cem_policies=self.cem_policies,
            cem_configs=self.cem_configs,
            env_spec=self.env_spec,
            policy=self.policy,
            policies=self.policies,
            qf1s=self.qf1s,
            qf2s=self.qf2s,
            policy_lr=self.config.lr,
            qf_lr=self.config.lr,
            sampler=self.sampler,
            visualizer=self.visualizer,
            get_stage_id=self.get_stage_id,
            preproc_obs=self.split_observation,
            initial_kl_coeff=self.config.kl_coeff,
            sampling_type=self.config.sampling_type,
            gradient_steps_per_itr=self.agent_args["gradient_steps_per_itr"],
            max_episode_length_eval=self.env_spec.max_episode_length,
            replay_buffers=self.replay_buffers,
            min_buffer_size=self.agent_args["min_buffer_size"],
            target_update_tau=self.agent_args["target_update_tau"],
            discount=self.agent_args["discount"],
            buffer_batch_size=self.agent_args["buffer_batch_size"],
            reward_scale=self.agent_args["reward_scale"],
            steps_per_epoch=self.agent_args["steps_per_epoch"],
            eval_env=self.test_envs,
            num_evaluation_episodes=self.config.num_evaluation_episodes,
            evaluation_frequency=self.config.evaluation_frequency,
        )
        self.tic5 = time.time()

    def set_gpu(self):
        super().set_gpu()
        self.tic6 = time.time()

    def train_agent(self):
        print(
            "Initializing models: {}, Making replay buffers: {}, Making samplers: {}, Making DnC trainer: {}, Putting to GPU: {}".format(
                self.tic2 - self.tic1, self.tic3 - self.tic2, self.tic4 - self.tic3, self.tic5 - self.tic4, self.tic6 - self.tic5
            )
        )
        super().train_agent()


class DnCSACExperiment(Experiment):
    NAME = "dnc_sac"

    def setup_models(self):
        self.policies = [
            NamedTanhGaussianMLPPolicy(
                env_spec=self.env_spec,
                hidden_sizes=self.config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
                output_nonlinearity=None,
                min_std=np.exp(-20.0),
                max_std=np.exp(2.0),
                name="LocalPolicy{}".format(i),
            )
            for i in range(self.config.n_policies)
        ]

        policy_assigner = get_partition(
            self.config.partition, self.config.env,
            self.base_env, self.config.n_policies
        )

        self.policy = MultiPolicyWrapper(
            self.policies, policy_assigner, self.split_observation
        )

        self.qf1s = [
            ContinuousMLPQFunction(
                env_spec=self.env_spec,
                hidden_sizes=self.config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
            )
            for i in range(self.config.n_policies)
        ]

        self.qf2s = [
            ContinuousMLPQFunction(
                env_spec=self.env_spec,
                hidden_sizes=self.config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
            )
            for i in range(self.config.n_policies)
        ]

    def setup_agent(self):
        self.agent = DnCSAC(
            env_spec=self.env_spec,
            policy=self.policy,
            policies=self.policies,
            qf1s=self.qf1s,
            qf2s=self.qf2s,
            policy_lr=self.config.lr,
            qf_lr=self.config.lr,
            sampler=self.sampler,
            visualizer=self.visualizer,
            get_stage_id=self.get_stage_id,
            preproc_obs=self.split_observation,
            initial_kl_coeff=self.config.kl_coeff,
            sampling_type=self.config.sampling_type,
            gradient_steps_per_itr=self.agent_args["gradient_steps_per_itr"],
            max_episode_length_eval=self.env_spec.max_episode_length,
            replay_buffers=self.replay_buffers,
            min_buffer_size=self.agent_args["min_buffer_size"],
            target_update_tau=self.agent_args["target_update_tau"],
            discount=self.agent_args["discount"],
            buffer_batch_size=self.agent_args["buffer_batch_size"],
            reward_scale=self.agent_args["reward_scale"],
            steps_per_epoch=self.agent_args["steps_per_epoch"],
            eval_env=self.test_envs,
            num_evaluation_episodes=self.config.num_evaluation_episodes,
        )


class SACExperiment(Experiment):
    NAME = "sac"
    SAMPLER_WORKER = FragmentWorker

    def setup_models(self):
        self.policy = TanhGaussianMLPPolicy(
            env_spec=self.env_spec,
            hidden_sizes=self.config.hidden_sizes,
            hidden_nonlinearity=torch.relu,
            output_nonlinearity=None,
            min_std=np.exp(-20.0),
            max_std=np.exp(2.0),
        )

        self.qf1 = ContinuousMLPQFunction(
            env_spec=self.env_spec,
            hidden_sizes=self.config.hidden_sizes,
            hidden_nonlinearity=torch.relu,
        )

        self.qf2 = ContinuousMLPQFunction(
            env_spec=self.env_spec,
            hidden_sizes=self.config.hidden_sizes,
            hidden_nonlinearity=torch.relu,
        )

    def setup_replay_buffer(self):
        self.replay_buffer = PathBuffer(
            capacity_in_transitions=int(1e6),
        )

    def setup_agent(self):
        num_tasks = (
            getattr(self.base_env, "num_tasks", 1)
            if self.config.experiment == "mtsac" else 1
        )
        self.agent = SAC(
            env_spec=self.env_spec,
            policy=self.policy,
            qf1=self.qf1,
            qf2=self.qf2,
            sampler=self.sampler,
            visualizer=self.visualizer,
            num_tasks=num_tasks,
            gradient_steps_per_itr=self.agent_args["gradient_steps_per_itr"],
            max_episode_length_eval=self.env_spec.max_episode_length,
            replay_buffer=self.replay_buffer,
            min_buffer_size=self.agent_args["min_buffer_size"],
            target_update_tau=self.agent_args["target_update_tau"],
            discount=self.agent_args["discount"],
            buffer_batch_size=self.agent_args["buffer_batch_size"],
            reward_scale=self.agent_args["reward_scale"],
            steps_per_epoch=self.agent_args["steps_per_epoch"],
            policy_lr=self.config.lr,
            qf_lr=self.config.lr,
            eval_env=self.test_envs,
            num_evaluation_episodes=self.config.num_evaluation_episodes,
            grad_clip=self.config.grad_clip,
        )


class SACDiscreteExperiment(SACExperiment):
    NAME = "sac_discrete"
    SAMPLER_WORKER = FragmentWorker

    def wrap_env(self, env):
        return DiscreteWrapper(
            env=env,
            n_actions=env.get_num_discrete_actions(),
            disc2cont=env.disc2cont,
        )

    def setup_models(self):
        self.policy = CategoricalMLPPolicy(
            env_spec=self.env_spec,
            output_dim=self.env_spec.action_space.flat_dim,
            hidden_sizes=self.config.hidden_sizes,
            hidden_nonlinearity=torch.relu,
            output_nonlinearity=None,
        )

        self.qf1 = DiscreteMLPQFunction(
            env_spec=self.env_spec,
            hidden_sizes=self.config.hidden_sizes,
            hidden_nonlinearity=torch.relu,
        )

        self.qf2 = DiscreteMLPQFunction(
            env_spec=self.env_spec,
            hidden_sizes=self.config.hidden_sizes,
            hidden_nonlinearity=torch.relu,
        )


class DistralSACExperiment(Experiment):
    NAME = "distral_sac"

    def setup_models(self):
        self.central_policy = NamedTanhGaussianMLPPolicy(
            env_spec=self.env_spec,
            hidden_sizes=self.config.hidden_sizes,
            hidden_nonlinearity=torch.relu,
            output_nonlinearity=None,
            min_std=np.exp(-20.0),
            max_std=np.exp(2.0),
            name="CentralPolicy",
        )

        self.policies = [
            NamedTanhGaussianMLPPolicy(
                env_spec=self.env_spec,
                hidden_sizes=self.config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
                output_nonlinearity=None,
                min_std=np.exp(-20.0),
                max_std=np.exp(2.0),
                name="LocalPolicy{}".format(i),
            )
            for i in range(self.config.n_policies)
        ]

        # if self.config.two_column:
        #     kl_coeffs = (
        #         self.config.kl_coeffs if isinstance(self.config.kl_coeffs, list)
        #         else [self.config.kl_coeffs] * self.config.n_policies
        #     )
        #     self.policies = [
        #         TwoColumnWrapper(
        #             policy,
        #             self.central_policy,
        #             alpha=1,
        #             beta=kl_coeff,
        #         )
        #         for policy, kl_coeff in zip(self.policies, kl_coeffs)
        #     ]

        policy_assigner = get_partition(
            self.config.partition, self.config.env,
            self.base_env, self.config.n_policies
        )

        self.policy = MultiPolicyWrapper(
            self.policies, policy_assigner, self.split_observation
        )

        if self.config.two_column:
            self.central_qf1 = ContinuousMLPQFunction(
                env_spec=self.env_spec,
                hidden_sizes=self.config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
            )

            self.central_qf2 = ContinuousMLPQFunction(
                env_spec=self.env_spec,
                hidden_sizes=self.config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
            )
        else:
            self.central_qf1 = None
            self.central_qf2 = None

        self.qf1s = [
            ContinuousMLPQFunction(
                env_spec=self.env_spec,
                hidden_sizes=self.config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
            )
            for i in range(self.config.n_policies)
        ]

        self.qf2s = [
            ContinuousMLPQFunction(
                env_spec=self.env_spec,
                hidden_sizes=self.config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
            )
            for i in range(self.config.n_policies)
        ]

    def setup_agent(self):
        self.agent = DistralSAC(
            env_spec=self.env_spec,
            central_policy=self.central_policy,
            policy=self.policy,
            policies=self.policies,
            central_qf1=self.central_qf1,
            central_qf2=self.central_qf2,
            qf1s=self.qf1s,
            qf2s=self.qf2s,
            policy_lr=self.config.lr,
            qf_lr=self.config.lr,
            sampler=self.sampler,
            visualizer=self.visualizer,
            get_stage_id=self.get_stage_id,
            preproc_obs=self.split_observation,
            two_column=self.config.two_column,
            initial_kl_coeff=self.config.kl_coeff,
            gradient_steps_per_itr=self.agent_args["gradient_steps_per_itr"],
            max_episode_length_eval=self.env_spec.max_episode_length,
            replay_buffers=self.replay_buffers,
            min_buffer_size=self.agent_args["min_buffer_size"],
            target_update_tau=self.agent_args["target_update_tau"],
            discount=self.agent_args["discount"],
            buffer_batch_size=self.agent_args["buffer_batch_size"],
            reward_scale=self.agent_args["reward_scale"],
            steps_per_epoch=self.agent_args["steps_per_epoch"],
            eval_env=self.test_envs,
            num_evaluation_episodes=self.config.num_evaluation_episodes,
        )


class DistralSQLDiscreteExperiment(Experiment):
    NAME = "distral_sql_discrete"

    def setup_models(self):
        # Policies are equiv to softmax(Q)
        # i.e., Q(.|s) = Policy(s).logits
        self.central_policy = CategoricalMLPPolicy(
            env_spec=self.env_spec,
            output_dim=self.env_spec.action_space.flat_dim,
            hidden_sizes=self.config.hidden_sizes,
            hidden_nonlinearity=torch.relu,
            output_nonlinearity=None,
            name="CentralPolicy",
        )

        self.policies = [
            CategoricalMLPPolicy(
                env_spec=self.env_spec,
                output_dim=self.env_spec.action_space.flat_dim,
                hidden_sizes=self.config.hidden_sizes,
                hidden_nonlinearity=torch.relu,
                output_nonlinearity=None,
                name="LocalPolicy{}".format(i),
            )
            for i in range(self.config.n_policies)
        ]

        # self.qfs = [
        #     DiscreteMLPQFunction(
        #         env_spec=self.env_spec,
        #         output_dim=self.env_spec.action_space.flat_dim,
        #         hidden_sizes=self.config.hidden_sizes,
        #         hidden_nonlinearity=torch.relu,
        #         output_nonlinearity=None,
        #     )
        #     for i in range(self.config.n_policies)
        # ]

        if self.config.two_column:
            kl_coeffs = (
                self.config.kl_coeffs
                if isinstance(self.config.kl_coeffs, list)
                else [self.config.kl_coeffs] * self.config.n_policies
            )
            two_column_policies = [
                DiscreteTwoColumnWrapper(
                    central_policy=self.central_policy,
                    policy=self.policies[i],
                    alpha=1,
                    beta=kl_coeffs[i],
                )
                for i in range(self.config.n_policies)
            ]
            self.policies = two_column_policies

        policy_assigner = get_partition(
            self.config.partition, self.config.env,
            self.base_env, self.config.n_policies
        )

        self.policy = MultiPolicyWrapper(
            self.policies, policy_assigner, self.split_observation
        )

    def setup_agent(self):
        self.agent = DistralSQLDiscrete(
            env_spec=self.env_spec,
            central_policy=self.central_policy,
            policy=self.policy,
            policies=self.policies,
            central_lr=self.config.lr,
            model_lr=self.config.lr,
            sampler=self.sampler,
            visualizer=self.visualizer,
            get_stage_id=self.get_stage_id,
            preproc_obs=self.split_observation,
            two_column=self.config.two_column,
            initial_kl_coeff=self.config.kl_coeff,
            gradient_steps_per_itr=self.agent_args["gradient_steps_per_itr"],
            max_episode_length_eval=self.env_spec.max_episode_length,
            replay_buffers=self.replay_buffers,
            min_buffer_size=self.agent_args["min_buffer_size"],
            target_update_tau=self.agent_args["target_update_tau"],
            discount=self.agent_args["discount"],
            buffer_batch_size=self.agent_args["buffer_batch_size"],
            reward_scale=self.agent_args["reward_scale"],
            steps_per_epoch=self.agent_args["steps_per_epoch"],
            eval_env=self.test_envs,
            num_evaluation_episodes=self.config.num_evaluation_episodes,
        )


def get_partition(partition, env_name, env, n_policies):
    if partition == "random":
        policy_assigner = EnvPartitionPolicy(
            env_spec=env.spec, mode="random", num_partitions=n_policies
        )

    elif partition == "goal_quadrant":
        assert n_policies % 4 == 0 and hasattr(env, "get_goal_quadrant_partition")
        policy_assigner = EnvPartitionPolicy(
            env_spec=env.spec,
            mode="fixed",
            num_partitions=n_policies,
            partitions=env.get_goal_quadrant_partition,
        )
    elif partition == "obstacle_id":
        assert n_policies % 4 == 0 and hasattr(env, "get_obstacle_id_partition")
        policy_assigner = EnvPartitionPolicy(
            env_spec=env.spec,
            mode="fixed",
            num_partitions=n_policies,
            partitions=env.get_obstacle_id_partition,
        )
    elif partition == "obstacle_orientation":
        assert n_policies % 2 == 0 and hasattr(
            env, "get_obstacle_orientation_partition"
        )
        policy_assigner = EnvPartitionPolicy(
            env_spec=env.spec,
            mode="fixed",
            num_partitions=n_policies,
            partitions=env.get_obstacle_orientation_partition,
        )
    elif partition == "goal_cluster":
        assert n_policies % 2 == 0 and hasattr(env, "get_goal_cluster")
        policy_assigner = EnvPartitionPolicy(
            env_spec=env.spec,
            mode="fixed",
            num_partitions=n_policies,
            partitions=env.get_goal_cluster,
        )
    elif partition == "task_id":
        assert hasattr(env, "get_task_id")
        policy_assigner = EnvPartitionPolicy(
            env_spec=env.spec,
            mode="fixed",
            num_partitions=n_policies,
            partitions=env.get_task_id,
        )
    else:
        raise NotImplementedError

    return policy_assigner
