"""Implementation of the Policy Gradient algorithm."""

from __future__ import annotations

import time
from typing import Any
import torch
from omnisafe.adapter import OnPolicyAdapter
from omnisafe.algorithms import registry
from omnisafe.algorithms.base_algo import BaseAlgo
from omnisafe.common.buffer import VectorOnPolicyBuffer
from omnisafe.common.logger import Logger
from omnisafe.models.actor_critic.constraint_actor_critic import ConstraintActorCritic
from omnisafe.utils import distributed


@registry.register
class PolicyGradient(BaseAlgo):

    def __init__(self, w: float = w, **kwargs): 
        super().__init__(**kwargs)
        self.w = w
        
    def _init_env(self) -> None:
        self._env: OnPolicyAdapter = OnPolicyAdapter(
            self._env_id,
            self._cfgs.train_cfgs.vector_env_nums,
            self._seed,
            self._cfgs,
        )
        assert (self._cfgs.algo_cfgs.steps_per_epoch) % (
            distributed.world_size() * self._cfgs.train_cfgs.vector_env_nums
        ) == 0, 'The number of steps per epoch is not divisible by the number of environments.'
        self._steps_per_epoch: int = (
            self._cfgs.algo_cfgs.steps_per_epoch 
            // distributed.world_size()
            // self._cfgs.train_cfgs.vector_env_nums
        )

    def _init_model(self) -> None:
        self._actor_critic: ConstraintActorCritic = ConstraintActorCritic(
            obs_space=self._env.observation_space,
            act_space=self._env.action_space,
            model_cfgs=self._cfgs.model_cfgs,
            epochs=self._cfgs.train_cfgs.epochs,
        ).to(self._device)

        # two actor_critic
        self._actor_critic2: ConstraintActorCritic = ConstraintActorCritic(
            obs_space=self._env.observation_space,  
            act_space=self._env.action_space,
            model_cfgs=self._cfgs.model_cfgs,
            epochs=self._cfgs.train_cfgs.epochs,
        ).to(self._device)
        
        # model load  
        pretrained_model_path1 = r"path"
        pretrained_model_path2 = r"path2"
        loaded_params1 = torch.load(pretrained_model_path1)
        loaded_params2 = torch.load(pretrained_model_path2)
        
        # set model
        self._actor_critic.actor.load_state_dict(loaded_params1['pi'], strict=False)  
        self._actor_critic2.actor.load_state_dict(loaded_params2['pi'], strict=False)

        if distributed.world_size() > 1:
            distributed.sync_params(self._actor_critic)
            distributed.sync_params(self._actor_critic2)

    def _init(self) -> None:
        self._buf: VectorOnPolicyBuffer = VectorOnPolicyBuffer(
            obs_space=self._env.observation_space,
            act_space=self._env.action_space,
            size=self._steps_per_epoch,
            gamma=self._cfgs.algo_cfgs.gamma,
            lam=self._cfgs.algo_cfgs.lam,
            lam_c=self._cfgs.algo_cfgs.lam_c,
            advantage_estimator=self._cfgs.algo_cfgs.adv_estimation_method,
            standardized_adv_r=self._cfgs.algo_cfgs.standardized_rew_adv,
            standardized_adv_c=self._cfgs.algo_cfgs.standardized_cost_adv,
            penalty_coefficient=self._cfgs.algo_cfgs.penalty_coef,
            num_envs=self._cfgs.train_cfgs.vector_env_nums,
            device=self._device,
        )

    def _init_log(self) -> None:
        self._logger = Logger(
            output_dir=self._cfgs.logger_cfgs.log_dir,
            exp_name=self._cfgs.exp_name,
            seed=self._cfgs.seed,
            use_tensorboard=self._cfgs.logger_cfgs.use_tensorboard,
            use_wandb=self._cfgs.logger_cfgs.use_wandb,
            config=self._cfgs,
        )

        what_to_save: dict[str, Any] = {}
        what_to_save['pi'] = self._actor_critic.actor
        self._logger.setup_torch_saver(what_to_save)

        self._logger.register_key('Metrics/EpRet', window_length=50)
        self._logger.register_key('Metrics/EpCost', window_length=50)
        self._logger.register_key('Metrics/EpLen', window_length=50)
        self._logger.register_key('Train/Epoch')
        self._logger.register_key('Value/reward', window_length=1)  # 注册 'Value/reward' 键
        self._logger.register_key('Value/cost', window_length=1)  # 注册 'Value/cost' 键
            
    def learn(self) -> tuple[float, float, float]:
        start_time = time.time()
        self._logger.log('INFO: Start training')

        for epoch in range(self._cfgs.train_cfgs.epochs):
            epoch_time = time.time()
            
            # 使用加权后的actor采样数据
            self._env.rollout(
                steps_per_epoch=self._steps_per_epoch,
                agent1=self._actor_critic,
                agent2=self._actor_critic2,
                buffer=self._buf,
                logger=self._logger, 
                w=self.w,
            )
            data = self._buf.get()
            
            self._logger.store(
                {
                    'Train/Epoch': epoch,
                },    
            )
            
            self._logger.dump_tabular()

        ep_ret = self._logger.get_stats('Metrics/EpRet')[0]
        ep_cost = self._logger.get_stats('Metrics/EpCost')[0]  
        ep_len = self._logger.get_stats('Metrics/EpLen')[0]
        self._logger.close()

        return ep_ret, ep_cost, ep_len