import warnings
from typing import Any, Dict, Optional, Type, Union

import numpy as np
import torch
import torch as th
import gym
from gym import spaces
from torch.nn import functional as F
import time

from stable_baselines3.common.utils import explained_variance


from lending_experiment.agents.on_policy_algorithm import OnPolicyAlgorithm


class POCAR(OnPolicyAlgorithm):
    def __init__(
        self,
        env: gym.Env,
        learning_rate: float = 1e-5,
        beta_0: float = 1,
        beta_1: float = 0.5,
        beta_2: float = 0.5,
        omega: float =0.1,
        n_steps: int = 2048,
        batch_size: int = 64,
        n_epochs: int = 10,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        clip_range: float = 0.2,
        normalize_advantage: bool = True,
        ent_coef: float = 0.,
        vf_coef: float = 0.5,
        max_grad_norm: float = 0.5,
        target_kl: Optional[float] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        **kwargs,
    ):

        super(POCAR, self).__init__(
            env=env,
            learning_rate=learning_rate,
            n_steps=n_steps,
            gamma=gamma,
            gae_lambda=gae_lambda,
            ent_coef=ent_coef,
            vf_coef=vf_coef,
            max_grad_norm=max_grad_norm,
            policy_kwargs=policy_kwargs,
            device=device,
            seed=seed,
        )

        # Sanity check, otherwise it will lead to noisy gradient and NaN
        # because of the advantage normalization
        if normalize_advantage:
            assert (
                batch_size > 1
            ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"

        self.utility_method = env.utility_method
        self.beta_0 = beta_0
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.omega = omega
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.clip_range = clip_range
        self.normalize_advantage = normalize_advantage
        self.target_kl = target_kl
        self.predictor_steps = 300  # initial value

        self._setup_model()

    def train(self) -> None:
        """
        Update policy using the currently gathered rollout buffer.
        """
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.train()

        entropy_losses = []
        pg_losses, value_losses = [], []
        clip_fractions = []
        continue_training = True

        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer
            for rollout_data in self.rollout_buffer.get(self.batch_size):
                actions = rollout_data.actions
                if isinstance(self.action_space, spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                _, values, log_prob, entropy = self.policy.get_action_and_value(
                    rollout_data.observations, actions
                )
                values = values.flatten()
                # Advantages shape: (batch_size,)
                advantages = rollout_data.advantages

                # Compute value-thresholding (vt) term as part of Eq. 3 from the paper
                vt_term = th.min(
                    th.zeros(rollout_data.delta_obs.shape[0]).to(self.device),
                    -rollout_data.delta_obs.abs() + th.tensor(self.omega, dtype=th.float32),
                )

                # Compute decrease-in-violation (div) term as part of Eq. 3 from the paper
                div_cond = th.where(
                    rollout_data.delta_obs.abs()
                    > th.tensor(self.omega, dtype=th.float32).to(self.device),
                    th.tensor(1, dtype=th.float32).to(self.device),
                    th.tensor(0, dtype=th.float32).to(self.device),
                )
                div_term = th.min(
                    th.zeros(rollout_data.delta_deltas.shape[0]).to(self.device),
                    -div_cond * rollout_data.delta_deltas,
                )

                # Bring the 3 terms to scale for numerical stability
                advantages = (advantages - th.min(advantages)) / (
                    th.max(advantages) - th.min(advantages) + 1e-8
                )
                vt_term = (vt_term - th.min(vt_term)) / (
                    th.max(vt_term) - th.min(vt_term) + 1e-8
                )
                div_term = (div_term - th.min(div_term)) / (
                    th.max(div_term) - th.min(div_term) + 1e-8
                )

                # Add terms to advantages
                advantages = (
                    self.beta_0 * advantages
                    + self.beta_1 * vt_term
                    + self.beta_2 * div_term
                )

                # Normalize advantage
                if self.normalize_advantage:
                    advantages = (advantages - advantages.mean()) / (
                        advantages.std() + 1e-8
                    )

                # ratio between old and new policy, should be one at the first iteration
                ratio = th.exp(log_prob - rollout_data.old_log_prob)

                # clipped surrogate loss
                policy_loss_1 = advantages * ratio
                policy_loss_2 = advantages * th.clamp(
                    ratio, 1 - self.clip_range, 1 + self.clip_range
                )
                policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()

                # Logging
                pg_losses.append(policy_loss.item())
                clip_fraction = th.mean(
                    (th.abs(ratio - 1) > self.clip_range).float()
                ).item()
                clip_fractions.append(clip_fraction)

                values_pred = rollout_data.old_values + th.clamp(
                    values - rollout_data.old_values, -self.clip_range, self.clip_range
                )
                # Value loss using the TD(gae_lambda) target
                value_loss = F.mse_loss(rollout_data.returns, values_pred)
                value_losses.append(value_loss.item())

                # Entropy loss favor exploration
                if entropy is None:
                    # Approximate entropy when no analytical form
                    entropy_loss = -th.mean(-log_prob)
                else:
                    entropy_loss = -th.mean(entropy)

                entropy_losses.append(entropy_loss.item())

                loss = (
                    policy_loss
                    + self.ent_coef * entropy_loss
                    + self.vf_coef * value_loss
                )

                # Calculate approximate form of reverse KL Divergence for early stopping
                # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
                # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
                # and Schulman blog: http://joschu.net/blog/kl-approx.html
                with th.no_grad():
                    log_ratio = log_prob - rollout_data.old_log_prob
                    approx_kl_div = (
                        th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
                    )
                    approx_kl_divs.append(approx_kl_div)

                if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
                    continue_training = False
                    if self.verbose >= 1:
                        print(
                            f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}"
                        )
                    break

                # Optimization step
                self.policy.optimizer.zero_grad()
                loss.backward()
                # Clip grad norm
                th.nn.utils.clip_grad_norm_(
                    self.policy.parameters(), self.max_grad_norm
                )
                self.policy.optimizer.step()

            if not continue_training:
                break

        # self._n_updates += self.n_epochs
        explained_var = explained_variance(
            self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()
        )

        # Logs
        self.logger.record("train/entropy_loss", np.mean(entropy_losses))
        self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
        self.logger.record("train/value_loss", np.mean(value_losses))
        self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
        self.logger.record("train/clip_fraction", np.mean(clip_fractions))
        self.logger.record("train/loss", loss.item())
        self.logger.record("train/explained_variance", explained_var)
        self.logger.record(
            "train/accept_rate", np.mean(self.rollout_buffer.actions.flatten())
        )
        self.logger.record(
            "train/pos_rate", np.mean(self.rollout_buffer.labels.flatten())
        )
        self.logger.record("train/reward", self.rollout_buffer.rewards.mean().item())

        # Logs some group-dependent variables
        g0_idx = (self.rollout_buffer.groups[:, 0] == 1).nonzero()
        g1_idx = (self.rollout_buffer.groups[:, 1] == 1).nonzero()

        accept_rate = [
            self.rollout_buffer.actions[g0_idx, 0].mean().item(),
            self.rollout_buffer.actions[g1_idx, 0].mean().item(),
        ]

        accuracy = (
            (self.rollout_buffer.labels[:, 0] == self.rollout_buffer.preds[:, 0])
            .mean()
            .item()
        )

        self.logger.record("train/accept_g0", accept_rate[0])
        self.logger.record("train/accept_g1", accept_rate[1])
        self.logger.record("train/delta", self.rollout_buffer.deltas.mean().item())
        self.logger.record("train/delta_obs", self.rollout_buffer.delta_obs.mean().item())
        self.logger.record("train/delta_delta", self.rollout_buffer.delta_deltas.mean().item())
        self.logger.record("train/accuracy", accuracy)

        if hasattr(self.policy, "log_std"):
            self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
