"""Divide-and-Conquer RL with PPO Policies"""
from dowel import tabular
import torch
import warnings
import numpy as np

from garage import EpisodeBatch
from garage import log_performance
from garage.np.algos import RLAlgorithm
from garage.sampler import LocalSampler
from utils import np_to_torch

from learning.policies.multi_policy_wrapper import MultiPolicyWrapper


class DnC(RLAlgorithm):
    def __init__(
        self,
        env,
        policy_assigner,
        n_policies,
        policy_class,
        policy_kwargs,
        value_function_class,
        value_function_kwargs,
        algorithm_class,
        algorithm_kwargs,
        steps_per_epoch=1,
        discount=0.99,
    ):
        super().__init__()

        self.env = env
        self.n_policies = n_policies
        self._steps_per_epoch = steps_per_epoch
        self._discount = discount
        self.max_episode_length = env.spec.max_episode_length

        ### Initialize Policy Ensemble ###

        self.policies = [
            policy_class(
                env_spec=self.env.spec, name="LocalPolicy{}".format(i), **policy_kwargs
            )
            for i in range(self.n_policies)
        ]
        self.value_fns = [
            value_function_class(
                env_spec=self.env.spec,
                name="LocalValue{}".format(i),
                **value_function_kwargs
            )
            for i in range(self.n_policies)
        ]
        self.algorithms = [
            algorithm_class(
                env_spec=self.env.spec,
                policy=policy,
                value_function=value_function,
                sampler=None,
                **algorithm_kwargs
            )
            for (policy, value_function) in zip(self.policies, self.value_fns)
        ]

        ### Initialize HL Policy ###
        self.policy_assigner = policy_assigner

        self.policy = MultiPolicyWrapper(self.policies, self.policy_assigner)
        self._sampler = LocalSampler(
            agents=self.policy,
            envs=self.env,
            max_episode_length=self.env._max_episode_length,
        )

    def train(self, trainer):

        infos = []

        for _ in trainer.step_epochs():
            for _ in range(self._steps_per_epoch):
                samples = trainer.obtain_episodes(trainer.step_itr)
                info = self._train_once(trainer.step_itr, samples)
                trainer.step_itr += 1
                infos.append(info)

        return infos

    def _train_once(self, itr, all_samples):

        policy_samples = self._extract_policy_samples(all_samples)
        KLs = self._compute_mutual_KL(policy_samples)

        for (i, (algorithm, samples, KL, policy)) in enumerate(
            zip(self.algorithms, policy_samples, KLs, self.policies)
        ):
            if samples == []:
                continue
            ### ASDF: right now doing multiple forward passes for KL then RL training loop, also each policy only trained on their own context
            L_KL = -KL  # self._compute_mutual_KL(self, samples, policy_id=i)
            algorithm._set_auxiliary_obj(L_KL)
            policy_train_info = algorithm._train_once(itr, samples)

            with tabular.prefix(policy.name):
                ### ASDF assume policy train info is the return
                tabular.record("/AverageReturn", policy_train_info)

                tabular.record("/DnCKLLoss", L_KL.item())
                tabular.record("/NumSamples", np.sum(samples.lengths))

        total_undiscounted_returns = log_performance(
            itr, all_samples, discount=self._discount
        )

        return total_undiscounted_returns

    def _extract_policy_samples(self, samples):
        # Returns a num_policies long list of samples sorted by the policy that gathered them.

        policy_samples = [[] for _ in range(self.n_policies)]

        for ep in samples.split():
            ### ASDF assuming each episode is gathered by a single policy
            policy_id = ep.agent_infos["policy_id"][0]
            policy_samples[policy_id].append(ep)

        for (i, samples) in enumerate(policy_samples):
            if len(samples) == 0:
                warnings.warn("Policy {} collected no samples this batch".format(i))
            else:
                policy_samples[i] = EpisodeBatch.concatenate(*samples)

        return policy_samples

    def _compute_mutual_KL(self, samples, policy_id=None):
        ### ASDF compute actual KL later, for now it's just equivalent to running independent PPOs

        return [torch.Tensor([0]) for _ in range(self.n_policies)]
        # import ipdb
        #
        # ipdb.set_trace()
        # observations = np_to_torch(samples.observations)
        # policy_dists = [policy(observations)[0] for policy in self.policies]
        # mutual_KL = torch.Tensor([0])
        #
        # for (i, dist) in enumerate(policy_dists):
        #     if i != policy_id:
        #         mutual_KL += self._compute_symmetric_KL(policy_dists[policy_id], dist)
        #
        # return mutual_KL

    def _compute_symmetric_KL(self, dist1, dist2):
        return 0
