import numpy as np
import torch
from src.utils.utils import *
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
import time
import copy
class ActorCritic():
    def __init__(self, args, policy, critic, device=torch.device("cpu")):
        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.policy = policy.to(device)
        self.critic = critic.to(device)
        self.is_disc_action = self.policy.is_disc_action
        self.lr = args.lr
        self.critic_lr = args.critic_lr
        self._use_value_active_masks = args.use_value_active_masks
        self._use_policy_active_masks = args.use_policy_active_masks
        self.opti_eps = args.opti_eps
        self.weight_decay = args.weight_decay
        self.gamma = args.gamma
        self.batch_size = args.batch_size

        self.total_time = 0

        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=self.critic_lr, eps=self.opti_eps,
                                                 weight_decay=self.weight_decay)

    def set_lr(self, lr, critic_lr=None):
        self.lr = lr
        if critic_lr is not None:
            self.critic_lr = critic_lr
            self.critic_optimizer.param_groups[0]["lr"] = critic_lr

    def ac_npg_update(self, states, actions, returns, advantages):
        start_time = time.time()

        batch_log_probs = []

        policy_losses = 0
        kl = 0
        batch_size = self.batch_size if self.batch_size is not None and self.batch_size > 0 else states.shape[0]
        sampler = BatchSampler(
            SubsetRandomSampler(range(states.shape[0])),
            batch_size,
            drop_last=True)

        for indices in sampler:
            states_batch = states[indices]
            actions_batch = actions[indices]
            advantages_batch = advantages[indices]
            returns_batch = returns[indices]
            values = self.critic(states_batch)
            value_loss = (values - returns_batch).pow(2).mean()

            self.critic_optimizer.zero_grad()
            value_loss.backward()
            self.critic_optimizer.step()

            action_log_probs = self.policy.get_log_prob(states_batch, actions_batch)

            policy_loss = -(action_log_probs * advantages_batch).mean()
            loss_grad = torch.autograd.grad(policy_loss, self.policy.parameters())
            loss_grad = get_flat_grad_from(loss_grad)
            grad_time = time.time()

            step_dir = self.conjugate_gradient(self.policy, states_batch, loss_grad.data)

            params = get_flat_params_from(self.policy)
            shs = 0.5 * (step_dir * self.fisher_vector_product(self.policy, states_batch, step_dir)
                        ).sum(0, keepdim=True)
            step_size = 1 / torch.sqrt(shs / 0.01)[0]
            full_step = step_size * step_dir
            
            old_policy = copy.deepcopy(self.policy).to(self.device)
            
            
            new_params = params - self.lr * full_step

            set_flat_params_to(self.policy, new_params)
            kl += self.kl_divergence(self.policy, old_policy, states_batch)

            batch_log_probs.append(action_log_probs)
            
            if "xla" in str(self.device):
                import torch_xla.core.xla_model as xm
                xm.mark_step()

        end_time = time.time()
        self.total_time += end_time - start_time
        with torch.no_grad():
            return policy_loss, value_loss, self.total_time, kl.mean(), torch.cat(batch_log_probs).mean(), advantages.mean()

    def train(self, states, actions, rewards, masks):

        # returns = self.get_returns(rewards, masks).unsqueeze(1)
        values = self.critic(torch.Tensor(states))
        advantages, returns = self.estimate_advantages(rewards, masks, values, tau=0.95, device=self.device)

        train_info={}
        train_info['policy_loss'] = 0
        train_info['value_loss'] = 0
        train_info['total_time'] = 0
        train_info['kl'] = 0

        policy_loss, value_loss, total_time, kl, log_probs, advantages = self.ac_npg_update(states, actions, returns, advantages)
        
        if "xla" in str(self.device):
            import torch_xla.core.xla_model as xm
            xm.mark_step()
        train_info['policy_loss'] += policy_loss.item()
        train_info['value_loss'] += value_loss.item()
        train_info['total_time'] = total_time
        train_info['episode_length'] = states.shape[0]
        train_info['kl'] = kl
        train_info['log_probs'] = log_probs.item()
        train_info['advantages'] = advantages.item()

        return train_info

    def conjugate_gradient(self, actor, states, b, nsteps=10, residual_tol=1e-10):
        x = torch.zeros(b.size()).to(self.device)
        r = b.clone()
        p = b.clone()
        rdotr = torch.dot(r, r)
        for i in range(nsteps):
            _Avp = self.fisher_vector_product(actor, states, p)
            alpha = rdotr / torch.dot(p, _Avp)
            x += alpha * p
            r -= alpha * _Avp
            new_rdotr = torch.dot(r, r)
            betta = new_rdotr / rdotr
            p = r + betta * p
            rdotr = new_rdotr
            if rdotr < residual_tol:
                break
        return x

    def fisher_vector_product(self, actor, states, p):
        p.detach()
        # kl = self.kl_divergence(new_actor=actor, old_actor=actor, states=states)
        kl = actor.get_kl(states)
        kl = kl.mean()
        kl_grad = torch.autograd.grad(kl, actor.parameters(), create_graph=True)
        kl_grad = get_flat_grad_from(kl_grad)  # check kl_grad == 0

        kl_grad_p = (kl_grad * p).sum()
        kl_hessian_p = torch.autograd.grad(kl_grad_p, actor.parameters(), retain_graph=True)
        kl_hessian_p = flat_hessian(kl_hessian_p)
        return kl_hessian_p + 0.1 * p

    def kl_divergence(self, new_actor, old_actor, states):
        # def kl_divergence(self, new_actor, old_actor, states):
        if self.is_disc_action:              
            p = new_actor(states)                        
            q = old_actor(states).detach()
            kl = (q * (q.log() - p.log())).sum(1, keepdim=True)
            return kl.mean()
        else:                                   
            
            mu, std, logstd = new_actor(torch.Tensor(states))
            mu_old, std_old, logstd_old = old_actor(torch.Tensor(states))
            mu_old = mu_old.detach()
            std_old = std_old.detach()
            logstd_old = logstd_old.detach()

            kl = logstd_old - logstd + (std_old.pow(2) + (mu_old - mu).pow(2)) / \
                (2.0 * std.pow(2)) - 0.5
            return kl.sum(1, keepdim=True)

    def estimate_advantages(self, rewards, mask, values, tau, device):
        rewards, masks, values = self.to_device(torch.device('cpu'), rewards, mask, values)
        tensor_type = type(rewards)
        deltas = tensor_type(rewards.size(0), 1)
        advantages = tensor_type(rewards.size(0), 1)

        prev_value = 0
        prev_advantage = 0
        for i in reversed(range(rewards.size(0))):
            deltas[i] = rewards[i] + self.gamma * prev_value * masks[i] - values[i]
            advantages[i] = deltas[i] + self.gamma * tau * prev_advantage * masks[i]

            prev_value = values[i, 0]
            prev_advantage = advantages[i, 0]

        returns = (values + advantages).detach()
        advantages = ((advantages - advantages.mean()) / advantages.std()).detach()

        advantages, returns = self.to_device(device, advantages, returns)
        return advantages, returns

    def update_model(self, model, new_params):
        index = 0
        for params in model.parameters():
            params_length = len(params.view(-1))
            new_param = new_params[index: index + params_length]
            new_param = new_param.view(params.size())
            params.data.copy_(new_param)
            index += params_length
    def to_device(self, device, *args):
        return [x.to(device) for x in args]
