from collections import deque
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from agent.ddpg import Actor, Critic
import utils

from typing import Any, Deque, Iterator, List, Optional, Protocol, cast, Dict


class DDPGAgentProtocol(Protocol):
    num_policies: int
    penalty_weight: float
    store_every_steps: int
    use_ensemble: bool
    actor_ensemble: Deque[Actor]
    actor_ensemble_loss_fn: Any
    actor_loss_weight: float
    ensemble_dist: Optional[torch.distributions.MixtureSameFamily]
    ensemble_actor: Actor
    ensemble_actor_opt: torch.optim.Optimizer
    actor_ensemble_steps: List[int]
    actor_ensemble_sliding_window: bool
    optimal_prior_snapshot_path: Optional[str]
    optimal_prior_actor: Optional[Actor]
    actor_added: bool

    actor: Actor
    critic: Critic
    use_tb: bool
    use_wandb: bool
    reward_free: bool
    stddev_schedule: Any
    stddev_clip: float
    device: Any
    actor_opt: torch.optim.Optimizer
    lr: float

    def init_from(self, other: "DDPGAgentProtocol") -> None:
        ...

    def ensemble_policy(self, obs: torch.Tensor, step: int) -> torch.distributions.MixtureSameFamily:
        ...

    def add_ensemble_loss(self, obs: torch.Tensor, step: int, dist: utils.TruncatedNormal, log_prob: torch.Tensor, action: torch.Tensor, stddev: float, actor_loss: torch.Tensor, metrics: Dict[str, Any]):
        ...


class HistoryEnsembleMixin:

    def __init__(self: DDPGAgentProtocol, use_ensemble: bool, store_every_steps: int, num_policies: int, penalty_weight: float, actor_loss_weight: float, actor_ensemble_steps: List[int], actor_ensemble_sliding_window: bool,  optimal_prior_snapshot_path: Optional[str], *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_policies = num_policies
        self.penalty_weight = penalty_weight
        self.actor_loss_weight = actor_loss_weight
        self.store_every_steps = store_every_steps
        self.actor_ensemble_steps = actor_ensemble_steps
        self.use_ensemble = use_ensemble
        self.actor_ensemble = deque(maxlen=num_policies)
        self.actor_ensemble_loss_fn = nn.KLDivLoss(
            reduction="batchmean", log_target=True)
        self.ensemble_dist = None
        self.ensemble_actor = deepcopy(self.actor)
        self.ensemble_actor_opt = torch.optim.Adam(
            self.ensemble_actor.parameters(), lr=self.lr)
        self.actor_ensemble_sliding_window = actor_ensemble_sliding_window
        self.optimal_prior_snapshot_path = optimal_prior_snapshot_path
        self.optimal_prior_actor = None
        if self.optimal_prior_snapshot_path is not None:
            self.optimal_prior_actor = torch.load(self.optimal_prior_snapshot_path)["agent"].actor
        self.actor_added = False

    def update_actor(self: DDPGAgentProtocol, obs: torch.Tensor, step: int):
        metrics = dict()

        if self.reward_free and (not self.actor_ensemble_sliding_window and step in self.actor_ensemble_steps) or (self.actor_ensemble_sliding_window and step % self.store_every_steps == 0):
            print(f"Adding actor to ensemble at step {step}")
            self.actor_ensemble.append(deepcopy(self.actor))
            self.actor_added = True

        if (self.use_tb or self.use_wandb) and self.reward_free:
            # initialize metrics dict
            metrics['ensemble_logprob'] = 0
            metrics['ensemble_mean'] = 0
            metrics['ensemble_stddev'] = 0
            metrics['ensemble_loss'] = 0

        stddev = utils.schedule(self.stddev_schedule, step)
        dist = self.actor(obs, stddev)
        action = dist.sample(clip=self.stddev_clip)
        log_prob = dist.log_prob(action)
        Q1, Q2 = self.critic(obs, action)
        Q = torch.min(Q1, Q2)

        actor_loss = -Q.mean() * self.actor_loss_weight

        return self.add_ensemble_loss(obs=obs, step=step, dist=dist, log_prob=log_prob, action=action, stddev=stddev, actor_loss=actor_loss, metrics=metrics)

    def add_ensemble_loss(self: DDPGAgentProtocol, obs: torch.Tensor, step: int, dist: utils.TruncatedNormal, log_prob: torch.Tensor, action: torch.Tensor, stddev: float, actor_loss: torch.Tensor, metrics: Dict[str, Any]):
        if self.actor_ensemble and self.use_ensemble:
            if self.ensemble_dist is None or self.actor_added or (isinstance(self.ensemble_dist, torch.distributions.MixtureSameFamily) and self.ensemble_dist._num_component != len(self.actor_ensemble)):
                self.ensemble_dist = self.ensemble_policy(
                    obs, step)
                self.actor_added = False
            ensemble_log_prob = self.ensemble_dist.log_prob(action)
            ensemble_loss = self.actor_ensemble_loss_fn(
                dist.log_prob(action), ensemble_log_prob)
            actor_loss += self.penalty_weight * ensemble_loss
            if self.use_tb or self.use_wandb:
                metrics['ensemble_logprob'] = ensemble_log_prob.sum(
                    -1, keepdim=True).mean().item()
                metrics['ensemble_mean'] = self.ensemble_dist.mean.mean().item()
                metrics['ensemble_stddev'] = self.ensemble_dist.stddev.mean().item()
                metrics['ensemble_loss'] = ensemble_loss.item()

        # optimize actor
        self.actor_opt.zero_grad(set_to_none=True)
        actor_loss.backward()
        self.actor_opt.step()

        if self.use_tb or self.use_wandb:
            metrics['actor_loss'] = actor_loss.item()
            metrics['actor_logprob'] = log_prob.sum(
                -1, keepdim=True).mean().item()
            metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()

        return metrics

    def ensemble_policy(self: DDPGAgentProtocol, obs: torch.Tensor, step: int):
        with torch.no_grad():
            stddev = utils.schedule(self.stddev_schedule, step)
            if self.optimal_prior_actor is None:
                # ensemble + policy distributions
                dists = [actor(obs, stddev)
                        for actor in (list(self.actor_ensemble) + list([self.actor]))]
                mix = torch.distributions.Categorical(
                    torch.ones(len(self.actor_ensemble) + 1,).to(self.device))
                mm = torch.distributions.MixtureSameFamily(mix, utils.TruncatedNormal(loc=torch.stack(
                    [dist.loc for dist in dists], dim=-1), scale=torch.stack([dist.scale for dist in dists], dim=-1)))
                return mm
            else:
                return self.optimal_prior_actor(obs, stddev)

    def init_from(self: DDPGAgentProtocol, other: DDPGAgentProtocol):
        cast(DDPGAgentProtocol, super()).init_from(other)
        self.actor_ensemble = other.actor_ensemble

