import sys
import os
import os.path as osp
import torch
from torch import nn
from torch.optim import Adam
import numpy as np
import time

from irl.algo.base import Algorithm
from irl.buffer import RolloutBuffer
from irl.network import StateIndependentPolicy, StateFunction


def calculate_gae(values, rewards, dones, next_values, gamma, lambd):
    # Calculate TD errors.
    deltas = rewards + gamma * next_values * (1 - dones) - values
    # Initialize gae.
    gaes = torch.empty_like(rewards)

    # Calculate gae recursively from behind.
    gaes[-1] = deltas[-1]
    for t in reversed(range(rewards.size(0) - 1)):
        gaes[t] = deltas[t] + gamma * lambd * (1 - dones[t]) * gaes[t + 1]

    return gaes + values, (gaes - gaes.mean()) / (gaes.std() + 1e-8)


class PPO(Algorithm):

    def __init__(self, state_shape, action_shape, device, seed, start_buffer_exp, gamma=0.995,
                 rollout_length=2048, mix_buffer=20, lr_actor=3e-4,
                 lr_critic=3e-4, units_actor=(64, 64), units_critic=(64, 64),
                 epoch_ppo=10, clip_eps=0.2, lambd=0.97, coef_ent=0.0,
                 max_grad_norm=10.0):
        super().__init__(state_shape, action_shape, device, seed, gamma)

        self.start_buffer_exp = start_buffer_exp

        # Rollout buffer.
        self.buffer = RolloutBuffer(
            buffer_size=rollout_length,
            state_shape=state_shape,
            action_shape=action_shape,
            device=device,
            mix=mix_buffer
        )

        # Actor.
        self.actor = StateIndependentPolicy(
            state_shape=state_shape,
            action_shape=action_shape,
            hidden_units=units_actor,
            hidden_activation=nn.Tanh()
        ).to(device)

        # Critic.
        self.critic = StateFunction(
            state_shape=state_shape,
            hidden_units=units_critic,
            hidden_activation=nn.Tanh()
        ).to(device)

        self.optim_actor = Adam(self.actor.parameters(), lr=lr_actor)
        self.optim_critic = Adam(self.critic.parameters(), lr=lr_critic)

        self.learning_steps_ppo = 0
        self.rollout_length = rollout_length
        self.epoch_ppo = epoch_ppo
        self.clip_eps = clip_eps
        self.lambd = lambd
        self.coef_ent = coef_ent
        self.max_grad_norm = max_grad_norm

    def is_update(self, step):
        return step % self.rollout_length == 0

    def step(self, env, state, t, step):
        t += 1

        action, log_pi = self.explore(state)
        next_state, reward, done, _ = env.step(action)
        # env.render()
        # time.sleep(1e-2)
        if self.args.env == 'Walker2dCrawl-v1':
            dist = env.unwrapped.get_dist()
            self.buffer.append(state, action, reward, done, log_pi, np.append(next_state, np.array(dist, dtype=np.float32)))
        else:
            self.buffer.append(state, action, reward, done, log_pi, next_state)
        if self.args.complex_task == 'hurdle':
            if self.front: 
                flag = env.unwrapped.is_transition_boundary()
            else:
                flag = env.unwrapped.is_transition_boundary_rear()
        elif self.args.complex_task == 'obstacle':
            if self.front:
                flag = env.unwrapped.is_transition_boundary_for_obstacle()
            else:
                flag = env.unwrapped.is_transition_boundary_rear_for_obstacle()
        elif self.args.complex_task == 'pick':
            flag = env.unwrapped.is_transition_boundary_for_pick()
        elif self.args.complex_task == 'catch':
            flag = env.unwrapped.is_transition_boundary_for_catch()
        elif self.args.complex_task == 'serve':
            flag = env.unwrapped.is_transition_boundary_for_serve()
        elif self.args.complex_task == 'patrol':
            if self.front:
                flag = env.unwrapped.is_transition_boundary_for_patrol()
            else:
                flag = env.unwrapped.is_transition_boundary_rear_for_patrol()
            

        if done or flag == -1:
            t = 0
            next_state = env.reset()
            start_sample = self.start_buffer_exp.sample()            
            next_state = env.unwrapped.rollback(start_sample[0].cpu().detach().numpy(), 
                start_sample[1].cpu().detach().numpy(),
                start_sample[2].cpu().detach().numpy(),
                start_sample[3].cpu().detach().numpy(),
                start_sample[4].cpu().detach().numpy())
        return next_state, t

    def update(self, writer):
        self.learning_steps += 1
        states, actions, rewards, dones, log_pis, next_states = \
            self.buffer.get()
        self.update_ppo(
            states, actions, rewards, dones, log_pis, next_states, writer)

    def update_ppo(self, states, actions, rewards, dones, log_pis, next_states,
                   writer):
        with torch.no_grad():
            values = self.critic(states)
            next_values = self.critic(next_states)

        targets, gaes = calculate_gae(
            values, rewards, dones, next_values, self.gamma, self.lambd)

        for _ in range(self.epoch_ppo):
            self.learning_steps_ppo += 1
            self.update_critic(states, targets, writer)
            self.update_actor(states, actions, log_pis, gaes, writer)

    def update_critic(self, states, targets, writer):
        loss_critic = (self.critic(states) - targets).pow_(2).mean()

        self.optim_critic.zero_grad()
        loss_critic.backward(retain_graph=False)
        nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
        self.optim_critic.step()

        if self.learning_steps_ppo % self.epoch_ppo == 0:
            writer.add_scalar(
                'loss/critic', loss_critic.item(), self.learning_steps)

    def update_actor(self, states, actions, log_pis_old, gaes, writer):
        log_pis = self.actor.evaluate_log_pi(states, actions)
        entropy = -log_pis.mean()

        ratios = (log_pis - log_pis_old).exp_()
        loss_actor1 = -ratios * gaes
        loss_actor2 = -torch.clamp(
            ratios,
            1.0 - self.clip_eps,
            1.0 + self.clip_eps
        ) * gaes
        loss_actor = torch.max(loss_actor1, loss_actor2).mean()

        self.optim_actor.zero_grad()
        (loss_actor - self.coef_ent * entropy).backward(retain_graph=False)
        nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
        self.optim_actor.step()

        if self.learning_steps_ppo % self.epoch_ppo == 0:
            writer.add_scalar(
                'loss/actor', loss_actor.item(), self.learning_steps)
            writer.add_scalar(
                'stats/entropy', entropy.item(), self.learning_steps)


    def save_models(self, path):
        fname = os.path.join(path, 'model.pt')
        os.makedirs(path, exist_ok=True)
        torch.save({
            'actor_mean': self.actor.mean,
            'actor_std': self.actor.std,
            'critic_mean': self.critic.mean,
            'critic_std': self.critic.std,
            'model': self.actor.state_dict(),
            }, fname)

class PPOExpert(PPO):

    def __init__(self, state_shape, action_shape, device, path,
                 units_actor=(64, 64)):
        # Actor.
        self.actor = StateIndependentPolicy(
            state_shape=state_shape,
            action_shape=action_shape,
            hidden_units=units_actor,
            hidden_activation=nn.Tanh()
        ).to(device)
        checkpoint = torch.load(path)
        self.actor.load_state_dict(checkpoint['model'])
        self.actor.mean = checkpoint['actor_mean']
        self.actor.std = checkpoint['actor_std']
        self.device = device
    
    def get_mean(self):
        return self.actor.mean

    def get_std(self):
        return self.actor.std
