import itertools
import os, sys
from copy import deepcopy

import numpy as np
import torch
from torch.optim import Adam
import torch.nn.functional as F

#from common.gym_utils.utils import obs_map
#from common.models.network_gym import ActorCritic
from baselines.rl.ppo.core import MLPQActorCritic as ActorCritic
from baselines.rl.ppo.ppo import PPOAgent, PPOBuffer
import baselines.rl.ppo.core as core
from common.gym_utils.mpi_tools import mpi_avg, mpi_statistics_scalar
from common.gym_utils.mpi_pytorch import mpi_avg_grads

from scipy.interpolate import interpn

DEVICE = 'cpu' #torch.device('cuda') if torch.cuda.is_available() else 'cpu'
print(DEVICE)

import ipdb as pdb


GROUP_HAZARD = 3
GROUP_VASE = 4

class SPARBuffer(PPOBuffer):
    """
    A buffer for storing trajectories experienced by a PPO agent interacting
    with the environment, and using Generalized Advantage Estimation (GAE-Lambda)
    for calculating the advantages of state-action pairs.
    """

    # TODO: 

    def __init__(self, obs_dim, act_dim, size, beta_schedule, alpha=0.6, gamma=0.99, lam=0.97):

        super().__init__(obs_dim, act_dim, size, gamma=gamma, lam=lam)

        self.obs2_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
        self.lx_buf = np.zeros(size, dtype=np.float32) ## l(x') - l(x) 
        self.d_lx_buf = np.zeros(size, dtype=np.float32) ## l(x') - l(x) 
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.actor_buf = np.zeros(size, dtype=np.float32)
        self.cost_buf = np.zeros(size, dtype=np.float32)
        
        self.current_size = 0

        ## Priortized Replay Memory
        self.priorities = 1e-1 * np.ones(size, dtype=np.float32)
        self.idxs = None
        self.alpha = alpha
        self.beta_schedule = beta_schedule 
        self.beta = self.beta_schedule[0]

    def store(self, obs, act, rew, cost, val, logp, obs2, lx, d, actor):
        """
        Append one timestep of agent-environment interaction to the buffer.
        """
        assert self.ptr < self.max_size     # buffer has to have room so you can store
        self.obs_buf[self.ptr] = obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.val_buf[self.ptr] = val
        self.logp_buf[self.ptr] = logp
        self.obs2_buf[self.ptr] = obs2
        self.lx_buf[self.ptr] = lx
        self.done_buf[self.ptr] = d
        self.cost_buf[self.ptr] = cost
        self.actor_buf[self.ptr] = actor
        self.ptr += 1
        self.current_size = min(self.max_size, self.current_size + 1)

    def get(self):
        """
        Call this at the end of an epoch to get all of the data from
        the buffer, with advantages appropriately normalized (shifted to have
        mean zero and std one). Also, resets some pointers in the buffer.
        """
        assert self.ptr == self.max_size    # buffer has to be full before you can get
        self.ptr, self.path_start_idx = 0, 0

        idxs = self.actor_buf == 0

        # the next two lines implement the advantage normalization trick
        adv_mean, adv_std = mpi_statistics_scalar(self.adv_buf)
        self.adv_buf = (self.adv_buf - adv_mean) / adv_std
        data = dict(obs=self.obs_buf[idxs], act=self.act_buf[idxs], ret=self.ret_buf[idxs],
                    adv=self.adv_buf[idxs], logp=self.logp_buf[idxs])

        return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in data.items()}

    def sample_batch(self, batch_size=64):
        """
        For updating the SPAR target
        """

        #idxs = self.actor_buf == 1 # selection mask for safety actor instances 

        #idxs = np.random.choice(self.current_size, 
        #        size=min(batch_size, self.current_size), 
        #        replace = False)
        priorities = self.priorities[:self.current_size]
        prob = priorities**self.alpha / np.sum(priorities**self.alpha)
        idxs = np.random.choice(self.current_size, size=min(batch_size, self.current_size), p = prob)
        self.priority_idxs = idxs ## Keep record for use when updating priorities

        #weights = self.size * prob[idxs] ** (-self.beta)
        #self.weights = torch.tensor(weights / weights.max(), dtype=torch.float32, device=DEVICE)

        batch = dict(obs=self.obs_buf[idxs],
                     obs2=self.obs2_buf[idxs],
                     act=self.act_buf[idxs],
                     rew=self.rew_buf[idxs],
                     cost=self.cost_buf[idxs],
                     adv=self.adv_buf[idxs], # inherited from parent buf class
                     logp=self.logp_buf[idxs],
                     lx = self.lx_buf[idxs],
                     d_lx = self.d_lx_buf[idxs],
                     done=self.done_buf[idxs],
                     actor=self.actor_buf[idxs])

        return {k: torch.tensor(v, dtype=torch.float32, device=DEVICE) for k, v in batch.items()}

    def update_priorities(self, td_error: np.array):
        self.priorities[self.priority_idxs] = td_error

