# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import agent.data_augs as rad
import utils
from agent.sac_ae import Actor, Critic, weight_init, LOG_FREQ
from agent.transition_model import make_transition_model
from agent.decoder import make_decoder
from torch.distributions import MultivariateNormal, kl_divergence
from torch.nn import functional as F


class MYAgent(nn.Module):
    """Baseline algorithm with transition model and various decoder types."""

    def __init__(
            self,
            obs_shape,
            action_shape,
            device,
            hidden_dim=256,
            discount=0.99,
            init_temperature=0.01,
            alpha_lr=1e-3,
            alpha_beta=0.9,
            actor_lr=1e-3,
            actor_beta=0.9,
            actor_log_std_min=-10,
            actor_log_std_max=2,
            actor_update_freq=2,
            encoder_stride=2,
            critic_lr=1e-3,
            critic_beta=0.9,
            critic_tau=0.005,
            critic_target_update_freq=2,
            encoder_type='pixel',
            encoder_feature_dim=50,
            encoder_lr=1e-3,
            encoder_tau=0.005,
            decoder_type='pixel',
            decoder_lr=1e-3,
            decoder_update_freq=1,
            decoder_weight_lambda=0.0,
            transition_model_type='deterministic',
            num_layers=4,
            num_filters=32,
            data_augs='',
    ):
        super(MYAgent, self).__init__()
        self.reconstruction = False
        if decoder_type == 'reconstruction':
            decoder_type = 'pixel'
            self.reconstruction = True
        self.device = device
        self.discount = discount
        self.critic_tau = critic_tau
        self.encoder_tau = encoder_tau
        self.actor_update_freq = actor_update_freq
        self.critic_target_update_freq = critic_target_update_freq
        self.decoder_update_freq = decoder_update_freq
        self.decoder_type = decoder_type
        self.best_reward = -1e10
        self.Q_count = 0
        self.reset = False
        self.action_shape = action_shape
        self.data_augs = data_augs
        self.augs_funcs = {}

        aug_to_func = {
            'crop': rad.random_crop,
            'grayscale': rad.random_grayscale,
            'cutout': rad.random_cutout,
            'cutout_color': rad.random_cutout_color,
            'flip': rad.random_flip,
            'rotate': rad.random_rotation,
            'rand_conv': rad.random_convolution,
            'color_jitter': rad.random_color_jitter,
            'translate': rad.random_translate,
            'no_aug': rad.no_aug,
        }
        for aug_name in self.data_augs.split('-'):
            assert aug_name in aug_to_func, 'invalid data aug string'
            self.augs_funcs[aug_name] = aug_to_func[aug_name]

        self.actor = Actor(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, actor_log_std_min, actor_log_std_max,
            num_layers, num_filters, encoder_stride
        ).to(device)

        self.critic = Critic(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters, encoder_stride
        ).to(device)

        self.critic_vlm = Critic(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters, encoder_stride
        ).to(device)

        self.critic_target = Critic(
            obs_shape, action_shape, hidden_dim, encoder_type,
            encoder_feature_dim, num_layers, num_filters, encoder_stride
        ).to(device)

        self.critic_target.load_state_dict(self.critic.state_dict())

        self.transition_model = make_transition_model(
            transition_model_type, encoder_feature_dim, action_shape
        ).to(device)

        self.reward_decoder = nn.Sequential(
            nn.Linear(encoder_feature_dim + action_shape[0], 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Linear(512, 1)).to(device)

        decoder_params = list(self.transition_model.parameters()) + list(self.reward_decoder.parameters())

        # tie encoders between actor and critic
        self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
        self.init_temperature = init_temperature
        self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
        self.log_alpha.requires_grad = True
        # set target entropy to -|A|
        self.target_entropy = -np.prod(action_shape)

        self.decoder = None
        if decoder_type == 'pixel':
            # create decoder
            self.decoder = make_decoder(
                decoder_type, obs_shape, encoder_feature_dim, num_layers,
                num_filters
            ).to(device)
            self.decoder.apply(weight_init)
            decoder_params += list(self.decoder.parameters())

        self.decoder_optimizer = torch.optim.Adam(
            decoder_params,
            lr=decoder_lr,
            weight_decay=decoder_weight_lambda
        )

        # optimizer for critic encoder for reconstruction loss
        self.encoder_optimizer = torch.optim.Adam(
            self.critic.encoder.parameters(), lr=encoder_lr
        )

        # optimizers
        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999)
        )

        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999)
        )

        self.alpha_lr = alpha_lr
        self.alpha_beta = alpha_beta
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999))

        self.train()
        self.critic_target.train()

        # 退火baseline
        self.iter = 0
        kickstarting_coef_initial = 3.
        kickstarting_coef_decent = 0.01  # * 300
        kickstarting_coef_minimum = 0.1
        iter_with_ks = 60000
        self.iter_with_ks = iter_with_ks
        self.ks_coef = kickstarting_coef_initial
        self.ks_coef_minimum = kickstarting_coef_minimum
        self.ks_coef_descent = kickstarting_coef_decent

    def train(self, training=True):
        self.training = training
        self.actor.train(training)
        self.critic.train(training)
        if self.decoder is not None:
            self.decoder.train(training)

    def update_kickstarting_coef(self):
        self.iter += 1
        if self.ks_coef <= self.ks_coef_minimum:
            self.ks_coef = self.ks_coef_minimum
        else:
            if self.iter % 100 == 0:
                self.ks_coef -= self.ks_coef_descent

    def Sigmoid(self, x:torch.Tensor):
        """
        Generalized Sigmoid function with adjustable slope, shift, and range.

        Parameters:
            x (torch.Tensor): Input value(s).
            k (float): Slope parameter (default: 1).
            c (float): Shift parameter (default: 0).
            a (float): Lower bound of output range (default: 0).
            b (float): Upper bound of output range (default: 1).

        Returns:
            float or np.ndarray: Output of the generalized Sigmoid function.
        """
        a = 0.0
        b = 0.3
        k = -40
        c = 0.2
        return a + b / (1 + torch.exp(-k * (x - c)))

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def init_alpha(self):
        self.log_alpha = torch.tensor(np.log(0.0)).to(self.device)
        self.log_alpha.requires_grad = True
        self.target_entropy = -np.prod(self.action_shape)
        self.log_alpha_optimizer = torch.optim.Adam(
            [self.log_alpha], lr=self.alpha_lr, betas=(self.alpha_beta, 0.999)
        )

    def select_action(self, obs):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            obs = obs.unsqueeze(0)
            mu, _, _, _, _ = self.actor(obs, compute_pi=False, compute_log_pi=False)
            return mu.cpu().data.numpy().flatten()

    def sample_action(self, obs):
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            obs = obs.unsqueeze(0)
            mu, pi, _, _, _ = self.actor(obs, compute_log_pi=False)
            return pi.cpu().data.numpy().flatten()

    def update_critic(self, obs, action, reward, next_obs, not_done, L, step):
        with torch.no_grad():
            _, policy_action, log_pi, _, _ = self.actor(next_obs)
            target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
            target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_pi
            target_Q = reward + (not_done * self.discount * target_V)

        # get current Q estimates
        current_Q1, current_Q2 = self.critic(obs, action, detach_encoder=False)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        L.log('train_critic/loss', critic_loss, step)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        self.critic.log(L, step)

    def update_params(self):
        # 软更新Critic网络的权重
        tau = 0.005
        for target_param, param in zip(self.critic_vlm.parameters(), self.critic.parameters()):
            target_param.data.copy_((1.0 - tau) * target_param.data + tau * param.data)

    #
    def kl_divergence(self, means_RL, means_VLM, std_RL, std_VLM):
        # action1 和 action2 是两个模型的行动输出
        # 每个行动是一个二元组（steer, throttle）
        batch_size = means_RL.shape[0]

        covariances_RL = torch.diag_embed(std_RL).expand(batch_size, -1, -1)
        covariances_VLM = torch.diag_embed(std_VLM).expand(batch_size, -1, -1)

        # 创建多元高斯分布
        distribution_RL = MultivariateNormal(means_RL, covariances_RL)
        distribution_VLM = MultivariateNormal(means_VLM, covariances_VLM)

        # log_prob1 = distribution_RL.log_prob(means_RL)
        # prob2 = distribution_VLM.prob(means_VLM)

        # 计算KL散度
        # 注意：kl_div需要输入log(p(x))和q(x)，所以我们需要计算distribution的log_prob
        log_prob_mvn1 = distribution_RL.log_prob(means_RL)
        kl = nn.KLDivLoss(reduction='batchmean')
        kl_div = kl(log_prob_mvn1, distribution_VLM)

        return kl_div

    def update_actor_and_alpha(self, obs, next_obs, L, step, VLM_action, Loss_type="value", vlm_update_freq=10):
        if Loss_type == "loss":
            _, pi, log_pi, log_std, _ = self.actor(obs, detach_encoder=True)
            actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True)
            actor_Q = torch.min(actor_Q1, actor_Q2)
            #
            # # 生成一个形状为 [128, 2] 的张量，其元素值在 [0, 1) 之间
            # random_tensor = torch.rand(128, 2)
            # # 对第一列的值进行缩放和偏移，使其在 [-1, 1] 范围内
            # first_column = 2 * random_tensor[:, 0] - 1
            # # 第二列的值已经在 [0, 1] 范围内，无需改变
            # second_column = random_tensor[:, 1]
            # # 将两列合并为一个张量
            # result_tensor = torch.stack((first_column, second_column), dim=1).to(self.device)
            #
            # vlm_loss = F.mse_loss(pi, result_tensor)

            vlm_loss = 3.0 * F.mse_loss(pi, VLM_action)
            actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean() + vlm_loss

        elif Loss_type == "dist":
            mu, pi, log_pi, log_std, _ = self.actor(obs, detach_encoder=True)
            actor_Q1_pi, actor_Q2_pi = self.critic(obs, pi, detach_encoder=True)
            actor_Q_pi = torch.min(actor_Q1_pi, actor_Q2_pi)
            #
            actor_Q1_vlm, actor_Q2_vlm = self.critic_vlm(obs, VLM_action, detach_encoder=True)
            actor_Q_vlm = torch.min(actor_Q1_vlm, actor_Q2_vlm)

            actor_loss = (self.alpha.detach() * log_pi - actor_Q_pi).mean()

            #
            min_val = torch.min(actor_Q_vlm)
            max_val = torch.max(actor_Q_vlm)
            actor_Q_vlm_scale = (actor_Q_vlm - min_val) / (max_val - min_val + 1e-8)  # 归一化
            #
            # mu = actor_Q_vlm.mean()
            # sigma = actor_Q_vlm.std()
            # actor_Q_vlm_scale = (actor_Q_vlm - mu) / sigma  # Z-score标准化

            #
            min_val = torch.min(actor_Q2_pi)
            max_val = torch.max(actor_Q2_pi)
            actor_Q2_pi_scale = (actor_Q2_pi - min_val) / (max_val - min_val + 1e-8)  # 归一化
            # mu = actor_Q2_pi.mean()
            # sigma = actor_Q2_pi.std()
            # actor_Q2_pi_scale = (actor_Q2_pi - mu) / sigma  # Z-score标准化

            error_q = actor_Q_vlm_scale - actor_Q2_pi_scale
            relu = torch.relu(error_q)
            min_val = torch.min(relu)
            max_val = torch.max(relu)
            relu = (relu - min_val) / (max_val - min_val + 1e-8)  # 归一化

            vlm_loss = F.mse_loss(pi, VLM_action, reduction='none')
            vlm_loss = 1.0 * relu.detach() * vlm_loss
            vlm_loss = vlm_loss.mean()
            actor_loss += vlm_loss
            L.log('train_actor/relu', relu.mean(), step)
            L.log('train_actor/vlm_loss', vlm_loss, step)
            L.log('train_actor/actor_Q_vlm', actor_Q_vlm.mean(), step)
            L.log('train_actor/actor_Q_pi', actor_Q1_pi.mean(), step)
            L.log('train_actor/actor_Q_vlm_scale', actor_Q_vlm_scale.mean(), step)
            L.log('train_actor/actor_Q_pi_scale', actor_Q2_pi_scale.mean(), step)

            L.log('train_actor/vlm_loss_pi', F.mse_loss(pi, VLM_action).detach(), step)
            L.log('train_actor/vlm_loss_mu', F.mse_loss(mu, VLM_action).detach(), step)

            h = self.critic.encoder(obs, detach=True)
            pred_next_latent_mu, pred_next_latent_sigma = self.transition_model(torch.cat([h, VLM_action], dim=1))
            if pred_next_latent_sigma is None:
                pred_next_latent_sigma = torch.ones_like(pred_next_latent_mu)
            next_h = self.critic.encoder(next_obs, detach=True)
            diff = (pred_next_latent_mu - next_h.detach()) / pred_next_latent_sigma
            vlm_transition_loss = 0.5 * diff.detach().pow(2) + torch.log(pred_next_latent_sigma.detach())
            L.log('train_actor/diff', diff.mean(), step)
            L.log('train_actor/pred_next_latent_sigma', pred_next_latent_sigma.mean(), step)
            # target = vlm_loss.mean() + 10 * loss.mean()
            vlm_transition_loss = vlm_transition_loss * 10.0
            L.log('train_actor/vlm_transition_loss', vlm_transition_loss.mean(), step)

            # 没用先
            max_val = torch.max(vlm_transition_loss)
            if max_val > 1e-8:  # 检查是否需要进行归一化
                vlm_transition_loss_scale = vlm_transition_loss / (max_val + 1e-8)
            else:  # 或者跳过归一化
                vlm_transition_loss_scale = torch.zeros_like(vlm_transition_loss)
            L.log('train_actor/vlm_transition_loss_scale', vlm_transition_loss_scale.mean(), step)
            #
            target = vlm_transition_loss.mean()

            # if step > 50000:
            #     if self.reset is False:
            #         self.init_alpha()
            #         self.reset = True

        elif Loss_type == "anneal":
            _, pi, log_pi, log_std, _ = self.actor(obs, detach_encoder=True)
            actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True)
            actor_Q = torch.min(actor_Q1, actor_Q2)
            actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()
            # print("pi:", pi)
            # print("VLM_action:", VLM_action)

            if self.iter < self.iter_with_ks:
                vlm_loss = F.mse_loss(pi, VLM_action)
                actor_loss += self.ks_coef * vlm_loss
                self.update_kickstarting_coef()

        elif Loss_type == "random":
            _, pi, log_pi, log_std, _ = self.actor(obs, detach_encoder=True)
            actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True)
            actor_Q = torch.min(actor_Q1, actor_Q2)
            actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()

            # 创建一个大小为 [128, 1] 的随机数张量，数值范围在 -1 到 1 之间
            error_q = torch.rand(128, 1) * 2 - 1
            error_q = error_q.to(self.device)
            relu = torch.relu(error_q)
            min_val = torch.min(relu)
            max_val = torch.max(relu)
            relu = (relu - min_val) / (max_val - min_val + 1e-8)  # 归一化

            vlm_loss = F.mse_loss(pi, VLM_action, reduction='none')
            # print("vlm_loss1:", vlm_loss)
            # print("3.0 * relu.detach():", 3.0 * relu)
            vlm_loss = 3.0 * relu.detach() * vlm_loss
            # print("vlm_loss2:", vlm_loss)
            actor_loss += vlm_loss.mean()

        elif Loss_type == "value":
            # detach encoder, so we don't update it with the actor loss
            _, pi, log_pi, log_std, _ = self.actor(obs, detach_encoder=True)
            actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True)

            actor_Q = torch.min(actor_Q1, actor_Q2)
            actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()
        else:
            sys.exit()

        L.log('train_actor/loss', actor_loss, step)
        L.log('train_actor/target_entropy', self.target_entropy, step)
        entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1)
        L.log('train_actor/entropy', entropy.mean(), step)
        L.log('train_actor/log_std', log_std.mean(), step)

        # optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.actor.log(L, step)

        self.log_alpha_optimizer.zero_grad()
        alpha_loss = (self.alpha * (-log_pi - self.target_entropy).detach()).mean()

        if Loss_type == "dist":
            alpha_target = self.Sigmoid(target)
            L.log('train_alpha/alpha_target', alpha_target.mean(), step)
            alpha_loss += 10.0 * F.mse_loss(self.alpha.to(torch.float32), alpha_target.mean().detach())

        L.log('train_alpha/loss', alpha_loss, step)
        L.log('train_alpha/value', self.alpha, step)
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

    def update_transition_reward_model(self, obs, action, next_obs, reward, L, step):
        h = self.critic.encoder(obs)
        pred_next_latent_mu, pred_next_latent_sigma = self.transition_model(torch.cat([h, action], dim=1))
        if pred_next_latent_sigma is None:
            pred_next_latent_sigma = torch.ones_like(pred_next_latent_mu)

        next_h = self.critic.encoder(next_obs)
        diff = (pred_next_latent_mu - next_h.detach()) / pred_next_latent_sigma
        loss = torch.mean(0.5 * diff.pow(2) + torch.log(pred_next_latent_sigma))
        L.log('train_ae/transition_loss', loss, step)

        pred_next_reward = self.reward_decoder(torch.cat([h, action], dim=1))
        reward_loss = F.mse_loss(pred_next_reward, reward)
        total_loss = loss + reward_loss
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        total_loss.backward()
        self.encoder_optimizer.step()
        self.decoder_optimizer.step()

    def update_decoder(self, obs, action, target_obs, L, step):  # uses transition model
        # image might be stacked, just grab the first 3 (rgb)!
        assert target_obs.dim() == 4
        target_obs = target_obs[:, :3, :, :]

        h = self.critic.encoder(obs)
        if not self.reconstruction:
            next_h = self.transition_model.sample_prediction(torch.cat([h, action], dim=1))
            if target_obs.dim() == 4:
                # preprocess images to be in [-0.5, 0.5] range
                target_obs = utils.preprocess_obs(target_obs)
            rec_obs = self.decoder(next_h)
            loss = F.mse_loss(target_obs, rec_obs)
        else:
            rec_obs = self.decoder(h)
            loss = F.mse_loss(obs, rec_obs)

        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        loss.backward()

        self.encoder_optimizer.step()
        self.decoder_optimizer.step()
        L.log('train_ae/ae_loss', loss, step)

        self.decoder.log(L, step, log_freq=LOG_FREQ)

    def update(self, replay_buffer, L, step, Loss_type="value", vlm_update_freq=10):
        if Loss_type == "loss" or Loss_type == "value" or Loss_type == "anneal" or Loss_type == "random" or Loss_type == "dist":
            # 采样出来的VLM_action是动作值
            obs, action, _, VLM_action, _, reward, next_obs, not_done = replay_buffer.sample_raw()
        else:
            # 采样出来的VLM_action是字符串决策文本
            obs, action, _, VLM_action, _, reward, next_obs, not_done = replay_buffer.sample(self.augs_funcs)

        L.log('train/batch_reward', reward.mean(), step)

        self.update_critic(obs, action, reward, next_obs, not_done, L, step)
        self.update_transition_reward_model(obs, action, next_obs, reward, L, step)

        if step % self.actor_update_freq == 0:
            self.update_actor_and_alpha(obs, next_obs, L, step, VLM_action, Loss_type=Loss_type, vlm_update_freq=vlm_update_freq)

        if step % self.critic_target_update_freq == 0:
            utils.soft_update_params(
                self.critic.Q1, self.critic_target.Q1, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.Q2, self.critic_target.Q2, self.critic_tau
            )
            utils.soft_update_params(
                self.critic.encoder, self.critic_target.encoder,
                self.encoder_tau
            )

        if self.decoder is not None and step % self.decoder_update_freq == 0:  # decoder_type is pixel
            self.update_decoder(obs, action, next_obs, L, step)

    def save(self, model_dir, step):
        torch.save(
            self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step)
        )
        torch.save(
            self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step)
        )
        if self.decoder is not None:
            torch.save(
                self.decoder.state_dict(),
                '%s/decoder_%s.pt' % (model_dir, step)
            )

    def save_best(self, model_dir, episode_reward):
        save_best = False
        if episode_reward > self.best_reward:
            self.best_reward = episode_reward
            save_best = True
        if save_best:
            torch.save(
                self.actor.state_dict(), '%s/actor_best.pt' % (model_dir)
            )
            torch.save(
                self.critic.state_dict(), '%s/critic_best.pt' % (model_dir)
            )
            if self.decoder is not None:
                torch.save(
                    self.decoder.state_dict(),
                    '%s/decoder_best.pt' % (model_dir)
                )
            print('----------------save best model------------------')

    def load(self, model_dir, step):
        self.actor.load_state_dict(
            torch.load('%s/actor_%s.pt' % (model_dir, step))
        )
        self.critic.load_state_dict(
            torch.load('%s/critic_%s.pt' % (model_dir, step))
        )
        if self.decoder is not None:
            self.decoder.load_state_dict(
                torch.load('%s/decoder_%s.pt' % (model_dir, step))
            )

    def load_best(self, model_dir):
        self.actor.load_state_dict(
            torch.load('%s/actor_best.pt' % (model_dir))
        )
        self.critic.load_state_dict(
            torch.load('%s/critic_best.pt' % (model_dir))
        )
        self.critic_target.load_state_dict(
            torch.load('%s/critic_best.pt' % (model_dir))
        )

    def load_critic(self, model_dir):
        self.critic_vlm.load_state_dict(
            torch.load('%s/critic_99999.pt' % (model_dir)), strict=False
        )
