import itertools
import os, sys
from copy import deepcopy

import numpy as np
import torch
from torch.optim import Adam
import torch.nn.functional as F

from common.models.network import Qfunction, ActorCritic
from racetracks.RaceTrack import RaceTrack
from baselines.rl.safeaac.safesac import SafeSACAgent
from baselines.rl.utils import toLocal, smooth_yaw, SafeController

from scipy.interpolate import interpn

DEVICE = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
print(DEVICE)

import ipdb as pdb

class ReplayBuffer:
    """
    Includes l(x) and a2
    """
    def __init__(self, obs_dim, act_dim, size):
        self.obs_buf = np.zeros((size, obs_dim), dtype=np.float32) #+1:spd #core.combined_shape(size, obs_dim)
        self.obs2_buf = np.zeros((size, obs_dim), dtype=np.float32) #+1:spd #core.combined_shape(size, obs_dim)
        self.act_buf = np.zeros((size, act_dim), dtype=np.float32) # core.combined_shape(size, act_dim)
        self.act2_buf = np.zeros((size, act_dim), dtype=np.float32) # core.combined_shape(size, act_dim)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.lx = np.zeros(size, dtype=np.float32) 
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size

    def store(self, obs, act, rew, next_obs, done):
        #pdb.set_trace()
        self.obs_buf[self.ptr] = obs.detach().cpu().numpy()
        self.obs2_buf[self.ptr] = next_obs.detach().cpu().numpy()
        self.act_buf[self.ptr] = act#.detach().cpu().numpy()
        self.rew_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done

        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.choice(self.size, size=min(batch_size, self.size), replace = False)
        batch = dict(obs=self.obs_buf[idxs],
                     obs2=self.obs2_buf[idxs],
                     act=self.act_buf[idxs],
                     act2=self.act2_buf[idxs],
                     rew=self.rew_buf[idxs],
                     lx = self.lx[idxs],
                     done=self.done_buf[idxs])
        return {k: torch.tensor(v, dtype=torch.float32, device=DEVICE) for k, v in batch.items()}


