import torch
import os
from datetime import datetime
import numpy as np
from mpi4py import MPI
from raDT.baselines.mpi_utils.mpi_utils import sync_networks, sync_grads
from raDT.baselines.rl_modules.base_agent import BaseAgent
from raDT.baselines.rl_modules.replay_buffer import replay_buffer
from raDT.baselines.rl_modules.models import actor, critic, VAE, LagrangianPIDController
from raDT.baselines.mpi_utils.normalizer import normalizer
from raDT.baselines.her_modules.her import her_sampler
import torch.nn.functional as F

"""
Actionable Model with HER (MPI-version)

"""
class ActionableModel(BaseAgent):
    def __init__(self, args, env, env_params):
        super().__init__(args, env, env_params) 
        # create the network
        self.actor_network = actor(env_params)
        self.critic_network = critic(env_params)
        self.cost_critic_network = critic(env_params)
        # sync the networks across the cpus
        sync_networks(self.actor_network)
        sync_networks(self.critic_network)
        sync_networks(self.cost_critic_network)
        # build up the target network
        self.actor_target_network = actor(env_params)
        self.critic_target_network = critic(env_params)
        self.cost_critic_target_network = critic(env_params)
        # load the weights into the target networks
        self.actor_target_network.load_state_dict(self.actor_network.state_dict())
        self.critic_target_network.load_state_dict(self.critic_network.state_dict())
        self.cost_critic_target_network.load_state_dict(self.cost_critic_network.state_dict())

        self.vae = VAE(env_params)
        self.q_thres = 5 * (1 - self.args.gamma**50) / (1 - self.args.gamma) / 50
        self.controller = LagrangianPIDController(10, 0.003, 0.001,
                                                  self.q_thres)
        
        # if use gpu
        if self.args.cuda:
            self.actor_network.cuda()
            self.critic_network.cuda()
            self.actor_target_network.cuda()
            self.critic_target_network.cuda()
            self.cost_critic_network.cuda()
            self.cost_critic_target_network.cuda()
            self.vae.cuda()
        # create the optimizer
        self.actor_optim = torch.optim.Adam(self.actor_network.parameters(), lr=self.args.lr_actor)
        self.critic_optim = torch.optim.Adam(self.critic_network.parameters(), lr=self.args.lr_critic)
        self.cost_critic_optim = torch.optim.Adam(self.cost_critic_network.parameters(), lr=self.args.lr_critic)
        self.vae_optim = torch.optim.Adam(self.vae.parameters(), lr=self.args.lr_critic)
    
    # this function will choose action for the agent and do the exploration
    def _stochastic_actions(self, input_tensor):
        pi = self.actor_network(input_tensor)
        action = pi.cpu().numpy().squeeze()
        # add the gaussian
        action += self.args.noise_eps * self.env_params['action_max'] * np.random.randn(*action.shape)
        action = np.clip(action, -self.env_params['action_max'], self.env_params['action_max'])
        # random actions...
        random_actions = np.random.uniform(low=-self.env_params['action_max'], high=self.env_params['action_max'], \
                                            size=self.env_params['action'])
        # choose if use the random actions
        action += np.random.binomial(1, self.args.random_eps, 1)[0] * (random_actions - action)
        return action
    
    def _deterministic_action(self, input_tensor):
        action = self.actor_network(input_tensor)
        return action

    # update the network
    def _update_network(self, future_p=None):

        self.mutip = 10
        # sample the episodes
        sample_batch = self.sample_batch(future_p=future_p)
        transitions = sample_batch['transitions']
        
        # apply AM goal-chaining (later half of the batch)
        # half_batch_size = int(self.args.batch_size / 2)
        # her_goals = transitions['g'][half_batch_size:]
        # np.random.shuffle(her_goals)
        # transitions['g'][half_batch_size:] = her_goals

        # start to do the update
        obs_norm = self.o_norm.normalize(transitions['obs'])
        g_norm = self.g_norm.normalize(transitions['g'])
        inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
        obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
        g_next_norm = self.g_norm.normalize(transitions['g_next'])
        inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
        # transfer them into the tensor
        inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
        inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32)
        actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32)
        r_tensor = torch.tensor(transitions['r'], dtype=torch.float32)
        c_tensor = torch.tensor(transitions['costs'], dtype=torch.float32) 

        if self.args.cuda:
            inputs_norm_tensor = inputs_norm_tensor.cuda()
            inputs_next_norm_tensor = inputs_next_norm_tensor.cuda()
            actions_tensor = actions_tensor.cuda()
            r_tensor = r_tensor.cuda()
            c_tensor = c_tensor.cuda()

        # Goal-Chaining reward: assign Q(s,a,g) as the reward for randomly sampled goals
        # with torch.no_grad():
        #     r_tensor[half_batch_size:] = self.critic_network(inputs_norm_tensor, actions_tensor)[half_batch_size:]

        # calculate the target Q value function
        with torch.no_grad():
            # do the normalization
            # concatenate the stuffs
            actions_next = self.actor_target_network(inputs_next_norm_tensor)
            q_next_value = self.critic_target_network(inputs_next_norm_tensor, actions_next)
            q_next_value = q_next_value.detach()
            target_q_value = r_tensor + self.args.gamma * q_next_value
            target_q_value = target_q_value.detach()
            # clip the q value
            clip_return = 1 / (1 - self.args.gamma)
            target_q_value = torch.clamp(target_q_value, -clip_return, 0)
        # the q loss
        real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor)
        critic_loss = (target_q_value - real_q_value).pow(2).mean()

        # add AM penalty
        num_random_actions = 10
        random_actions_tensor = torch.FloatTensor(q_next_value.shape[0] * num_random_actions, actions_tensor.shape[-1]).uniform_(-1, 1).to(actions_tensor.device)
        inputs_norm_tensor_repeat = inputs_norm_tensor.repeat_interleave(num_random_actions, axis=0)

        q_random_actions = self.critic_network(inputs_norm_tensor_repeat, random_actions_tensor)
        q_random_actions = q_random_actions.reshape(q_next_value.shape[0], -1)

        # sample according to exp(Q)
        sampled_random_actions = torch.distributions.Categorical(logits=q_random_actions.detach()).sample()
        critic_loss_AM = q_random_actions[torch.arange(q_random_actions.shape[0]), sampled_random_actions].mean()
        critic_loss += critic_loss_AM
        
        with torch.no_grad():
            qc_next_value = self.cost_critic_target_network(inputs_next_norm_tensor, actions_next)
            qc_next_value = qc_next_value.detach()
            target_qc_value = c_tensor + self.args.gamma * qc_next_value
            target_qc_value = target_qc_value.detach()
            # clip the q value
            clip_return = 1 / (1 - self.args.gamma)
            target_qc_value = torch.clamp(target_qc_value, 0, clip_return)

        # qc loss
        real_qc_value = self.cost_critic_network(inputs_norm_tensor, actions_tensor)
        loss_cost_critic = (target_qc_value - real_qc_value).pow(2).mean()

        # the actor loss
        actions_real = self.actor_network(inputs_norm_tensor)
        qc_val = self.cost_critic_network(inputs_norm_tensor, actions_real)
        # actor_loss = -((qc_val <= self.q_thres) * self.critic_network(inputs_norm_tensor, actions_real)).mean()
        # with torch.no_grad():
        #     multiplier = self.controller.control(qc_val)
        qc_penalty = ((qc_val - self.q_thres) * self.mutip).mean()
        actor_loss = -(self.critic_network(inputs_norm_tensor, actions_real)).mean() + qc_penalty
        
        # start to update the network
        self.actor_optim.zero_grad()
        actor_loss.backward()
        sync_grads(self.actor_network)
        self.actor_optim.step()

        # update the critic_network
        self.critic_optim.zero_grad()
        critic_loss.backward()
        sync_grads(self.critic_network)
        self.critic_optim.step()

        self.cost_critic_optim.zero_grad()
        loss_cost_critic.backward()
        sync_grads(self.cost_critic_network)
        self.cost_critic_optim.step()

        results = {'Train/critic_loss': critic_loss, 
                   'Train/actor_loss': actor_loss,
                   'Train/critic_loss_AM': critic_loss_AM,
                   'Train/loss_cost_critic': loss_cost_critic,
                   'Train/real_q_value': (real_q_value).mean(),
                   'Train/real_qc_value': (real_qc_value).mean()}

        return results