""" DnC with CEM iterations to learn task-wise betas.  This is built on top of DnCSAC """

from dowel import tabular
import numpy as np
import torch

from garage.torch import as_torch_dict, as_torch, global_device

from .dnc_sac import DnCSAC
from learning.utils import rollout, log_wandb

BETAS = np.array([1e-7, 1e-5, 1e-3, 0.1, 10.0, 1000.0])


class DnCCEM(DnCSAC):
    def __init__(
        self,
        cem_policies,
        cem_configs={
            "frequency": 50,
            "sigma": 1.1,
            "population_size": 20,
            "num_elite": 1,
            "iterations": 10,
            "warm_up": 20,
            "update": "RL",
            "num_updates": 1,
            "one_beta": False,
            "discrete": True,
            "individual_evaluation": False,
            "fixed_search": False,
            "restart": False,
        },
        *args,
        **kwargs,
    ):

        super().__init__(*args, **kwargs)

        self.cem_policies = cem_policies
        self._cem_iterations = 0
        self._cem_best_score = -np.inf

        ### for discrete:
        ## To Do: Assert initial_kl_coeff is a power of the right base, or actually maybe it's fine whatever it is
        ## Let sigma be the factor? and use fixed uniform distribution, maybe just [-1, 0, +1]?

        ### HACK Assume starting with single initial_kl_coeff
        self._log_kl_coeffs = (
            torch.ones((self.n_policies, self.n_policies))
            * kwargs["initial_kl_coeff"][0]
        ).log()
        self._cem_configs = cem_configs
        self._sigma = cem_configs["sigma"]

        self._cem_env_steps = 0

    def train(self, trainer):
        """Obtain samplers and start actual training for each epoch.

        Args:
            trainer (Trainer): Gives the algorithm the access to
                :method:`~Trainer.step_epochs()`, which provides services
                such as snapshotting and sampler control.

        Returns:
            float: The average return in last epoch cycle.

        """
        if not self._eval_env:
            self._eval_env = trainer.get_env_copy()
        last_return = None
        for epoch in trainer.step_epochs():
            if epoch == 0:
                self._log_cem_statistics(
                    trainer.step_itr,
                    [-np.inf] * self.n_policies,
                    [0] * self.n_policies,
                    [0] * self.n_policies,
                    [0] * self.n_policies,
                )

            if (
                epoch >= self._cem_configs["warm_up"]
                and (epoch - self._cem_configs["warm_up"])
                % self._cem_configs["frequency"]
                < self._cem_configs["iterations"]
            ):
                best_scores, means, stds, z_scores = self.run_cem_once()
                self._log_cem_statistics(
                    trainer.step_itr, best_scores, means, stds, z_scores
                )

            else:
                for itr in range(self._steps_per_epoch):
                    self.train_once(trainer)

                if epoch % self._evaluation_frequency == 0:
                    last_return = self._evaluate_policy(trainer.step_itr)

                infos = {}
                infos["AverageReturn"] = np.mean(
                    [np.mean(self.episode_rewards[i]) for i in range(self.n_policies)]
                )
                infos["SuccessRate"] = np.mean(
                    [np.mean(self.success_rates[i]) for i in range(self.n_policies)]
                )
                infos["StagesCompleted"] = np.mean(
                    [np.mean(self.stages_completed[i]) for i in range(self.n_policies)]
                )
                infos["TotalEnvSteps"] = trainer.total_env_steps + self._cem_env_steps
                infos["TotalRLEnvSteps"] = trainer.total_env_steps
                log_wandb(trainer.step_itr, infos, medias=videos, prefix="Train/")
            trainer.step_itr += 1

        return np.mean(last_return)

    ### CEM Code

    def run_cem_once(self):

        if self._cem_configs["restart"]:
            self._log_kl_coeffs = (
                (
                    torch.ones((self.n_policies, self.n_policies))
                    * self._initial_kl_coeffs[0]
                )
                .log()
                .to(global_device())
            )

        current_population = self._generate_population()
        fitness = self._evaluate_cem(current_population)

        best_scores = []
        means = []
        stds = []
        zs = []
        for policy_id in range(self.n_policies):
            elite_idxs = np.argpartition(
                fitness[:, policy_id], -self._cem_configs["num_elite"]
            )[-self._cem_configs["num_elite"] :]
            current_elites = current_population[elite_idxs]
            best = np.mean(fitness[elite_idxs][:, policy_id])
            best_scores.append(best)
            means.append(np.mean(fitness[:, policy_id]))
            stds.append(np.std(fitness[:, policy_id]))
            zs.append(
                (best - np.mean(fitness[:, policy_id])) / np.std(fitness[:, policy_id])
            )
            ### select only betas for policy_id
            self._log_kl_coeffs[policy_id, :] = torch.mean(
                current_elites[:, policy_id], dim=0
            )

        self._cem_iterations += 1

        return best_scores, means, stds, zs

    def _generate_population(self):

        if self._cem_configs["fixed_search"]:
            ### Tries each of a fixed set of BETAS defined above, if pairs, must do individual evaluation to make this work
            assert (
                self._cem_configs["one_beta"]
                or self._cem_configs["individual_evaluation"]
            )
            current_population = np.random.choice(
                BETAS,
                size=(
                    self._cem_configs["population_size"],
                    self.n_policies,
                    self.n_policies,
                ),
            )
            current_population = (
                torch.Tensor(current_population).to(global_device()).log()
            )
            # current_population = BETAS.to(global_device()).log() * torch.unsqueeze(
            #     torch.ones_like(self._log_kl_coeffs), dim=0
            # )
            return current_population

        if self._cem_configs["one_beta"]:
            if self._cem_configs["discrete"]:
                delta = (
                    np.log(self._cem_configs["sigma"])
                    * torch.randint(
                        low=-1,
                        high=2,
                        size=(self._cem_configs["population_size"], 1, 1),
                    ).to(global_device())
                )

            else:
                delta = self._cem_configs["sigma"] * torch.randn(
                    (self._cem_configs["population_size"], 1, 1)
                ).to(global_device())
        else:
            if self._cem_configs["discrete"]:
                delta = (
                    np.log(self._cem_configs["sigma"])
                    * torch.randint(
                        low=-1,
                        high=2,
                        size=(
                            self._cem_configs["population_size"],
                            self.n_policies,
                            self.n_policies,
                        ),
                    ).to(global_device())
                )
            else:
                delta = (
                    self._cem_configs["sigma"]
                    * torch.randn(
                        (
                            self._cem_configs["population_size"],
                            self.n_policies,
                            self.n_policies,
                        )
                    ).to(global_device())
                )

        ### HACK : keep current beta in population, check it doesn't mess up the fixed search stuff
        delta[0] = 0

        current_population = self._log_kl_coeffs + delta
        return current_population

    def _evaluate_cem(self, population):
        fitness = np.zeros((len(population), self.n_policies))
        for i, log_betas in enumerate(population):

            ### Copy policy
            self._copy_policies(self.policies, self.cem_policies)

            ### Update policy using betas
            ### ASDF Maybe should be regularizing to old policies just in case, especially for larger # of updates
            ### 4/13 2:40PM switched loop order to num_updates ( n_policy ) instead of other way around
            for j in range(self._cem_configs["num_updates"]):
                for policy_id in range(self.n_policies):
                    all_obs, policy_samples = self.replay_buffers.sample_transitions(
                        self._buffer_batch_size, policy_id
                    )
                    all_obs = [as_torch(obs) for obs in all_obs]
                    policy_samples = as_torch_dict(policy_samples)

                    self._optimize_policy_only(
                        all_obs,
                        policy_samples,
                        policy_id=policy_id,
                        log_betas=log_betas,
                    )

            ### Evaluate policies

            undiscounted_returns = [[] for _ in range(self.n_policies)]

            for j in range(self._num_evaluation_episodes):
                if isinstance(self._eval_env, list):
                    eval_env = self._eval_env[j % len(self._eval_env)]
                else:
                    eval_env = self._eval_env
                eps = rollout(
                    eval_env,
                    self.policy,
                    max_episode_length=self._max_episode_length_eval,
                    deterministic=self._use_deterministic_evaluation,
                )
                policy_id = eps["agent_infos"]["policy_id"][0]
                undiscounted_returns[policy_id].append(sum(eps["rewards"]))
                self._cem_env_steps += len(eps["rewards"])
            if self._cem_configs["one_beta"]:
                ### Evaluate policies together
                fitness[i, :] = np.mean(np.concatenate(undiscounted_returns))
            else:
                ### Evaluate policies separately
                for j in range(self.n_policies):
                    if self._cem_configs["individual_evaluation"]:
                        fitness[i, j] = np.mean(undiscounted_returns[j])
                    else:
                        fitness[i, j] = np.mean(np.concatenate(undiscounted_returns))
            ### Restore Policy
            self._copy_policies(self.cem_policies, self.policies)

        return fitness

    def _optimize_policy_only(self, all_obs, samples_data, policy_id, log_betas=[]):
        obs = samples_data["observation"]

        ### DnC KL Penalty
        kl, kl_penalty = self._compute_kl_penalty(
            obs, all_obs, policy_id, log_betas=log_betas
        )
        policy_loss = kl_penalty

        if self._cem_configs["update"] == "RL":
            action_dists = self.policies[policy_id](obs)[0]
            (
                new_actions_pre_tanh,
                new_actions,
            ) = action_dists.rsample_with_pre_tanh_value()
            log_pi_new_actions = action_dists.log_prob(
                value=new_actions, pre_tanh_value=new_actions_pre_tanh
            )

            rl_loss = self._actor_objective(
                samples_data, new_actions, log_pi_new_actions, policy_id
            )
            policy_loss += rl_loss

        else:
            assert self._cem_configs["update"] == "KL"

        self._policy_optimizers[policy_id].zero_grad()
        policy_loss.backward()

        self._policy_optimizers[policy_id].step()

        return (
            policy_loss.item(),
            kl_penalty.item(),
            kl.mean().item(),
        )

    ### Random Things

    def _log_cem_statistics(self, step, scores, means, stds, z_scores):
        infos = {}

        infos["Iteration"] = self._cem_iterations
        infos["Score"] = np.mean(scores)
        infos["MeanScore"] = np.mean(means)
        infos["StdScore"] = np.mean(stds)
        infos["ZScore"] = np.mean(z_scores)
        infos["TotalEnvSteps"] = self._cem_env_steps

        for policy_id in range(self.n_policies):
            log_betas = self._log_kl_coeffs[policy_id].cpu().detach().numpy()
            sum_betas = 0
            for j in range(self.n_policies):
                if j != policy_id:
                    beta_ij = np.exp(log_betas[j])
                    infos["Beta{}-{}".format(policy_id, j)] = beta_ij
                    sum_betas += beta_ij
            betamean = sum_betas / (self.n_policies - 1)
            infos["Beta{}".format(policy_id)] = betamean
            infos["Score{}".format(policy_id)] = scores[policy_id]
            infos["MeanScore{}".format(policy_id)] = means[policy_id]
            infos["StdScore{}".format(policy_id)] = stds[policy_id]
            infos["ZScore{}".format(policy_id)] = z_scores[policy_id]

        log_wandb(step, infos, prefix="Train/CEM/")

    def _copy_policies(self, source, target):
        """Update parameters in the target policies."""
        for target_policy, source_policy in zip(target, source):
            for t_param, param in zip(
                target_policy.parameters(), source_policy.parameters()
            ):
                t_param.data.copy_(param.data)

    @property
    def networks(self):
        """Return all the networks within the model.

        Returns:
            list: A list of networks.

        """
        ### TO DO: Add Beta Policies
        return [
            *self.policies,
            *self._qf1s,
            *self._qf2s,
            *self._target_qf1s,
            *self._target_qf2s,
        ]

    def to(self, device=None):
        super().to(device=device)

        ### TO DO: Put Beta Policies on device
