import time
import torch
import torch.nn as nn
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler


class ActorCritic():
    def __init__(self, args, policy, critic, device=torch.device("cpu")):
        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.policy = policy.to(device)
        self.critic = critic.to(device)
        self.lr = args.lr
        self.critic_lr = args.critic_lr
        self._use_value_active_masks = args.use_value_active_masks
        self._use_policy_active_masks = args.use_policy_active_masks
        self.opti_eps = args.opti_eps
        self.weight_decay = args.weight_decay
        self.gamma = args.gamma
        self.batch_size = args.batch_size
        self.entropy_coeff = args.entropy_coeff
        self.max_grad_norm = args.max_grad_norm

        self.total_time = 0

        if args.optimizer == 'adam':
            self.actor_optimizer = torch.optim.Adam(self.policy.parameters(),
                                                    lr=self.lr, eps=self.opti_eps,
                                                    weight_decay=self.weight_decay)
        elif args.optimizer == 'sgd':
            self.actor_optimizer = torch.optim.SGD(self.policy.parameters(),
                                                   lr=self.lr,)
        else:
            raise ValueError(f'Unknown optimizer {args.optimizer}')

        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=self.critic_lr, eps=self.opti_eps,
                                                 weight_decay=self.weight_decay)

    def set_lr(self, lr, critic_lr=None):
        self.lr = lr
        self.actor_optimizer.param_groups[0]["lr"] = lr
        if critic_lr is not None:
            self.critic_lr = critic_lr
            self.critic_optimizer.param_groups[0]["lr"] = critic_lr

    def ac_update(self, states, actions, returns, advantages):
        start_time = time.time()

        total_policy_loss = 0.0
        total_value_loss = 0.0
        batch_log_probs = []

        batch_size = self.batch_size if self.batch_size is not None and self.batch_size > 0 else states.shape[0]
        batch_size = min(batch_size, states.shape[0])
        sampler = BatchSampler(
            SubsetRandomSampler(range(states.shape[0])),
            batch_size,
            drop_last=True)

        num_updates = 0
        for indices in sampler:
            num_updates += 1
            states_batch = states[indices]
            actions_batch = actions[indices]
            advantages_batch = advantages[indices]
            returns_batch = returns[indices]
            values = self.critic(states_batch)
            value_loss = (values - returns_batch).pow(2).mean()

            self.critic_optimizer.zero_grad()
            value_loss.backward()

            # clip gradients
            if self.max_grad_norm >= 0:
                nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)

            self.critic_optimizer.step()

            action_log_probs, dist_entropy = self.policy.get_log_prob(states_batch, actions_batch, return_entropy=True)
            policy_loss = -(action_log_probs * advantages_batch).mean()
            policy_loss -= dist_entropy * self.entropy_coeff

            self.actor_optimizer.zero_grad()
            policy_loss.backward()

            # clip gradients
            if self.max_grad_norm >= 0:
                nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)

            self.actor_optimizer.step()

            total_policy_loss += policy_loss
            total_value_loss += value_loss
            batch_log_probs.append(action_log_probs)
        end_time = time.time()

        self.total_time += end_time - start_time
        
        if "xla" in str(self.device):
            import torch_xla.core.xla_model as xm
            xm.mark_step()

        with torch.no_grad():
            return (total_policy_loss / num_updates, total_value_loss / num_updates,
                    torch.cat(batch_log_probs).mean(), advantages.mean())

    def train(self, states, actions, rewards, masks):

        # returns = self.get_returns(rewards, masks).unsqueeze(1)
        values = self.critic(torch.Tensor(states))
        advantages, returns = self.estimate_advantages(rewards, masks, values, tau=0.95, device=self.device)

        train_info={}
        train_info['policy_loss'] = 0
        train_info['value_loss'] = 0

        policy_loss, value_loss, log_probs, advantages = self.ac_update(states, actions, returns, advantages)
        
        if "xla" in str(self.device):
            import torch_xla.core.xla_model as xm
            xm.mark_step()
        train_info['policy_loss'] += policy_loss.item()
        train_info['value_loss'] += value_loss.item()
        train_info['total_time'] = self.total_time
        train_info['episode_length'] = states.shape[0]
        train_info['log_probs'] = log_probs.item()
        train_info['advantages'] = advantages.item()

        return train_info

    def estimate_advantages(self, rewards, mask, values, tau, device):
        rewards, masks, values = self.to_device(torch.device('cpu'), rewards, mask, values)
        tensor_type = type(rewards)
        deltas = tensor_type(rewards.size(0), 1)
        advantages = tensor_type(rewards.size(0), 1)

        prev_value = 0
        prev_advantage = 0
        for i in reversed(range(rewards.size(0))):
            deltas[i] = rewards[i] + self.gamma * prev_value * masks[i] - values[i]
            advantages[i] = deltas[i] + self.gamma * tau * prev_advantage * masks[i]

            prev_value = values[i, 0]
            prev_advantage = advantages[i, 0]

        returns = (values + advantages).detach()
        advantages = ((advantages - advantages.mean()) / (advantages.std() + 1e-8)).detach()

        advantages, returns = self.to_device(device, advantages, returns)
        return advantages, returns

    def to_device(self, device, *args):
        return [x.to(device) for x in args]