class SPARAgent(SafeSACAgent):
    """
    SafeSAC with Dynamic Updates (SPAR)
    Extends SafeSAC class
    """
    def __init__(self, env, cfg, args, atol = -1,
                 loggers=tuple(), save_episodes=True,  uMode = 'max', store_from_safe = True):
        super().__init__(env, cfg, args, loggers=loggers, save_episodes = save_episodes, 
                        atol = atol, # By setting atol = -1, the agent does not take random action when stuck
                        store_from_safe=store_from_safe)
        self.safety_controller = SafeController(uMode = uMode, 
                                                verbose = False, 
                                                margin = self.cfg['safety_margin'])
        print(f"Safety Marigin = {self.cfg['safety_margin']}")
        ## Instantiate the safety value function
        self.safety_q = Qfunction(cfg)
        self.safety_q.load_state_dict(torch.load(self.cfg[self.cfg['use_encoder_type']]['safety_q_statedict'],\
                    map_location=DEVICE))
        self.safety_q.to(DEVICE)
        self.safety_q_target = deepcopy(self.safety_q)
        self.safety_q_optimizer = Adam(self.safety_q.parameters(), lr=2*self.cfg['lr'])

        self.gamma = 0.85
        self.gamma_anneal = 1.00062

        # Experience buffer
        self.replay_buffer = ReplayBuffer(obs_dim=self.feat_dim, ## obs = [img_embed, speed] 
                act_dim=self.act_dim , 
                size=self.cfg['replay_size'])

        ## For calculating the shortest distance
        self.track = RaceTrack('Thruxton') 

        self.record['transition_actor'] = 'random'
        
    def _check_safety(self, feat, state):
        ## feat: (33, )
        ## Calculate and Save lx:
        self.current_env = self.env
        track_index = self.current_env.nearest_idx
        nearest_idx = track_index//self.segment_len * self.segment_len # This index is used to find the correponnding safety set.

        ## Only reload if move onto the next segment
        if nearest_idx == self.nearest_idx:
            pass 
        else: 
            self.safety = np.load(f"{self.safety_data_path}/{nearest_idx}.npz", allow_pickle=True)
            self.nearest_idx = nearest_idx
            self.grid = (self.safety['x'], self.safety['y'], self.safety['v'], self.safety['yaw'])

        x, y, v, yaw = self._unpack_state(state)
        
        xy = np.array([x, y])
        lx, _ = self.track._calc_value(xy, u_init = track_index / self.track.raceline_length)
        self.replay_buffer.lx[self.replay_buffer.ptr] = lx

         # Use the racetrack geometry at the nearest_idx instead of the idx for coordinate transform
        origin, yaw0 = self._get_track_info(nearest_idx)
        # Transform to local coordinate system.
        local_state = toLocal(x, y, v, yaw, origin, yaw0) #np.array([5, 2, 10, 1.5])

        # make sure yaw is in the correct range
        if (local_state[3]>max(self.safety['yaw'])):
            local_state[3] -= 2*np.pi
        if (local_state[3]<min(self.safety['yaw'])):
            local_state[3] += 2*np.pi
        
        safe_action = self.safety_controller.select_action(local_state, self.safety, lx) ## Use lx as a proxy for current value
        
        ## Prevent the steering angle from being too large
        speed = feat[-1].item() # in m/s
        if speed > 30:
            safe_action[0] = np.clip(safe_action[0], a_min=-1/12, a_max=1/12)
        elif speed > 25:
            safe_action[0] = np.clip(safe_action[0], a_min=-1/6, a_max=1/6)
        elif speed > 10:
            safe_action[0] = np.clip(safe_action[0], a_min=-1/3, a_max=1/3)
        
        ## Core difference
        q = self.safety_q.forward(feat, torch.tensor(safe_action).float().to(DEVICE))
        print(f"V_est={q.item():.2f}, l(x)={lx:.2f}, speed = {feat[-1].item():.2f}")
        return q, safe_action
    
    def select_action(self, t, feat, state, deterministic=False):        
        if ((t+1) % 1000==0):
            self.gamma *= self.gamma_anneal
            self.gamma = min(0.99,self.gamma)        
        
        safety_value, safe_action = self._check_safety(feat, state)
        self.replay_buffer.act2_buf[self.replay_buffer.ptr] = safe_action

        if safety_value.item() <= (0 + self.cfg['safety_margin']):
            a = safe_action
            print(f"Steer {'Left' if a[0]>0 else 'Right'}")
            ## Penalize the action that activates safe controller
            if self.record['transition_actor'] != 'safepol':
                self.replay_buffer.rew_buf[max(0, self.replay_buffer.ptr-1)] = -3
        
            if not 'safety_info' in self.metadata:
                self.metadata['safety_info'] = {'ep_interventions': 0}
                
            self.metadata['safety_info']['ep_interventions'] += 1 # inherited from parent class
            self.record['transition_actor'] = 'safepol'
        else:
            # Until start_steps have elapsed, randomly sample actions
            # from a uniform distribution for better exploration. Afterwards,
            # use the learned policy.
            if t > self.cfg['start_steps'] and not self.cfg['make_random_actions']:
                #pdb.set_trace()
                a = self.actor_critic.act(feat, deterministic)
                self.record['transition_actor'] = 'learner'
                a[1] = np.clip(a[1], a_min = -0.125, a_max = 1)
            else:
                a = np.random.uniform([-1, -0.125], [1, 1])
                #a = self.env.action_space.sample()
                self.record['transition_actor'] = 'random'
        return a
    
    def update_safety_q(self, data):
        o, a, r, lx, o2, a2, done = data['obs'], data['act'], data['rew'], data['lx'], data['obs2'], data['act2'], data['done']
        BS = o.shape[0]

        with torch.no_grad():
            next_q_values = self.safety_q_target.forward(o2, a2) # 
            target = (1-self.gamma) * lx +  self.gamma * torch.minimum(lx, next_q_values)
            target[done.bool()] = lx[done.bool()] ## if Done target = lx

        self.safety_q_optimizer.zero_grad()
        q_pred = self.safety_q.forward(o, a)
        loss = F.mse_loss(q_pred, target) 
        loss.backward()
        self.safety_q_optimizer.step()
        # print('Updating the safety critic')
        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(self.safety_q.parameters(), self.safety_q_target.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.mul_(self.cfg['safety_polyak'])
                p_targ.data.add_((1 - self.cfg['safety_polyak']) * p.data)
        
    def update(self, data):
        self.update_safety_q(data)
        super().update(data) 
        

