import numpy as np
import torch
import time
from src.utils.utils import *
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
from src.algorithms.ac.acPolicy import Actor, DiscreteActor
import copy


class ActorCritic:
    def __init__(self, args, policy, critic, device=torch.device("cpu")):
        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.policy = policy.to(device)
        self.critic = critic.to(device)
        self.is_disc_action = self.policy.is_disc_action
        self.lr = args.lr
        self.critic_lr = args.critic_lr
        self._use_value_active_masks = args.use_value_active_masks
        self._use_policy_active_masks = args.use_policy_active_masks
        self.opti_eps = args.opti_eps
        self.weight_decay = args.weight_decay
        self.gamma = args.gamma
        self.flag = False
        self.online_update = args.online_update
        self.ief = args.ief
        self.online = args.online_update
        self.batch_size = args.batch_size
        self.damping_factor = args.damping_factor
        self.matrix_free = args.matrix_free
        self.entropy_coeff = args.entropy_coeff

        if self.online or not self.matrix_free:
            if self.online:
                num_params = sum(p.numel() for p in policy.parameters())
                self.diagonal_fim = torch.Tensor(np.zeros(num_params)).to(self.device)
                self.beta = 0.5
            num_params = sum(p.numel() for p in policy.parameters())

            self.fim = self.damping_factor * torch.eye(num_params, num_params).to(self.device)

        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=self.critic_lr)

        self.m = None  # First moment vector for Adam
        self.v = None  # Second moment vector for Adam
        self.beta1 = 0.7  # args.beta1
        self.beta2 = 0.7  # args.beta2
        self.epsilon = 1e-8  # args.epsilon
        self.t = 0  # Time step
        self.total_time = 0

    def set_lr(self, lr, critic_lr=None):
        self.lr = lr
        if critic_lr is not None:
            self.critic_lr = critic_lr
            self.critic_optimizer.param_groups[0]["lr"] = critic_lr

    def ac_npg_update1(self, states, actions, returns, advantages):
        total_policy_loss = 0.0
        total_value_loss = 0.0
        batch_log_probs = []
        kl = 0

        if self.batch_size is not None and self.batch_size > 0:
            batch_size = self.batch_size if self.batch_size < states.shape[0] else states.shape[0]
        else:
            batch_size = states.shape[0]
        sampler = BatchSampler(
            SubsetRandomSampler(range(states.shape[0])),
            batch_size,
            drop_last=True)

        num_updates = 0
        for indices in sampler:
            num_updates += 1
            states_batch = states[indices]
            actions_batch = actions[indices]
            advantages_batch = advantages[indices]
            returns_batch = returns[indices]
            values = self.critic(states_batch)

            value_loss = (values - returns_batch).pow(2).mean()

            self.critic_optimizer.zero_grad()
            value_loss.backward()
            self.critic_optimizer.step()

            action_log_probs, dist_entropy = self.policy.get_log_prob(states_batch, actions_batch, return_entropy=True)

            if self.online or not self.matrix_free:
                policy_loss = -(action_log_probs * advantages_batch).mean()
                # grad_outputs = torch.ones_like(action_log_probs)
                loss_grad = torch.autograd.grad(policy_loss, self.policy.parameters(), retain_graph=True)
                loss_grad = get_flat_grad_from(loss_grad)
                log_prob_grads = torch.autograd.grad(action_log_probs.mean(),
                                                     self.policy.parameters())  # , grad_outputs=grad_outputs)

                log_prob_grads = get_flat_grad_from(log_prob_grads)

                if self.ief:
                    sq_sum = log_prob_grads.pow(2).sum().item()
                    sief = sq_sum ** 0.5
                    sief = sief ** 0.5
                    log_prob_grads = log_prob_grads / sief

                if self.online_update:
                    self.fim = self.sherman_morrison_update1(self.fim, log_prob_grads, log_prob_grads)
                    step_dir = torch.matmul(self.fim, loss_grad)
                else:
                    F_inv = self.sherman_morrison_update1(self.fim, log_prob_grads, log_prob_grads)
                    step_dir = torch.matmul(F_inv, loss_grad)
            else:
                policy_loss = -(action_log_probs * advantages_batch).mean()

                loss_grad = torch.autograd.grad(policy_loss, self.policy.parameters(), retain_graph=True)
                loss_grad = get_flat_grad_from(loss_grad)

                log_prob_grads = torch.autograd.grad(action_log_probs.mean(), self.policy.parameters())
                log_prob_grads = get_flat_grad_from(log_prob_grads)

                step_dir = self.sherman_morrison_update2(loss_grad, log_prob_grads, lambdaa=self.damping_factor)

            step_dir -= dist_entropy * self.entropy_coeff

            params = get_flat_params_from(self.policy)
            # shs = 0.5 * torch.matmul(loss_grad, step_dir)
            # step_size = 1 / torch.sqrt(shs / 0.01)
            # full_step = step_size * step_dir
            # new_params = params - self.lr * full_step

            old_policy = copy.deepcopy(self.policy).to(self.device)

            # set_flat_params_to(old_policy, params)

            new_params = params - self.lr * step_dir
            set_flat_params_to(self.policy, new_params)

            kl += self.kl_divergence(self.policy, old_policy, states_batch)
            # self.adam_update(self.policy, -step_dir, learning_rate=self.lr)
            total_policy_loss += policy_loss
            total_value_loss += value_loss
            batch_log_probs.append(action_log_probs)

            if "xla" in str(self.device):
                import torch_xla.core.xla_model as xm
                xm.mark_step()

        with torch.no_grad():
            return (total_policy_loss / num_updates, total_value_loss / num_updates,
                    kl.mean(), torch.cat(batch_log_probs).mean(), advantages.mean())

    def adam_update(self, policy, natural_gradient, learning_rate):
        # Increment the time step
        self.t += 1
        if self.m == None or self.v == None:
            self.m = natural_gradient
            self.v = natural_gradient ** 2
        # Update the first moment vector (m) and the second moment vector (v)
        self.m = self.beta1 * self.m + (1 - self.beta1) * natural_gradient
        self.v = self.beta2 * self.v + (1 - self.beta2) * natural_gradient ** 2
        # Compute bias-corrected first moment estimate
        m_hat = self.m / (1 - self.beta1 ** self.t)

        # Compute bias-corrected second moment estimate
        v_hat = self.v / (1 - self.beta2 ** self.t)

        # Compute adaptive learning rate
        adaptive_learning_rate = learning_rate / (torch.sqrt(v_hat) + self.epsilon)
        # print(adaptive_learning_rate)
        params = get_flat_params_from(policy)
        new_params = params + learning_rate * natural_gradient + adaptive_learning_rate * m_hat
        # new_params = params + adaptive_learning_rate * (m_hat + natural_gradient)
        set_flat_params_to(policy, new_params)

    def sherman_morrison_update2(self, loss_grad, u, lambdaa=1.0):
        term1 = loss_grad / lambdaa

        uTu = torch.inner(u, loss_grad)
        term2_1 = u * uTu

        term2_2 = lambdaa ** 2 + lambdaa * torch.inner(u, u)

        term2 = term2_1 / term2_2
        return term1 - term2

    def sherman_morrison_update1(self, A_inv, u, vt, beta=0.8):
        # A_inv_u: prev ngd
        # vt: transpose of gradient of log likelihood
        A_inv_beta = 1 / (1 - beta)

        A_inv_u = torch.matmul(A_inv_beta * A_inv, u)

        A_inv_u_vt = torch.outer(A_inv_u, beta * vt)

        A_inv_u_vt_A_inv = torch.matmul(A_inv_u_vt, A_inv)

        vt_A_inv_u = torch.inner(beta * vt, A_inv_u)

        # vt_A_inv_u = torch.matmul(A_inv_u)
        term = A_inv_u_vt_A_inv / (1 + vt_A_inv_u)
        return A_inv - term

    def train(self, states, actions, rewards, masks):

        # returns = self.get_returns(rewards, masks).unsqueeze(1)
        values = self.critic(torch.Tensor(states))
        advantages, returns = self.estimate_advantages(rewards, masks, values, tau=0.95, device=self.device)

        train_info = {}
        train_info['policy_loss'] = 0
        train_info['value_loss'] = 0
        train_info['kl'] = 0

        policy_loss, value_loss, kl, log_probs, advantages = self.ac_npg_update1(states, actions, returns, advantages)

        if "xla" in str(self.device):
            import torch_xla.core.xla_model as xm
            xm.mark_step()
        train_info['policy_loss'] += policy_loss.item()
        train_info['value_loss'] += value_loss.item()
        train_info['episode_length'] = states.shape[0]
        train_info['kl'] = kl
        train_info['log_probs'] = log_probs.item()
        train_info['advantages'] = advantages.item()

        return train_info

    def init_online_fim(self, states, actions):
        self.diagonal_fim = self.diagonal_fisher_information(self.policy, states, actions)

    def kl_divergence(self, new_actor, old_actor, states):
        # def kl_divergence(self, new_actor, old_actor, states):
        if self.is_disc_action:
            p = new_actor(states)
            q = old_actor(states).detach()
            kl = (q * (q.log() - p.log())).sum(1, keepdim=True)
            return kl.mean()
        else:

            mu, std, logstd = new_actor(torch.Tensor(states))
            mu_old, std_old, logstd_old = old_actor(torch.Tensor(states))
            mu_old = mu_old.detach()
            std_old = std_old.detach()
            logstd_old = logstd_old.detach()

            kl = logstd_old - logstd + (std_old.pow(2) + (mu_old - mu).pow(2)) / \
                 (2.0 * std.pow(2)) - 0.5
            return kl.sum(1, keepdim=True)

    def compute_log_prob_gradient(self, log_prob, parameters):
        grad_log_prob = torch.autograd.grad(log_prob, parameters, retain_graph=True)
        grad_log_prob = torch.cat([g.view(-1) for g in grad_log_prob]).detach()  # Flatten gradients
        return grad_log_prob

    def estimate_advantages(self, rewards, mask, values, tau, device):
        rewards, masks, values = self.to_device(torch.device('cpu'), rewards, mask, values)
        tensor_type = type(rewards)
        deltas = tensor_type(rewards.size(0), 1)
        advantages = tensor_type(rewards.size(0), 1)

        prev_value = 0
        prev_advantage = 0
        for i in reversed(range(rewards.size(0))):
            deltas[i] = rewards[i] + self.gamma * prev_value * masks[i] - values[i]
            advantages[i] = deltas[i] + self.gamma * tau * prev_advantage * masks[i]

            prev_value = values[i, 0]
            prev_advantage = advantages[i, 0]

        returns = (values + advantages).detach()
        advantages = ((advantages - advantages.mean()) / advantages.std()).detach()

        advantages, returns = self.to_device(device, advantages, returns)
        return advantages, returns

    def update_model(self, model, new_params):
        index = 0
        for params in model.parameters():
            params_length = len(params.view(-1))
            new_param = new_params[index: index + params_length]
            new_param = new_param.view(params.size())
            params.data.copy_(new_param)
            index += params_length

    def to_device(self, device, *args):
        return [x.to(device) for x in args]