
import numpy as np
import torch
from torch import nn
from numba import njit

from network import ActorCritic
from utils import RunningMeanStd
from data import Batch



class PPOPolicy(nn.Module):
    def __init__(self, actor, critic, optim, dist_fn,
        eps_clip=0.2, discount_factor=0.99, reward_normalization=False,
        action_scaling=True, action_bound_method="clip", deterministic_eval=False,
        dual_clip=None, value_clip=True, advantage_normalization=True, recompute_advantage=False,
        vf_coef=0.5, ent_coef=0.01, max_grad_norm=None, gae_lambda=0.95,
        max_batchsize=256, observation_space=None, action_space=None,lr_scheduler=None,
        **kwargs):
        super().__init__()
        self.observation_space = observation_space
        self.action_space = action_space
        self.agent_id = 0
        self.action_scaling = action_scaling
        # can be one of ("clip", "tanh", ""), empty string means no bounding
        assert action_bound_method in ("", "clip", "tanh")
        self.action_bound_method = action_bound_method

        self.actor = actor
        self.optim = optim
        self.lr_scheduler = lr_scheduler
        self.dist_fn = dist_fn
        assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
        self._gamma = discount_factor
        self._rew_norm = reward_normalization
        self.ret_rms = RunningMeanStd()
        self._eps = np.finfo(np.float32).eps.item()
        self._deterministic_eval = deterministic_eval

        self.critic = critic
        assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
        self._lambda = gae_lambda
        self._weight_vf = vf_coef
        self._weight_ent = ent_coef
        self._grad_norm = max_grad_norm
        self._batch = max_batchsize
        self._actor_critic = ActorCritic(self.actor, self.critic)

        self._eps_clip = eps_clip
        assert dual_clip is None or dual_clip > 1.0, \
            "Dual-clip PPO parameter should greater than 1.0."
        self._dual_clip = dual_clip
        self._value_clip = value_clip
        if not self._rew_norm:
            assert not self._value_clip, \
                "value clip is available only when `reward_normalization` is True"
        self._norm_adv = advantage_normalization
        self._recompute_adv = recompute_advantage


    def process_fn(self, batch, buffer, **kwargs) :
        if self._recompute_adv:
            self._buffer = buffer
        batch = self._compute_returns(batch, buffer)
        device = batch.v_s.device
        dtype = batch.v_s.dtype
        batch.act = torch.from_numpy(batch.act).type(dtype).to(device)
        batch.feature = torch.from_numpy(batch.feature).type(dtype).to(device)
        old_log_prob = []
        with torch.no_grad():
            for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
                old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act))
        batch.logp_old = torch.cat(old_log_prob, dim=0)
        return batch


    def _compute_returns(self, batch, buffer):
        v_s, v_s_ = [], []
        with torch.no_grad():
            for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
                v_s.append(self.critic(minibatch.obs))
                v_s_.append(self.critic(minibatch.obs_next))
        batch.v_s = torch.cat(v_s, dim=0).flatten()  # old value
        v_s = batch.v_s.cpu().numpy()
        v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
        if self._rew_norm:  # unnormalize v_s & v_s_
            v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
            v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
        unnormalized_returns, advantages = self.compute_episodic_return(
            batch,
            buffer,
            v_s_,
            v_s,
            gamma=self._gamma,
            gae_lambda=self._lambda
        )
        if self._rew_norm:
            batch.returns = unnormalized_returns / \
                np.sqrt(self.ret_rms.var + self._eps)
            self.ret_rms.update(unnormalized_returns)
            # print(f'mean:{self.ret_rms.mean}, var:{self.ret_rms.var}')
        else:
            batch.returns = unnormalized_returns
        device = batch.v_s.device
        dtype = batch.v_s.dtype
        batch.returns = torch.from_numpy(batch.returns).type(dtype).to(device)
        batch.adv = torch.from_numpy(advantages).type(dtype).to(device)
        return batch

    @staticmethod
    def compute_episodic_return(batch, buffer,
        v_s_=None, v_s=None, gamma=0.99, gae_lambda=0.95):
        rew = batch.rew
        if v_s_ is None:
            assert np.isclose(gae_lambda, 1.0)
            v_s_ = np.zeros_like(rew)
        else:
            v_s_ = v_s_.flatten()  # type: ignore
            v_s_ = v_s_ * (~batch.done)  
        v_s = np.roll(v_s_, 1) if v_s is None else v_s.flatten()

        end_flag = batch.done.copy()
        end_flag[buffer.unfinished_index()] = True
        advantage = _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda)
        returns = advantage + v_s
        # normalization varies from each policy, so we don't do it here
        return returns, advantage

    def forward(self, batch, **kwargs):
        logits, feature = self.actor(batch.obs)
        dist = self.dist_fn(*logits)
        if self._deterministic_eval and not self.training:
            act = logits[0]              
        else:
            act = dist.sample()
        return Batch(logits=logits, act=act, dist=dist, feature=feature)        

    def learn(self, batch, batch_size, repeat, **kwargs):
        losses, clip_losses, vf_losses, ent_losses = [], [], [], []
        ratios, kl_data = [], []
        for step in range(repeat):
            if self._recompute_adv and step > 0:
                batch = self._compute_returns(batch, self._buffer)
            for minibatch in batch.split(batch_size, merge_last=True):
                # calculate loss for actor
                dist = self(minibatch).dist
                if self._norm_adv:
                    mean, std = minibatch.adv.mean(), minibatch.adv.std()
                    minibatch.adv = (minibatch.adv - mean) / std  # per-batch norm
                ratio = (dist.log_prob(minibatch.act) -
                         minibatch.logp_old).exp().float()
                ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
                surr1 = ratio * minibatch.adv
                surr2 = ratio.clamp(
                    1.0 - self._eps_clip, 1.0 + self._eps_clip
                ) * minibatch.adv
                if self._dual_clip:
                    clip1 = torch.min(surr1, surr2)
                    clip2 = torch.max(clip1, self._dual_clip * minibatch.adv)
                    clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean()
                else:
                    clip_loss = -torch.min(surr1, surr2).mean()
                # calculate loss for critic
                value = self.critic(minibatch.obs).flatten()
                if self._value_clip:
                    v_clip = minibatch.v_s + \
                        (value - minibatch.v_s).clamp(-self._eps_clip, self._eps_clip)
                    vf1 = (minibatch.returns - value).pow(2)
                    vf2 = (minibatch.returns - v_clip).pow(2)
                    vf_loss = torch.max(vf1, vf2).mean()
                else:
                    vf_loss = (minibatch.returns - value).pow(2).mean()
                # calculate regularization and overall loss
                ent_loss = dist.entropy().mean()
                loss = clip_loss + self._weight_vf * vf_loss \
                    - self._weight_ent * ent_loss
                self.optim.zero_grad()
                loss.backward()
                if self._grad_norm:  # clip large gradient
                    nn.utils.clip_grad_norm_(
                        self._actor_critic.parameters(), max_norm=self._grad_norm
                    )
                self.optim.step()
                clip_losses.append(clip_loss.item())
                vf_losses.append(vf_loss.item())
                ent_losses.append(ent_loss.item())
                losses.append(loss.item())
                ratios.append(ratio.mean().item())
                kl_data.append(ratio.log().mean().item())
        # update learning rate if lr_scheduler is given
        if self.lr_scheduler is not None:
            # print(self.lr_scheduler.get_lr())
            if self.lr_scheduler.get_lr()[0]>1e-5:
                self.lr_scheduler.step()

        return {
            "loss": losses,
            "loss/clip": clip_losses,
            "loss/vf": vf_losses,
            "loss/ent": ent_losses,
            "ratio": ratios,
            "kl": kl_data,
        }

    def update(self, buffer, **kwargs):
        batch = buffer.sample()
        batch = self.process_fn(batch, buffer, **kwargs)
        result = self.learn(batch, **kwargs)
        return result

    def map_action(self, act):
        if self.action_bound_method == "clip":
            act = np.clip(act, -1.0, 1.0)
        elif self.action_bound_method == "tanh":
            act = np.tanh(act)
        if self.action_scaling:
            assert np.min(act) >= -1.0 and np.max(act) <= 1.0, \
                "action scaling only accepts raw action range = [-1, 1]"
            low, high = self.action_space.low, self.action_space.high
            act = low + (high - low) * (act + 1.0) / 2.0  # type: ignore
        return act



@njit
def _gae_return(
    v_s: np.ndarray,
    v_s_: np.ndarray,
    rew: np.ndarray,
    end_flag: np.ndarray,
    gamma: float,
    gae_lambda: float,
) -> np.ndarray:
    returns = np.zeros(rew.shape)
    delta = rew + v_s_ * gamma - v_s
    discount = (1.0 - end_flag) * (gamma * gae_lambda)
    gae = 0.0
    for i in range(len(rew) - 1, -1, -1):
        gae = delta[i] + discount[i] * gae
        returns[i] = gae
    return returns

