# acktr_actor_critic.py
import time
import torch
import numpy as np
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler

from src.utils.KFAC import KFACOptimizer


class ActorCritic:
    def __init__(self, args, policy, critic, device=torch.device("cpu")):
        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.policy = policy.to(device)
        self.critic = critic.to(device)

        self.lr = args.lr
        self.critic_lr = args.critic_lr
        self._use_value_active_masks = getattr(args, "use_value_active_masks", False)
        self._use_policy_active_masks = getattr(args, "use_policy_active_masks", False)
        self.opti_eps = args.opti_eps
        self.weight_decay = args.weight_decay
        self.gamma = args.gamma
        self.batch_size = args.batch_size

        self.kfac_damping = args.kfac_damping
        self.kfac_factor_decay = args.kfac_factor_decay
        self.kfac_kl_clip = args.kfac_kl_clip
        self.kfac_update_freq = args.kfac_update_freq

        self.total_time = 0
        self._step = 0


        self.actor_optimizer = KFACOptimizer(
            self.policy,
            lr=self.lr,
            stat_decay=self.kfac_factor_decay,
            damping=self.kfac_damping,
            kl_clip=self.kfac_kl_clip,
            weight_decay=self.weight_decay,
        )


        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(), lr=self.critic_lr, eps=self.opti_eps, weight_decay=self.weight_decay
        )

    # def set_lr(self, lr, critic_lr=None):
    #     self.lr = lr
    #     # KFACOptimizer 与 torch.optim 接口一致
    #     for g in self.actor_optimizer.param_groups:
    #         g["lr"] = lr
    #     if critic_lr is not None:
    #         self.critic_lr = critic_lr
    #         for g in self.critic_optimizer.param_groups:
    #             g["lr"] = critic_lr

    @torch.no_grad()
    def _maybe_sample_actions(self, states):

        if hasattr(self.policy, "act"):
            out = self.policy.act(states, deterministic=False)
            if isinstance(out, tuple) and len(out) >= 2:
                return out[0], out[1]
            if isinstance(out, dict) and "actions" in out and "log_probs" in out:
                return out["actions"], out["log_probs"]
        if hasattr(self.policy, "sample"):
            actions = self.policy.sample(states)
            if hasattr(self.policy, "get_log_prob"):
                logp = self.policy.get_log_prob(states, actions)
            else:
                logp = None
            return actions, logp
        return None, None

    def _acktr_actor_step(self, states_batch, actions_batch, advantages_batch):

        if getattr(self.actor_optimizer, "acc_stats", False):
            self.actor_optimizer.zero_grad()
            with torch.no_grad():
                samp_actions, samp_logp = self._maybe_sample_actions(states_batch)
                if samp_actions is None:
                    samp_actions = actions_batch
            fisher_logp = self.policy.get_log_prob(states_batch, samp_actions)
            fisher_loss = -fisher_logp.mean()
            fisher_loss.backward(retain_graph=True)


        self.actor_optimizer.zero_grad()
        action_log_probs = self.policy.get_log_prob(states_batch, actions_batch)
        policy_loss = -(action_log_probs * advantages_batch).mean()
        policy_loss.backward()
        self.actor_optimizer.step()

        return policy_loss, action_log_probs


    def _adam_critic_step(self, values, returns):
        self.critic_optimizer.zero_grad()
        value_loss = 0.5 * (values - returns).pow(2).mean()
        value_loss.backward()
        self.critic_optimizer.step()
        return value_loss

    def ac_update(self, states, actions, returns, advantages):
        start_time = time.time()
        batch_log_probs = []

        if self.batch_size is not None and self.batch_size > 0:
            batch_size = self.batch_size if self.batch_size < states.shape[0] else states.shape[0]
        else:
            batch_size = states.shape[0]
        sampler = BatchSampler(
            SubsetRandomSampler(range(states.shape[0])),
            batch_size,
            drop_last=True)

        mean_policy_loss = 0.0
        mean_value_loss = 0.0

        num_updates = 0
        for indices in sampler:
            num_updates += 1
            states_batch = states[indices]
            actions_batch = actions[indices]
            advantages_batch = advantages[indices]
            returns_batch = returns[indices]

            values = self.critic(states_batch)
            
            self.critic_optimizer.zero_grad()
            value_loss = 0.5 * (values - returns).pow(2).mean()
            value_loss.backward()
            self.critic_optimizer.step()

            do_kfac_update = (self._step % self.kfac_update_freq) == 0
            self.actor_optimizer.acc_stats = do_kfac_update

            policy_loss, action_log_probs = self._acktr_actor_step(
                states_batch, actions_batch, advantages_batch
            )

            self.actor_optimizer.acc_stats = False


            mean_policy_loss += policy_loss.item()
            mean_value_loss += value_loss.item()
            batch_log_probs.append(action_log_probs.detach())

            self._step += 1

        end_time = time.time()
        self.total_time += end_time - start_time

        if "xla" in str(self.device):
            import torch_xla.core.xla_model as xm
            xm.mark_step()

        with torch.no_grad():
            return (
                torch.tensor(mean_policy_loss / num_updates),
                torch.tensor(mean_value_loss / num_updates),
                torch.cat(batch_log_probs).mean(),
                advantages.mean(),
            )

    def train(self, states, actions, rewards, masks):
        values = self.critic(torch.as_tensor(states, device=self.device))
        advantages, returns = self.estimate_advantages(rewards, masks, values, tau=0.95, device=self.device)

        train_info = {}
        train_info['policy_loss'] = 0
        train_info['value_loss'] = 0

        policy_loss, value_loss, log_probs, advantages = self.ac_update(
            torch.as_tensor(states, device=self.device),
            torch.as_tensor(actions, device=self.device),
            returns,
            advantages
        )

        if "xla" in str(self.device):
            import torch_xla.core.xla_model as xm
            xm.mark_step()

        train_info['policy_loss'] += float(policy_loss.item())
        train_info['value_loss'] += float(value_loss.item())
        train_info['total_time'] = self.total_time
        train_info['episode_length'] = states.shape[0]
        train_info['log_probs'] = float(log_probs.item())
        train_info['advantages'] = float(advantages.item())

        return train_info

    def estimate_advantages(self, rewards, mask, values, tau, device):
        rewards, masks, values = self.to_device(torch.device('cpu'), rewards, mask, values)
        tensor_type = type(rewards)
        deltas = tensor_type(rewards.size(0), 1)
        advantages = tensor_type(rewards.size(0), 1)

        prev_value = 0
        prev_advantage = 0
        for i in reversed(range(rewards.size(0))):
            deltas[i] = rewards[i] + self.gamma * prev_value * masks[i] - values[i]
            advantages[i] = deltas[i] + self.gamma * tau * prev_advantage * masks[i]
            prev_value = values[i, 0]
            prev_advantage = advantages[i, 0]

        returns = (values + advantages).detach()
        advantages = ((advantages - advantages.mean()) / advantages.std()).detach()
        advantages, returns = self.to_device(device, advantages, returns)
        return advantages, returns

    def to_device(self, device, *args):
        return [x.to(device) for x in args]