class SPARGym(PPOAgent):
    """
    """
    def __init__(self, env, cfg, args, save_episodes=True, 
            t_start=0, reward_scale=100, **kwargs):
        super().__init__(lambda: env, **kwargs)
        self.args = args
        self.args.safety_margin = float(self.args.safety_margin)
        self.cfg = cfg
        self.reward_scale = reward_scale

        self.target_kl = self.cfg['target_kl']
        print(f"Target KL ={self.target_kl }")
        print(f"\epsilon ={self.clip_ratio }")
        print(f"pi_lr={self.pi_lr}, critic_lr={self.vf_lr}") 
        #print(f"Safety Marigin = {self.args.safety_margin}")
        
        self.env = env
        beta_schedule = np.linspace(0.4, 1, self.epochs + 1)
        
        self.buf = SPARBuffer(self.obs_dim, self.act_dim, self.local_steps_per_epoch, beta_schedule, 
                                gamma = self.gamma, lam =self.lam)
        print(f"gamma={self.gamma}, lambda={self.lam}")

        ## Instantiate the safety actor-critic
        self.safety_ac = ActorCritic(self.env.observation_space,
                self.env.action_space)
        self.safety_ac = self.safety_ac.float()

        self.safety_q_target = deepcopy(self.safety_ac.q1)

        # Set up optimizers for policy and q-function
        self.safety_pi_optimizer = Adam(self.safety_ac.pi.parameters(), lr=self.cfg['actor_lr'])
        self.safety_q_optimizer = Adam(self.safety_ac.q1.parameters(), lr=self.cfg['critic_lr'])

        self.safe_gamma_schedule = 1-np.logspace(1, 2, self.epochs + 1, base = 0.15)
         
        self.current_epoch = 0
        self.safe_gamma = self.safe_gamma_schedule[self.current_epoch]
        
        self.lx_prev = None

        self.margin = self.cfg['safety_margin']
        print(f"Safety Margin = {self.margin}")

    def _find_lx(self,):
        lx = 1-self.env.obs_lidar(self.env.hazards_pos, GROUP_HAZARD)
        lx *= self.env.lidar_max_dist
        #lx_vase = torch.tensor(min(obs_dict['vases_lidar'][obs_dict['vases_lidar']>0]))
        return min(lx), self.env.dist_goal()

    def _check_safety(self, obs, act):
        ## feat: (33, )
        ## Check if the action from the current policy is safe
        q = self.safety_ac.q1(torch.Tensor(obs), torch.tensor(act).to(DEVICE))
        return q

    def select_action(self, t, o, deterministic=False):
        if t==0:
            print(f"Safe Gamma={self.safe_gamma}")
        ## At the epoch end
        if t==self.local_steps_per_epoch-1:
            self.current_epoch += 1
            self.safe_gamma = min(1, self.safe_gamma_schedule[self.current_epoch])
            #self.buf.beta = self.buf.beta_schedule[current_epoch]

        lx, lg = self._find_lx()

        a, v, logp = self.ac.step(torch.as_tensor(o, dtype=torch.float32))
        actor = 0 

        safety_value = self._check_safety(o, a)

        ## Add back the baseline
        safety_value = safety_value/self.reward_scale + lx
    
        if (safety_value.item() < self.margin):
            ## Use safety backup controller
            a = self.safety_ac.act(torch.tensor(o).float())
            safety_value = self._check_safety(o, a)
            safety_value = safety_value/self.reward_scale + lx
            actor = 1 
            logp = None

        next_o, r, d, info = self.env.step(a)

        # yee-haw
        c = info.get('cost', 0) #+ info.get('cost_vases', 0) 
        g = info.get('goal_met', False)
        
        print(f"{actor}: lx={lx:.2f}, Ax={safety_value.item()-lx:.3f}, c={c:.1f}, l_goal={lg:.2f}, r={r*self.reward_scale:.3f}")
        d_lx = lx-self.lx_prev if self.lx_prev is not None else 0
        if self.buf.ptr>0:
            self.buf.d_lx_buf[self.buf.ptr-1] = d_lx * self.reward_scale
        self.lx_prev = None if d else lx
        self.buf.store(o, a, r, c, v, logp, next_o, lx, d, actor)

        if self.buf.current_size >  1000:
            data = self.buf.sample_batch()
            self.update_safety_critic(data)
            self.update_safety_actor(data)
        
        return a, v, logp, next_o, r, d, c, g
    
    def update_safety_critic(self, data):
        o, a, o2, d_lx, cost, done = data['obs'], data['act'], data['obs2'], data['d_lx'], data['cost'], data['done']
        penalty_flag = cost>0

        with torch.no_grad():
            ## a2 should be the optimal safety action
            # Target actions come from *current* policy
            a2, _ = self.safety_ac.pi(o2)

            next_q_values = self.safety_q_target.forward(o2, a2) 
            ## HJ Bellman update
            #target = (1-self.gamma) * lx + self.gamma * torch.minimum(lx, next_q_values)
            #target[done.bool()] = lx[done.bool()] ## if Done target = lx
            ## The version with b[aseline
            target = self.safe_gamma * torch.minimum(torch.zeros_like(d_lx), next_q_values + d_lx)
            ## Set Boundary Condition
            target[done.bool()] = 0
            #target[penalty_flag.bool()] = self.margin - lx[penalty_flag.bool()]
            
        self.safety_q_optimizer.zero_grad()
        q_pred = self.safety_ac.q1(o, a)
        delta = target - q_pred
        self.buf.update_priorities(torch.abs(delta).detach().cpu().numpy() + 1e-3)
        loss = F.mse_loss(target, q_pred)
        loss.backward()
        self.safety_q_optimizer.step()
        # print('Updating the safety critic')
        
        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(self.safety_ac.q1.parameters(), self.safety_q_target.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target params.
                p_targ.data.mul_(self.cfg['safety_polyak'])
                p_targ.data.add_((1 - self.cfg['safety_polyak']) * p.data)
    
    def update_safety_actor(self, data):
        for p in self.safety_ac.q1.parameters():
            p.requires_grad = False

        o, l = data['obs'], data['lx']
        pi, logp_pi = self.safety_ac.pi(o)
        q_pi = self.safety_ac.q1(o, pi)

        # Entropy-regularized policy loss
        ## Only update on samples when intervention from safety controller is necessary
        # (q_pi <= self.cfg['safety_margin'])
        loss_pi = (self.cfg['alpha'] * logp_pi - q_pi)
        loss_pi = loss_pi.mean()

        self.safety_pi_optimizer.zero_grad()
        loss_pi.backward()
        self.safety_pi_optimizer.step()

        for p in self.safety_ac.q1.parameters():
            p.requires_grad = True

        

