import logging
import numpy as np
from typing import List, Type, Union

from models.modelv2 import ModelV2
from models.action_dist import ActionDistribution
from trainer.trainer import Trainer
from policy.sample_batch import SampleBatch
from agents.league.trainer import LeagueTrainer, LEAGUE_DEFAULT_CONFIG
from agents.ppo.policy import PPOTorchPolicy
from utils.annotations import override
from utils.framework import try_import_torch
from utils.numpy import convert_to_numpy
from utils.typing import Dict, TensorType, TrainerConfigDict

torch, nn = try_import_torch()

logger = logging.getLogger(__name__)


class PopulationEntropyPolicy(PPOTorchPolicy):
    def __init__(self, observation_space, action_space, config):
        self.population_entropy_coeff = config["population_entropy_coeff"]
        self.mean_population_neg_logp = None
        PPOTorchPolicy.__init__(
            self,
            observation_space,
            action_space,
            config,
        )

    @override(PPOTorchPolicy)
    def postprocess_trajectory(
        self, sample_batch, other_agent_batches=None, episode=None
    ):
        if episode is not None:
            policy_map = episode.policy_map
            input_dict = self._lazy_tensor_dict(
                SampleBatch(
                    {
                        SampleBatch.OBS: sample_batch[SampleBatch.OBS],
                        SampleBatch.ACTIONS: sample_batch[SampleBatch.ACTIONS]
                    }
                )
            )
            with torch.no_grad():
                for pid, policy in policy_map.items():
                    # print("policy.model, self.model:", policy.model.training, policy.model is self.model, self.model.training)
                    if policy.model is self.model:
                        team = "left" if "left" in pid else "right"

                population_action_probs = []
                for pid, policy in policy_map.items():
                    if not team in pid:
                        continue
                    logits, _ = policy.model(input_dict)
                    action_dist = policy.dist_class(logits, policy.model)
                    action_probs = torch.exp(action_dist.logp(input_dict[SampleBatch.ACTIONS]))
                    population_action_probs.append(action_probs)

                population_mean_action_probs = np.mean(convert_to_numpy(population_action_probs), axis=0)  # [B, act_dim]
                population_neg_logp = -np.log(population_mean_action_probs)

            sample_batch[SampleBatch.REWARDS] += self.population_entropy_coeff * population_neg_logp
            self.mean_population_neg_logp = np.mean(population_neg_logp)

        return PPOTorchPolicy.postprocess_trajectory(self, sample_batch, other_agent_batches, episode)

    @override(PPOTorchPolicy)
    def loss(
        self,
        model: ModelV2,
        dist_class: Type[ActionDistribution],
        train_batch: SampleBatch,
    ) -> Union[TensorType, List[TensorType]]:
        total_loss = PPOTorchPolicy.loss(self, model, dist_class, train_batch)
        if self.mean_population_neg_logp is not None:
            model.tower_stats["mean_population_neg_logp"] = self.mean_population_neg_logp
        return total_loss

    @override(PPOTorchPolicy)
    def extra_grad_info(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
        infos = PPOTorchPolicy.extra_grad_info(self, train_batch)
        if self.mean_population_neg_logp is not None:
            infos["population_neg_logp"] = convert_to_numpy(
                torch.mean(
                    torch.stack(self.get_tower_stats("mean_population_neg_logp"))
                )
            )

        return convert_to_numpy(infos)


POPULATION_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
    LEAGUE_DEFAULT_CONFIG,  # See keys in trainer.py, which are also supported.
    {
        "population_entropy_coeff": 0.01,
    },
    _allow_unknown_configs=True,
)


class PopulationEntropyTrainer(LeagueTrainer):
    _allow_unknown_subkeys = LeagueTrainer._allow_unknown_subkeys + [
        "league_config", "population_entropy_coeff"
    ]
    _override_all_subkeys_if_type_changes = (
        LeagueTrainer._override_all_subkeys_if_type_changes + ["league_config", "population_entropy_coeff"]
    )

    @classmethod
    @override(LeagueTrainer)
    def get_default_config(cls) -> TrainerConfigDict:
        return POPULATION_DEFAULT_CONFIG

    @override(LeagueTrainer)
    def get_default_policy_class(self, config):
        return PopulationEntropyPolicy
