from abc import ABC, abstractmethod
import numpy as np
import gym
import torch
from torch._C import device
from torch.distributions.categorical import Categorical
from torch.distributions.utils import probs_to_logits
from torch.nn import functional as F

from utils import DictList, ParallelEnv
from utils.format import default_preprocess_obss
from gym import spaces

class BaseAlgo(ABC):
    """The base class for the PPA algorithm."""

    def __init__(self, 
                    envs, 
                    model, 
                    reward_dim,
                    device, 
                    num_frames_per_proc, 
                    discount, 
                    lr, 
                    gae_lambda, 
                    entropy_coef, 
                    value_loss_coef, 
                    max_grad_norm, 
                    preprocess_obss,
                    action_dim=1,
                    obs_clip=None):
        """
        Initializes a `BaseAlgo` instance.

        Parameters:
        ----------
        envs : list
            a list of environments that will be run in parallel
        model : torch.Module
            the agent
        reward_dim : int
            the dimension of the reward
        device: torch.device
            the device to run model on 
        num_frames_per_proc : int
            the number of frames collected by every process for an update
        discount : float
            the discount for future rewards
        lr : float
            the learning rate for optimizers
        gae_lambda : float
            the lambda coefficient in the GAE formula
            ([Schulman et al., 2015](https://arxiv.org/abs/1506.02438))
        entropy_coef : float
            the weight of the entropy cost in the final objective
        value_loss_coef : float
            the weight of the value loss in the final objective
        max_grad_norm : float
            gradient will be clipped to be at most this value
        preprocess_obss : function
            a function that takes observations returned by the environment
            and converts them into the format that the model can handle
        action_dim: int
            The dimention of the action space (1 for discrete actions)
        obs_clip: float (default = None)
            If set, clip obs to given value
        """
        # Store parameters
        self.env = ParallelEnv(envs)
        self.model = model
        self.reward_dim = reward_dim
        self.device = device
        self.num_frames_per_proc = num_frames_per_proc
        self.discount = discount
        self.lr = lr
        self.gae_lambda = gae_lambda
        self.entropy_coef = entropy_coef
        self.value_loss_coef = value_loss_coef
        self.max_grad_norm = max_grad_norm
        self.preprocess_obss = preprocess_obss or default_preprocess_obss
        self.action_dim = action_dim
        self.obs_clip = obs_clip

        # Configure models
        self.model.to(self.device)
        self.model.train()

        # Store helpers values
        self.num_procs = len(envs)
        self.num_frames = self.num_frames_per_proc * self.num_procs

        # Initialize experience values
        shape = (self.num_frames_per_proc, self.num_procs)
        shape_reward = (self.num_frames_per_proc, self.num_procs, self.reward_dim)
        shape_actions = (self.num_frames_per_proc, self.num_procs, self.action_dim)
    
        self.obs = self.env.reset()
        self.obss = [None]*(shape[0])
        self.mask = torch.ones(shape[1])
        self.masks = torch.zeros(*shape)
        self.actions = torch.zeros(*shape_actions, dtype=torch.int)
        self.actions = torch.squeeze(self.actions)
        self.values = torch.zeros(*shape_reward)
        self.rewards = torch.zeros(*shape_reward)
        self.advantages = torch.zeros(*shape_reward)
        self.log_probs = torch.zeros(*shape)

        # Init reward statistics 
        self.reward_mean = np.zeros(shape=(self.reward_dim,))
        self.reward_std = np.zeros(shape=(self.reward_dim,))
        self.reward_cnt = 0

        # Initialize log values
        self.log_done_counter = 0
        self.log_episode_return = torch.zeros((self.num_procs, self.reward_dim))
        self.log_episode_num_frames = torch.zeros(self.num_procs)
        self.log_reward_per_episode = [[0]*self.reward_dim] # need to init to non-empty list
        self.log_num_frames = []


    def collect_experiences(self):
        """Collects rollouts and computes advantages.

        Runs several environments concurrently. The next actions are computed
        in a batch mode for all environments at the same time. The rollouts
        and advantages from all environments are concatenated together.

        Returns
        -------
        exps : DictList
            Contains actions, rewards, advantages etc as attributes.
            Each attribute, e.g. `exps.reward` has a shape
            (self.num_frames_per_proc * num_envs, ...). k-th block
            of consecutive `self.num_frames_per_proc` frames contains
            data obtained from the k-th environment. Be careful not to mix
            data from different environments!
        logs : dict
            Useful stats about the training process, including the average
            reward, policy loss, value loss, etc.
        """

        for i in range(self.num_frames_per_proc):
            # Do one agent-environment interaction
            preprocessed_obs = self.preprocess_obss(self.obs, device=self.device)
            with torch.no_grad():
                dist, value = self.model(preprocessed_obs)
                    
            # Sample action

            action = dist.sample()

            # Get obs, reward, cost, done 
            obs, reward, done, _ = self.env.step(action.cpu().numpy())
            reward = np.array(reward)
            reward = reward[:,:self.reward_dim]

            for r in reward:
                self._update_reward_statistics(r)
            
            # Update experiences values
            self.obss[i] = self.obs
            self.obs = obs
            self.masks[i] = self.mask
            self.mask = 1 - torch.tensor(done, dtype=torch.float)
            self.actions[i] = action
            self.values[i] = value.cpu()
            self.rewards[i] = torch.tensor(reward)
            log_probs = dist.log_prob(action)

            # sum log probs for action_dim > 1
            if self.action_dim > 1:
                log_probs = log_probs.sum(dim=-1)
            self.log_probs[i] = log_probs

            # Update log values
            self.log_episode_return += torch.tensor(reward, dtype=torch.float)
            self.log_episode_num_frames += torch.ones(self.num_procs)

            for j, done_ in enumerate(done):
                if done_:
                    self.log_done_counter += 1
                    rep = []
                    for d in range(self.reward_dim):
                        rep.append(self.log_episode_return[j,d].item())
                    self.log_reward_per_episode.append(rep)
                    self.log_num_frames.append(self.log_episode_num_frames[j].item())

            mask = torch.unsqueeze(self.mask, -1).repeat(1,self.reward_dim)
            self.log_episode_return *= mask
            self.log_episode_num_frames *= self.mask

        preprocessed_obs = self.preprocess_obss(self.obs, device=self.device)
        with torch.no_grad():
            _, next_value = self.model(preprocessed_obs)
        
        # self.rewards = (self.rewards - self.reward_mean) / (np.sqrt(self.reward_std/self.reward_cnt) + 1e-8)
        # self.rewards = torch.clip(self.rewards, -10, 10)

        # Calculate advantages for all rewards 
        for i in range(self.reward_dim):
            self.advantages[:,:,i] = self._calc_advantages(values=self.values[:,:,i], 
                                                            next_values=next_value[:,i].cpu(),
                                                            rewards=self.rewards[:,:,i])

        exps = DictList()
        exps.obss = [self.obss[i][j]
                    for j in range(self.num_procs)
                    for i in range(self.num_frames_per_proc)]
        exps.obss = self.preprocess_obss(exps.obss, device=self.device)
        # for all tensors below, T x P -> P x T -> P * T
        exps.actions = self.actions.transpose(0, 1).reshape((-1, self.action_dim))
        exps.actions = torch.squeeze(exps.actions)
        exps.masks = self.masks.transpose(0,1).reshape(-1)
        exps.valuess = self.values.transpose(0, 1).reshape((-1, self.reward_dim))
        exps.rewards = self.rewards.transpose(0, 1).reshape((-1, self.reward_dim))
        exps.advantages = self.advantages.transpose(0, 1).reshape((-1, self.reward_dim))
        exps.log_probs = self.log_probs.transpose(0,1).reshape(-1)
        exps.returns = exps.advantages + exps.valuess

        # Log some values
        keep = self.log_done_counter
        logs = {
            "return_per_episode": self.log_reward_per_episode[-keep:],
            "num_frames_per_episode": self.log_num_frames[-keep:],
            "num_frames": self.num_frames
        }

        # Reset counter
        self.log_done_counter = 0

        return exps, logs

    def _update_reward_statistics(self, reward):
        prev = self.reward_mean
        self.reward_cnt += 1
        self.reward_mean += (reward - self.reward_mean)/self.reward_cnt
        self.reward_std += (reward - self.reward_mean)*(reward-prev)

    def _normalize(self, tens, dim=None):
        if dim:
            return (tens - tens.mean(dim=dim)) / (tens.std(dim=dim) + 1e-8)
        else:
            return (tens - tens.mean()) / (tens.std() + 1e-8)

    def _get_batches(self):
        """Generates random batches of size self.batch_size

        Returns
        -------
        batches : list of list of int
            the indexes of the experiences to be used at first for each batch
        """

        indices = np.arange(0, self.num_frames)

        np.random.shuffle(indices)
        batches_start = np.arange(0,self.num_frames, self.batch_size)[1:]
        batches = np.array_split(indices, batches_start)

        return batches

    def _calc_advantages(self, values, next_values, rewards):
        shape = (self.num_frames_per_proc, self.num_procs)
        advantages = torch.zeros(*shape)

        for i in reversed(range(self.num_frames_per_proc)):
            next_mask = self.masks[i+1] if i < self.num_frames_per_proc - 1 else self.mask
            next_values = values[i+1] if i < self.num_frames_per_proc - 1 else next_values
            next_advantage = advantages[i+1] if i < self.num_frames_per_proc - 1 else 0

            delta = rewards[i] + self.discount * next_values * next_mask - values[i]
            advantages[i] = delta + self.discount * self.gae_lambda * next_advantage * next_mask

        return advantages

    def _actor_loss(self, log_probs_new, log_probs_old, advantages):

        advantages = self._normalize(advantages, dim=0)
        # Get prob ratio
        probs_ratio = log_probs_new.exp() / (log_probs_old.exp() + 1e-8) 

        # Multiply with advantages
        probs_weighted = advantages * probs_ratio

        # probs_weighted = torch.transpose(probs_weighted, 0, 1)
        probs_weighted_clipped = torch.clamp(probs_ratio, 1-self.clip_eps, 1+self.clip_eps)*advantages
        actor_loss = -torch.min(probs_weighted_clipped, probs_weighted).mean()
        return actor_loss

    def _critic_loss(self, critic_value, returns):
        # Get critic loss 
        returns = self._normalize(returns, dim=0)
        critic_loss = F.mse_loss(returns, critic_value)
        return critic_loss

    @abstractmethod
    def update_parameters(self):
        pass
