import numpy as np
import torch
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
import time
class ActorCritic():
    def __init__(self, args, policy, critic, device=torch.device("cpu")):
        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.policy = policy.to(device)
        self.critic = critic.to(device)
        self.lr = args.lr
        self.critic_lr = args.critic_lr
        self._use_value_active_masks = args.use_value_active_masks
        self._use_policy_active_masks = args.use_policy_active_masks
        self.opti_eps = args.opti_eps
        self.weight_decay = args.weight_decay
        self.gamma = args.gamma
        self.batch_size = args.batch_size

        self.total_time = 0

        self.actor_optimizer = torch.optim.SGD(self.policy.parameters(),
                                                   lr=self.lr,)

        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=self.critic_lr, eps=self.opti_eps,
                                                 weight_decay=self.weight_decay)

    def set_lr(self, lr, critic_lr=None):
        self.lr = lr
        self.actor_optimizer.param_groups[0]["lr"] = lr
        if critic_lr is not None:
            self.critic_lr = critic_lr
            self.critic_optimizer.param_groups[0]["lr"] = critic_lr

    def ac_update(self, states, actions, returns, advantages):
        start_time = time.time()

        batch_log_probs = []
        
        policy_losses = 0
        kl = 0
        batch_size = self.batch_size if self.batch_size is not None else states.shape[0]
        sampler = BatchSampler(
            SubsetRandomSampler(range(states.shape[0])),
            batch_size,
            drop_last=True)

        for indices in sampler:
            states_batch = states[indices]
            actions_batch = actions[indices]
            advantages_batch = advantages[indices]
            returns_batch = returns[indices]
            values = self.critic(states_batch)
            value_loss = (values - returns_batch).pow(2).mean()

            self.critic_optimizer.zero_grad()
            value_loss.backward()
            self.critic_optimizer.step()

            action_log_probs = self.policy.get_log_prob(states_batch, actions_batch)
            policy_loss = -(action_log_probs * advantages_batch).mean()
            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            self.actor_optimizer.step()

            batch_log_probs.append(action_log_probs)
        end_time = time.time()

        self.total_time += end_time - start_time
        
        if "xla" in str(self.device):
            import torch_xla.core.xla_model as xm
            xm.mark_step()

        with torch.no_grad():
            return policy_loss, value_loss, torch.cat(batch_log_probs).mean(), advantages.mean()

    def train(self, states, actions, rewards, masks):

        # returns = self.get_returns(rewards, masks).unsqueeze(1)
        values = self.critic(torch.Tensor(states))
        advantages, returns = self.estimate_advantages(rewards, masks, values, tau=0.95, device=self.device)

        train_info={}
        train_info['policy_loss'] = 0
        train_info['value_loss'] = 0


        policy_loss, value_loss, log_probs, advantages = self.ac_update(states, actions, returns, advantages)
        
        if "xla" in str(self.device):
            import torch_xla.core.xla_model as xm
            xm.mark_step()
        train_info['policy_loss'] += policy_loss.item()
        train_info['value_loss'] += value_loss.item()
        train_info['total_time'] = self.total_time
        train_info['episode_length'] = states.shape[0]
        train_info['log_probs'] = log_probs.item()
        train_info['advantages'] = advantages.item()

        return train_info

    def estimate_advantages(self, rewards, mask, values, tau, device):
        rewards, masks, values = self.to_device(torch.device('cpu'), rewards, mask, values)
        tensor_type = type(rewards)
        deltas = tensor_type(rewards.size(0), 1)
        advantages = tensor_type(rewards.size(0), 1)

        prev_value = 0
        prev_advantage = 0
        for i in reversed(range(rewards.size(0))):
            deltas[i] = rewards[i] + self.gamma * prev_value * masks[i] - values[i]
            advantages[i] = deltas[i] + self.gamma * tau * prev_advantage * masks[i]

            prev_value = values[i, 0]
            prev_advantage = advantages[i, 0]

        # returns = values + advantages
        returns = (values + advantages).detach()
        # print(advantages.mean(), advantages.std())
        advantages = ((advantages - advantages.mean()) / advantages.std()).detach()

        advantages, returns = self.to_device(device, advantages, returns)
        
        return advantages, returns

    def to_device(self, device, *args):
        return [x.to(device) for x in args]
