# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import sys

import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
import agent.data_augs as rad
import utils
from agent.fusion_action import Actor, Critic, weight_init, LOG_FREQ, ICM
from agent.transition_model import make_transition_model
from agent.decoder import make_decoder
from torch.distributions import MultivariateNormal, kl_divergence
from torch.nn import functional as F
import copy
from math import exp


class MYAgent(object):
    """Baseline algorithm with transition model and various decoder types."""

    def __init__(
            self,
            obs_shape,
            action_shape,
            device,
            hidden_dim=256,
            discount=0.99,
            init_temperature=0.01,
            alpha_lr=1e-3,
            alpha_beta=0.9,
            actor_lr=1e-3,
            actor_beta=0.9,
            actor_log_std_min=-10,
            actor_log_std_max=2,
            actor_update_freq=2,
            encoder_stride=2,
            critic_lr=1e-3,
            critic_beta=0.9,
            critic_tau=0.005,
            critic_target_update_freq=2,
            encoder_type='pixel',
            encoder_feature_dim=50,
            encoder_lr=1e-3,
            encoder_tau=0.005,
            decoder_type='pixel',
            decoder_lr=1e-3,
            decoder_update_freq=1,
            decoder_weight_lambda=0.0,
            transition_model_type='deterministic',
            num_layers=4,
            num_filters=32,
            data_augs='',
    ):
        super(MYAgent, self).__init__()
        self.reconstruction = False
        if decoder_type == 'reconstruction':
            decoder_type = 'pixel'
            self.reconstruction = True
        self.device = device
        self.discount = discount
        self.critic_tau = critic_tau
        self.encoder_tau = encoder_tau
        self.actor_update_freq = actor_update_freq
        self.critic_target_update_freq = critic_target_update_freq
        self.decoder_update_freq = decoder_update_freq
        self.decoder_type = decoder_type
        self.best_reward = -1e10
        self.Q_count = 0
        self.reset = False

        self.action_shape = action_shape
        self.data_augs = data_augs
        self.augs_funcs = {}

        aug_to_func = {
            'crop': rad.random_crop,
            'grayscale': rad.random_grayscale,
            'cutout': rad.random_cutout,
            'cutout_color': rad.random_cutout_color,
            'flip': rad.random_flip,
            'rotate': rad.random_rotation,
            'rand_conv': rad.random_convolution,
            'color_jitter': rad.random_color_jitter,
            'translate': rad.random_translate,
            'no_aug': rad.no_aug,
        }
        for aug_name in self.data_augs.split('-'):
            assert aug_name in aug_to_func, 'invalid data aug string'
            self.augs_funcs[aug_name] = aug_to_func[aug_name]

        self.actor = Actor(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, actor_log_std_min, actor_log_std_max,
            num_layers, num_filters, encoder_stride, device
        ).to(device)

        self.critic = Critic(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters, encoder_stride
        ).to(device)

        self.critic_vlm = Critic(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters, encoder_stride
        ).to(device)

        self.critic_target = Critic(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters, encoder_stride
        ).to(device)

        # self.critic_target_vlm = Critic(
        #     obs_shape, action_shape, hidden_dim, encoder_type,
        #     encoder_feature_dim, num_layers, num_filters, encoder_stride
        # ).to(device)

        self.critic_target.load_state_dict(self.critic.state_dict())
        # self.critic_target_vlm.load_state_dict(self.critic_vlm.state_dict())

        self.transition_model = make_transition_model(
            transition_model_type, encoder_feature_dim, action_shape
        ).to(device)

        self.reward_decoder = nn.Sequential(
            nn.Linear(encoder_feature_dim + action_shape[0], 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Linear(512, 1)).to(device)

        #
        # 逆模型：根据状态变化预测动作
        # self.inverse_model = nn.Sequential(
        #     nn.Linear(encoder_feature_dim * 2, 256),
        #     nn.ReLU(),
        #     nn.Linear(256, action_shape[0])
        # ).to(device)
        # self.icm = ICM(
        #     obs_shape, action_shape, hidden_dim, encoder_type,
        #     encoder_feature_dim, num_layers, num_filters, encoder_stride
        # ).to(device)
        # tie encoders between icm and critic
        # self.icm.encoder.copy_conv_weights_from(self.critic.encoder)

        # decoder_params = list(self.transition_model.parameters()) + list(self.reward_decoder.parameters()) + list(self.inverse_model.parameters())
        decoder_params = list(self.transition_model.parameters()) + list(self.reward_decoder.parameters())

        # tie encoders between actor and critic
        self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
        self.init_temperature = init_temperature
        self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
        self.log_alpha.requires_grad = True
        # set target entropy to -|A|
        self.target_entropy = -np.prod(action_shape)

        #
        self.eps_min = 0.1
        self.eps_start = 1.0
        self.eps_decay = 20000
        self.eps_test = 0.05
        self.eps_steps = 0
        self.testing = False
        self.eps = 0.1

        self.decoder = None
        if decoder_type == 'pixel':
            # create decoder
            self.decoder = make_decoder(
                decoder_type, obs_shape, encoder_feature_dim, num_layers,
                num_filters
            ).to(device)
            self.decoder.apply(weight_init)
            decoder_params += list(self.decoder.parameters())

        self.decoder_optimizer = torch.optim.Adam(
            decoder_params,
            lr=decoder_lr,
            weight_decay=decoder_weight_lambda
        )
        # optimizer for critic encoder for reconstruction loss
        self.encoder_optimizer = torch.optim.Adam(
            self.critic.encoder.parameters(), lr=encoder_lr
        )

        # optimizers
        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999)
        )

        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999)
        )
        # self.critic_optimizer_vlm = torch.optim.Adam(
        #     self.critic_vlm.parameters(), lr=critic_lr, betas=(critic_beta, 0.999)
        # )

        self.alpha_lr = alpha_lr
        self.alpha_beta = alpha_beta
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999))

        self.train()
        self.critic_target.train()
        # self.critic_target_vlm.train()

        # 退火baseline
        self.iter = 0
        kickstarting_coef_initial = 2.
        kickstarting_coef_decent = 0.01
        kickstarting_coef_minimum = 0.1
        iter_with_ks = 40000
        self.iter_with_ks = iter_with_ks
        self.ks_coef = kickstarting_coef_initial
        self.ks_coef_minimum = kickstarting_coef_minimum
        self.ks_coef_descent = kickstarting_coef_decent

        #
        # self.gt_fot_vlm = self.get_soft_gt_fot_vlm(128).to(device)
        # self.gt_fot_rl = self.get_soft_gt_fot_rl(128).to(device)

    def train(self, training=True):
        self.training = training
        self.actor.train(training)
        self.critic.train(training)
        # self.critic_vlm.train(training)

        if self.decoder is not None:
            self.decoder.train(training)

    def update_kickstarting_coef(self):
        self.iter += 1
        if self.ks_coef <= self.ks_coef_minimum:
            self.ks_coef = self.ks_coef_minimum
        else:
            if self.iter % 100 == 0:
                self.ks_coef -= self.ks_coef_descent

    def Sigmoid(self, x:torch.Tensor):
        """
        Generalized Sigmoid function with adjustable slope, shift, and range.

        Parameters:
            x (float or torch): Input value(s).
            k (float): Slope parameter (default: 1).
            c (float): Shift parameter (default: 0).
            a (float): Lower bound of output range (default: 0).
            b (float): Upper bound of output range (default: 1).

        Returns:
            float or np.ndarray: Output of the generalized Sigmoid function.
        """
        a = 0.0
        b = 0.3
        k = -40
        c = 0.4
        return a + b / (1 + torch.exp(-k * (x - c)))

    @property
    def alpha(self):
        return self.log_alpha.exp()

    @property
    def epsilon(self):
        if not self.testing:
            eps = self.eps_min + (self.eps_start - self.eps_min) * exp(-self.eps_steps / self.eps_decay)
            self.eps_steps += 1
        else:
            eps = self.eps_test
        return eps

    def init_alpha(self):
        self.log_alpha = torch.tensor(np.log(0.0)).to(self.device)
        self.log_alpha.requires_grad = True
        self.target_entropy = -np.prod(self.action_shape)
        self.log_alpha_optimizer = torch.optim.Adam(
            [self.log_alpha], lr=self.alpha_lr, betas=(self.alpha_beta, 0.999)
        )

    def select_action(self, obs, pre_obs, pre_action):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            obs = obs.unsqueeze(0)
            pre_obs = torch.FloatTensor(pre_obs).to(self.device)
            pre_obs = pre_obs.unsqueeze(0)

            pre_action = torch.FloatTensor(pre_action).to(self.device)
            pre_action = pre_action.unsqueeze(0)
            VLM_action = [0, 1]
            VLM_action = torch.FloatTensor(VLM_action).to(self.device)
            VLM_action = VLM_action.unsqueeze(0)
            mu, _, _, _, _, _ = self.actor(obs, pre_obs, pre_action, VLM_action, compute_pi=False, compute_log_pi=False, compute_std_weight=False)
            return mu.cpu().data.numpy().flatten()

    def sample_action(self, obs, pre_obs, pre_action, VLM_action):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            obs = obs.unsqueeze(0)
            #
            pre_obs = torch.FloatTensor(pre_obs).to(self.device)
            pre_obs = pre_obs.unsqueeze(0)
            #
            pre_action = torch.FloatTensor(pre_action).to(self.device)
            pre_action = pre_action.unsqueeze(0)
            #
            VLM_action = torch.FloatTensor(VLM_action).to(self.device)
            VLM_action = VLM_action.unsqueeze(0)
            _, pi, _, _, _, output = self.actor(obs, pre_obs, pre_action, VLM_action, compute_log_pi=False)

            # # 使用 argmax 获取最大值的索引
            # max_index = torch.argmax(output, dim=1)
            # # 创建一个one-hot编码的张量
            # pro_action = torch.zeros_like(output)
            # pro_action.scatter_(1, max_index.unsqueeze(1), 1)
            #
            # return pi.cpu().data.numpy().flatten(), pro_action.cpu().data.numpy().flatten()
            return pi.cpu().data.numpy().flatten()

    def update_critic(self, obs, action, VLM_action, pre_action, reward, next_obs, not_done, L, step):
        with torch.no_grad():
            _, policy_action, log_pi, _, _, _ = self.actor(next_obs, obs, pre_action, VLM_action)
            target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
            target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_pi
            target_Q = reward + (not_done * self.discount * target_V)

        # get current Q estimates
        current_Q1, current_Q2 = self.critic(obs, action, detach_encoder=False)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        L.log('train_critic/loss', critic_loss, step)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        self.critic.log(L, step)

    def update_critic_vlm(self, obs, action, VLM_action, pre_action, reward, next_obs, not_done, L, step):
        with torch.no_grad():
            _, policy_action, log_pi, _, _, _ = self.actor(next_obs, obs, pre_action, VLM_action)
            target_Q1, target_Q2 = self.critic_target_vlm(next_obs, policy_action)
            target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_pi
            target_Q = reward + (not_done * self.discount * target_V)

        # get current Q estimates
        current_Q1, current_Q2 = self.critic_vlm(obs, action, detach_encoder=False)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        L.log('train_critic/loss', critic_loss, step)

        # Optimize the critic
        self.critic_optimizer_vlm.zero_grad()
        critic_loss.backward()
        self.critic_optimizer_vlm.step()
        self.critic_vlm.log(L, step)

    def get_soft_gt_fot_rl(self, bs):
        # 批量大小 bs
        # 创建一个形状为 [bs, 1] 的张量，填充值为 0
        zeros_tensor = torch.zeros(bs, 1)
        # 创建一个形状为 [bs, 1] 的张量，填充值为 1
        ones_tensor = torch.ones(bs, 1)
        # 将两个张量堆叠起来，形成 [bs, 2] 的张量
        result_tensor = torch.cat([zeros_tensor, ones_tensor], dim=1)

        return result_tensor

    def get_soft_gt_fot_vlm(self, bs):
        # 批量大小 bs
        # 创建一个形状为 [bs, 1] 的张量，填充值为 0
        zeros_tensor = torch.zeros(bs, 1)
        # 创建一个形状为 [bs, 1] 的张量，填充值为 1
        ones_tensor = torch.ones(bs, 1)
        # 将两个张量堆叠起来，形成 [bs, 2] 的张量
        result_tensor = torch.cat([ones_tensor, zeros_tensor], dim=1)

        return result_tensor

    def update_actor_and_alpha(self, obs, next_obs, pre_obses, pre_actions, L, step, VLM_action, action, Loss_type="value", vlm_update_freq=10):
        # detach encoder, so we don't update it with the actor loss
        mu, pi, log_pi, log_std, _, ouput_weight = self.actor(obs, pre_obses, pre_actions, VLM_action, detach_encoder=True)

        #
        actor_Q1_pi, actor_Q2_pi = self.critic(obs, pi, detach_encoder=True)
        actor_Q_pi = torch.min(actor_Q1_pi, actor_Q2_pi)

        # actor_Q1_mu, actor_Q2_mu = self.critic(obs, mu, detach_encoder=True)
        # actor_Q_mu = torch.min(actor_Q1_mu, actor_Q2_mu)

        actor_Q1_vlm, actor_Q2_vlm = self.critic_vlm(obs, VLM_action, detach_encoder=True)
        actor_Q_vlm = torch.min(actor_Q1_vlm, actor_Q2_vlm)

        actor_loss = (self.alpha.detach() * log_pi - actor_Q_pi).mean()

        #
        # min_val = torch.min(actor_Q_vlm)
        # max_val = torch.max(actor_Q_vlm)
        # actor_Q_vlm_scale = (actor_Q_vlm - min_val) / (max_val - min_val + 1e-8)  # 归一化
        #
        # mu = actor_Q_vlm.mean()
        # sigma = actor_Q_vlm.std()
        # actor_Q_vlm_scale = (actor_Q_vlm - mu) / sigma  # Z-score标准化

        #
        # min_val = torch.min(actor_Q2_pi)
        # max_val = torch.max(actor_Q2_pi)
        # actor_Q2_pi_scale = (actor_Q2_pi - min_val) / (max_val - min_val + 1e-8)  # 归一化
        # mu = actor_Q2_pi.mean()
        # sigma = actor_Q2_pi.std()
        # actor_Q2_pi_scale = (actor_Q2_pi - mu) / sigma  # Z-score标准化

        error_q = actor_Q_vlm - actor_Q1_pi
        # error_q = actor_Q_vlm_scale - actor_Q2_pi_scale
        relu = torch.relu(error_q)
        relu_one_hot = (relu != 0).float()
        min_val = torch.min(relu)
        max_val = torch.max(relu)
        if max_val > 1e-8:  # 检查是否需要进行归一化
            relu = relu / (max_val + 1e-8)
        else:
            relu = torch.zeros_like(relu)  # 或者跳过归一化

        vlm_loss = F.mse_loss(pi, VLM_action, reduction='none')
        vlm_loss = 10.0 * relu.detach() * vlm_loss
        actor_loss += vlm_loss.mean()
        L.log('train_actor/relu', relu.mean(), step)
        L.log('train_actor/vlm_loss_relu', vlm_loss.mean(), step)
        L.log('train_actor/actor_Q_vlm', actor_Q_vlm.mean(), step)
        L.log('train_actor/actor_Q_pi', actor_Q1_pi.mean(), step)

        L.log('train_actor/vlm_loss_pi', F.mse_loss(pi, VLM_action).detach(), step)
        L.log('train_actor/vlm_loss_mu', F.mse_loss(mu, VLM_action).detach(), step)

        #
        h = self.critic.encoder(obs, detach=True)
        next_h = self.critic.encoder(next_obs, detach=True)
        pred_next_latent_mu, pred_next_latent_sigma = self.transition_model(torch.cat([h, VLM_action], dim=1))
        if pred_next_latent_sigma is None:
            pred_next_latent_sigma = torch.ones_like(pred_next_latent_mu)
        diff = (pred_next_latent_mu - next_h.detach()) / pred_next_latent_sigma
        vlm_transition_loss = 0.5 * diff.detach().pow(2) + torch.log(pred_next_latent_sigma.detach())
        L.log('train_actor/diff', diff.mean(), step)
        L.log('train_actor/pred_next_latent_sigma_vlm', pred_next_latent_sigma.mean(), step)
        vlm_transition_loss = vlm_transition_loss * 10.0
        L.log('train_actor/vlm_transition_loss', vlm_transition_loss.mean(), step)

        # 没用先
        pred_next_latent_mu, pred_next_latent_sigma = self.transition_model(torch.cat([h, pi], dim=1))
        if pred_next_latent_sigma is None:
            pred_next_latent_sigma = torch.ones_like(pred_next_latent_mu)
        diff = (pred_next_latent_mu - next_h.detach()) / pred_next_latent_sigma
        pi_transition_loss = 0.5 * diff.detach().pow(2) + torch.log(pred_next_latent_sigma.detach())
        L.log('train_actor/pred_next_latent_sigma_pi', pred_next_latent_sigma.mean(), step)
        pi_transition_loss = pi_transition_loss * 10.0
        L.log('train_actor/pi_transition_loss', pi_transition_loss.mean(), step)

        # 没用先
        max_val = torch.max(vlm_transition_loss)
        if max_val > 1e-8:  # 检查是否需要进行归一化
            vlm_transition_loss_scale = vlm_transition_loss / (max_val + 1e-8)
        else:   # 或者跳过归一化
            vlm_transition_loss_scale = torch.zeros_like(vlm_transition_loss)
        L.log('train_actor/vlm_transition_loss_scale', vlm_transition_loss_scale.mean(), step)
        #

        target = vlm_transition_loss.mean()

        L.log('train_actor/loss', actor_loss, step)
        L.log('train_actor/target_entropy', self.target_entropy, step)
        entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1)
        L.log('train_actor/entropy', entropy.mean(), step)
        L.log('train_actor/log_std', log_std.mean(), step)
        L.log('train_actor/log_pi', log_pi.mean(), step)

        # optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        self.actor.log(L, step)

        self.log_alpha_optimizer.zero_grad()
        alpha_loss = (self.alpha * (-log_pi - self.target_entropy).detach()).mean()

        alpha_target = self.Sigmoid(target)

        L.log('train_alpha/alpha_target', alpha_target.mean(), step)
        alpha_loss += 15.0 * F.mse_loss(self.alpha.to(torch.float32), alpha_target.mean().detach())

        L.log('train_alpha/loss', alpha_loss, step)
        L.log('train_alpha/value', self.alpha, step)
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

    def update_transition_reward_model(self, obs, action, next_obs, reward, L, step):
        h = self.critic.encoder(obs)
        pred_next_latent_mu, pred_next_latent_sigma = self.transition_model(torch.cat([h, action], dim=1))
        if pred_next_latent_sigma is None:
            pred_next_latent_sigma = torch.ones_like(pred_next_latent_mu)

        next_h = self.critic.encoder(next_obs)
        diff = (pred_next_latent_mu - next_h.detach()) / pred_next_latent_sigma
        loss = torch.mean(0.5 * diff.pow(2) + torch.log(pred_next_latent_sigma))
        L.log('train_ae/transition_loss', loss, step)
        #
        pred_next_reward = self.reward_decoder(torch.cat([h, action], dim=1))
        reward_loss = F.mse_loss(pred_next_reward, reward)
        #
        #
        # # 计算ICM损失和内在奖励
        # forward_loss = loss
        # # 预测动作和下一个状态
        # predicted_action = self.inverse_model(torch.cat([h, next_h], dim=1))
        # inverse_loss = F.cross_entropy(predicted_action, action.squeeze())
        # intrinsic_reward = 0.1 + forward_loss.sum()
        # self.eps = torch.clamp(10 * intrinsic_reward, 0, 1).item()
        # L.log('train_ae/eps', self.eps, step)

        # print("intrinsic_reward:", intrinsic_reward)
        # print("self.eps:", self.eps)
        #
        total_loss = loss + reward_loss
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        total_loss.backward()
        self.encoder_optimizer.step()
        self.decoder_optimizer.step()

    def update_decoder(self, obs, action, target_obs, L, step):  # uses transition model
        # image might be stacked, just grab the first 3 (rgb)!
        assert target_obs.dim() == 4
        target_obs = target_obs[:, :3, :, :]

        h = self.critic.encoder(obs)
        if not self.reconstruction:
            next_h = self.transition_model.sample_prediction(torch.cat([h, action], dim=1))
            if target_obs.dim() == 4:
                # preprocess images to be in [-0.5, 0.5] range
                target_obs = utils.preprocess_obs(target_obs)
            rec_obs = self.decoder(next_h)
            loss = F.mse_loss(target_obs, rec_obs)
        else:
            rec_obs = self.decoder(h)
            loss = F.mse_loss(obs, rec_obs)

        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        loss.backward()

        self.encoder_optimizer.step()
        self.decoder_optimizer.step()
        L.log('train_ae/ae_loss', loss, step)

        self.decoder.log(L, step, log_freq=LOG_FREQ)

    def update(self, replay_buffer, L, step, Loss_type="value", vlm_update_freq=10):
        obs, pre_obses, pre_actions, action, _, VLM_action, _, reward, next_obs, not_done = replay_buffer.sample_pre()
        L.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, action, VLM_action, pre_actions, reward, next_obs, not_done, L, step)
        # if step > 60000:
        #     self.update_critic_vlm(obs, action, VLM_action, pre_actions, reward, next_obs, not_done, L, step)
        self.update_transition_reward_model(obs, action, next_obs, reward, L, step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(obs, next_obs, pre_obses, pre_actions, L, step, VLM_action, action, Loss_type=Loss_type, vlm_update_freq=vlm_update_freq)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(
                self.critic.Q1, self.critic_target.Q1, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.Q2, self.critic_target.Q2, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.encoder, self.critic_target.encoder, self.encoder_tau
            )
            # if step > 60000:
            #     utils.soft_update_params(
            #         self.critic_vlm.Q1, self.critic_target_vlm.Q1, self.critic_tau
            #     )
            #     utils.soft_update_params(
            #         self.critic_vlm.Q2, self.critic_target_vlm.Q2, self.critic_tau
            #     )
            # if step > 60000:
            #     utils.soft_update_params(
            #         self.critic.encoder, self.critic_vlm.encoder,
            #         self.encoder_tau
            #     )
            #     utils.soft_update_params(
            #         self.critic_vlm.encoder, self.critic_target_vlm.encoder,
            #         self.encoder_tau
            #     )

        if self.decoder is not None and step % self.decoder_update_freq == 0:  # decoder_type is pixel
            self.update_decoder(obs, action, next_obs, L, step)

    def save(self, model_dir, step):
        torch.save(
            self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
        )
        torch.save(
            self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
        )
        if self.decoder is not None:
            torch.save(
                self.decoder.state_dict(),
                '%s/decoder_%s.pt' % (model_dir, step)
            )

    def save_best(self, model_dir, episode_reward):
        save_best = False
        if episode_reward > self.best_reward:
            self.best_reward = episode_reward
            save_best = True
        if save_best:
            torch.save(
                self.actor.state_dict(), '%s/actor_best.pt' % (model_dir)
            )
            torch.save(
                self.critic.state_dict(), '%s/critic_best.pt' % (model_dir)
            )
            if self.decoder is not None:
                torch.save(
                    self.decoder.state_dict(),
                    '%s/decoder_best.pt' % (model_dir)
                )
            print('----------------save best model------------------')

    def load(self, model_dir, step):
        self.actor.load_state_dict(
            torch.load('%s/actor_%s.pt' % (model_dir, step))
        )
        self.critic.load_state_dict(
            torch.load('%s/critic_%s.pt' % (model_dir, step))
        )
        if self.decoder is not None:
            self.decoder.load_state_dict(
                torch.load('%s/decoder_%s.pt' % (model_dir, step))
            )

    def load_best(self, model_dir):
        # self.actor.load_state_dict(
        #     torch.load('%s/actor_best.pt' % (model_dir)), strict=False
        # )
        self.critic.load_state_dict(
            torch.load('%s/critic_best.pt' % (model_dir)), strict=False
        )
        self.critic_target.load_state_dict(
            torch.load('%s/critic_best.pt' % (model_dir)), strict=False
        )

    def load_critic(self, model_dir):
        self.critic_vlm.load_state_dict(
            torch.load('%s/critic_99999.pt' % (model_dir)), strict=False
        )



