from typing import Dict, Tuple, List, Optional
from functools import reduce
import numpy as np
from numpy.core.fromnumeric import partition
from algos.utils_finetune import Every
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from logger import logger
from logger import create_stats_ordered_dict
from torch.distributions import Normal
import wandb
from functorch import combine_state_for_ensemble, vmap

LOG_SIG_MAX = 2
LOG_SIG_MIN = -20


class Actor(nn.Module):
    def __init__(self, state_dim, latent_dim, max_action, device):
        super(Actor, self).__init__()
        hidden_size = (256, 256, 256)

        self.pi1 = nn.Linear(state_dim, hidden_size[0])
        self.pi2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.pi3 = nn.Linear(hidden_size[1], hidden_size[2])
        self.pi4 = nn.Linear(hidden_size[2], latent_dim)

        self.max_action = max_action

    def forward(self, state):
        a = F.relu(self.pi1(state))
        a = F.relu(self.pi2(a))
        a = F.relu(self.pi3(a))
        a = self.pi4(a)
        a = self.max_action * torch.tanh(a)

        return a


class ActorVAE(nn.Module):
    def __init__(
        self,
        state_dim,
        action_dim,
        latent_dim,
        max_action,
        device,
        deterministic: bool = True,
        std_architecture: str = "noise_vector",
        log_sig_max: float = 2.0,
        log_sig_min: float = -20.0,
    ):
        super(ActorVAE, self).__init__()
        hidden_size = (256, 256, 256)

        self.e1 = nn.Linear(state_dim + action_dim, hidden_size[0])
        self.e2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.e3 = nn.Linear(hidden_size[1], hidden_size[2])

        self.mean = nn.Linear(hidden_size[2], latent_dim)
        self.log_var = nn.Linear(hidden_size[2], latent_dim)

        self.d1 = nn.Linear(state_dim + latent_dim, hidden_size[0])
        self.d2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.d3 = nn.Linear(hidden_size[1], hidden_size[2])
        self.d4 = nn.Linear(hidden_size[2], action_dim)

        if not deterministic:
            if std_architecture == "noise_vector":
                self.action_log_std_logits = nn.Parameter(
                    torch.zeros(action_dim, requires_grad=True)
                )
            elif (
                std_architecture == "linear_layer"
                or std_architecture == "linear_layer_stop_grad"
            ):
                self.action_log_std_linear = nn.Linear(hidden_size[2], action_dim)
            elif std_architecture == "mlp_stop_grad":
                self.action_log_std_1 = nn.Linear(hidden_size[2], hidden_size[2])
                self.action_log_std_2 = nn.Linear(hidden_size[2], action_dim)
            else:
                raise NotImplementedError

        self.max_action = max_action
        self.action_dim = action_dim
        self.latent_dim = latent_dim
        self.device = device
        self.deterministic = deterministic
        self.std_architecture = std_architecture
        self._log_sig_max = log_sig_max
        self._log_sig_min = log_sig_min

    def forward(self, state, action):
        z = F.relu(self.e1(torch.cat([state, action], 1)))
        z = F.relu(self.e2(z))
        z = F.relu(self.e3(z))

        mean = self.mean(z)
        log_var = self.log_var(z)
        std = torch.exp(log_var / 2)
        z = mean + std * torch.randn_like(std)

        u = self.decode(state, z)

        return u, z, mean, log_var

    def decode(self, state, z=None, clip=None):
        # When sampling from the VAE, the latent vector is clipped
        if z is None:
            clip = self.max_action
            z = (
                torch.randn((state.shape[0], self.latent_dim))
                .to(self.device)
                .clamp(-clip, clip)
            )

        a = F.relu(self.d1(torch.cat([state, z], 1)))
        a = F.relu(self.d2(a))
        trunk = F.relu(self.d3(a))
        action_mean = self.d4(trunk)
        if not self.deterministic:
            if (
                self.std_architecture == "linear_layer"
                or self.std_architecture == "linear_layer_stop_grad"
            ):
                log_std = torch.sigmoid(self.action_log_std_linear(trunk.detach()))
            elif self.std_architecture == "noise_vector":
                log_std = torch.sigmoid(self.action_log_std_logits)
            elif self.std_architecture == "mlp_stop_grad":
                x = F.relu(self.action_log_std_1(trunk.detach()))
                log_std = torch.sigmoid(self.action_log_std_2(x))
            else:
                raise NotImplementedError
            # log_std = LOG_SIG_MIN + log_std * (LOG_SIG_MAX - LOG_SIG_MIN)
            log_std = self._log_sig_min + log_std * (
                self._log_sig_max - self._log_sig_min
            )
            std = torch.exp(log_std)
            return Normal(action_mean, std)
        return action_mean


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, device, train_v_func: bool = True):
        super(Critic, self).__init__()

        hidden_size = (256, 256, 256)

        self.l1 = nn.Linear(state_dim + action_dim, hidden_size[0])
        self.l2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.l3 = nn.Linear(hidden_size[1], hidden_size[2])
        self.l4 = nn.Linear(hidden_size[2], 1)

        self._train_v_func = train_v_func
        if train_v_func:
            self.v1 = nn.Linear(state_dim, hidden_size[0])
            self.v2 = nn.Linear(hidden_size[0], hidden_size[1])
            self.v3 = nn.Linear(hidden_size[1], hidden_size[2])
            self.v4 = nn.Linear(hidden_size[2], 1)

    # def forward(self, state, action):
    def forward(self, state_and_action):
        q1 = F.relu(self.l1(state_and_action))
        q1 = F.relu(self.l2(q1))
        q1 = F.relu(self.l3(q1))
        q1 = self.l4(q1)
        return q1

    def v(self, state):
        assert self._train_v_func
        v = F.relu(self.v1(state))
        v = F.relu(self.v2(v))
        v = F.relu(self.v3(v))
        v = self.v4(v)
        return v


class CriticEnsemble(nn.Module):
    def __init__(
        self,
        state_dim,
        action_dim,
        device,
        num_models: int = 2,
        # train_v_func: bool = True,
    ):
        super(CriticEnsemble, self).__init__()

        hidden_size = (256, 256, 256)

        self._num_models = num_models
        self.models = [
            Critic(state_dim, action_dim, device, train_v_func=False).to(device)
            for _ in range(num_models)
        ]
        (
            self.combined_model,
            self.combined_params,
            self.combined_buffers,
        ) = combine_state_for_ensemble(self.models)
        [p.requires_grad_() for p in self.combined_params]
        # self._train_v_func = train_v_func
        # if train_v_func:
        #     self._module_layers["v1"] = nn.Linear(state_dim, hidden_size[0])
        #     self._module_layers["v2"] = nn.Linear(hidden_size[0], hidden_size[1])
        #     self._module_layers["v3"] = nn.Linear(hidden_size[1], hidden_size[2])
        #     self._module_layers["v4"] = nn.Linear(hidden_size[2], 1)

    def forward(
        self, state, action, model_indices: Optional[List[int]] = None
    ) -> List[torch.Tensor]:
        if model_indices is None:
            model_indices = range(self._num_models)
        elif type(model_indices) is int:
            model_indices = [model_indices]
        model_returns = vmap(self.combined_model, in_dims=(0, 0, None))(
            self.combined_params, self.combined_buffers, torch.cat([state, action], 1)
        )
        if len(model_returns) == 1:
            model_returns = model_returns[0]
        return model_returns

    def v(self, state: torch.Tensor) -> torch.Tensor:
        assert self._train_v_func
        v = F.relu(self._module_layers["v1"](state))
        v = F.relu(self._module_layers["v2"](v))
        v = F.relu(self._module_layers["v3"](v))
        v = self._module_layers["v4"](v)
        return v


class Latent(object):
    def __init__(
        self,
        state_dim,
        action_dim,
        latent_dim,
        max_action,
        min_v,
        max_v,
        replay_buffer,
        device,
        discount=0.99,
        tau=0.005,
        vae_lr=1e-4,
        actor_lr=1e-4,
        critic_lr=5e-4,
        max_latent_action=1,
        expectile=0.8,
        kl_beta=1.0,
        no_piz=False,
        no_noise=True,
        doubleq_min=0.8,
        deterministic_actions: bool = True,
        std_architecture: str = "noise_vector",
        num_q_functions: int = 2,
        entropy_term_weight: float = 0,
        dataset_noise: bool = False,
        log_sig_max: float = 2.0,
        log_sig_min: float = -20.0,
        optimism_parameter: float = 0.0,
        optimism_threshold: float = -1,  # If disagreement > threshold, don't apply optimism
        gradient_clipping: Optional[float] = None,
        not_train_v_func: bool = False,
        redq_subset_size: int = 2,
    ):

        self.device = torch.device(device)
        self.actor_vae = ActorVAE(
            state_dim,
            action_dim,
            latent_dim,
            max_latent_action,
            self.device,
            deterministic_actions,
            std_architecture,
            log_sig_max,
            log_sig_min,
        ).to(self.device)
        self.actor_vae_target = copy.deepcopy(self.actor_vae)
        self.actorvae_optimizer = torch.optim.Adam(
            self.actor_vae.parameters(), lr=vae_lr
        )

        self.actor = Actor(state_dim, latent_dim, max_latent_action, self.device).to(
            self.device
        )
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)

        train_v_func = not not_train_v_func
        assert not (no_piz and not_train_v_func)
        assert not (not_train_v_func and not no_noise)
        assert redq_subset_size <= num_q_functions
        self._num_q_functions = num_q_functions
        self._redq_subset_size = redq_subset_size
        self._train_v_func = train_v_func
        # self.critic = Critic(state_dim, action_dim, self.device).to(self.device)
        # self.critic_ensemble = CriticEnsemble(
        #     state_dim,
        #     action_dim,
        #     self.device,
        #     num_models=num_q_functions,
        #     # train_v_func=train_v_func,
        # )  # .to(self.device)
        # self.critic_target = copy.deepcopy(self.critic_ensemble)
        # (
        #     self.critic_target.combined_model,
        #     self.critic_target.combined_params,
        #     self.critic_target.combined_buffers,
        # ) = combine_state_for_ensemble(self.critic_target.models)
        # self.critic_optimizer = torch.optim.Adam(
        #     self.critic_ensemble.combined_params, lr=critic_lr
        # )

        self.critic_lr = critic_lr
        self.latent_dim = latent_dim
        self.max_action = max_action
        self.max_latent_action = max_latent_action
        self.action_dim = action_dim
        self.discount = discount
        self.tau = tau
        self.tau_vae = tau

        self.expectile = expectile
        self.kl_beta = kl_beta
        self.no_piz = no_piz
        self.no_noise = no_noise
        self.doubleq_min = doubleq_min

        self.replay_buffer = replay_buffer
        self.min_v, self.max_v = min_v, max_v
        self.gradient_clipping = gradient_clipping
        self._entropy_term_weight = entropy_term_weight
        self._dataset_noise = dataset_noise
        self._optimism_parameter = optimism_parameter
        self._optimism_threshold = optimism_threshold
        self._log_disagreement_histogram = Every(10000)

    def explore(self, state):
        with torch.no_grad():
            state = self.replay_buffer.normalize_state(state)
            state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)

            if np.random.rand() > 0.9:
                action = self.actor_vae.decode(state).cpu().data.numpy().flatten()
            else:
                latent_a = self.actor(state)
                action = (
                    self.actor_vae.decode(state, z=latent_a)
                    .cpu()
                    .data.numpy()
                    .flatten()
                )

            action = self.replay_buffer.unnormalize_action(action)
            action = np.clip(action, -1, 1)

        return action

    def select_action(
        self,
        state,
        deterministic: bool = True,
        num_actions_to_sample: int = 1,
        max_std_dev: float = -1,
        latent_noise: float = 0,
        action_space_noise: float = 0,
    ) -> Tuple[np.ndarray, Dict]:
        info = {}
        if (latent_noise > 0 or action_space_noise > 0) and num_actions_to_sample > 1:
            # For now, we only support adding noise and sampling multiple actions
            # when the model is deterministic.
            assert deterministic
        with torch.no_grad():
            state = self.replay_buffer.normalize_state(state)
            state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
            action_samples = None

            if self.no_piz:
                action = self.actor_vae.decode(state).cpu().data.numpy().flatten()
            else:
                latent_a = self.actor(state)
                if latent_noise > 0:
                    # We sample num_actions_to_sample - 2 because we're going to add
                    # 1 sample with original latent_a and action space noise, and another
                    # sample with neither latent noise nor action space noise.
                    latent_action_samples = [
                        (latent_a + torch.randn_like(latent_a) * latent_noise).clamp(
                            -self.actor_vae.max_action, -self.actor_vae.max_action
                        )
                        for _ in range(num_actions_to_sample - 2)
                    ]
                    latent_action_samples.append(latent_a)
                else:
                    latent_action_samples = None
                if self.actor_vae.deterministic:
                    if num_actions_to_sample == 1:
                        action = (
                            self.actor_vae.decode(state, z=latent_a)
                            .cpu()
                            .data.numpy()
                            .flatten()
                        ) + torch.randn_like(action) * action_space_noise
                    else:
                        action_no_noise = self.actor_vae.decode(state, z=latent_a)
                        if latent_action_samples is not None:
                            action_samples = [
                                (
                                    self.actor_vae.decode(state, z=latent_action_sample)
                                    + torch.randn_like(action_no_noise)
                                    * action_space_noise
                                ).flatten()
                                for latent_action_sample in latent_action_samples
                            ]
                        else:
                            action_samples = [
                                action_no_noise
                                + torch.randn_like(action_no_noise) * action_space_noise
                                for _ in range(num_actions_to_sample - 1)
                            ]
                        action_samples.append(action_no_noise)

                else:
                    action_distribution = self.actor_vae.decode(state, z=latent_a)
                    info[
                        "action_distribution_entropy"
                    ] = action_distribution.entropy().mean()
                    info["action_std_dev"] = (
                        action_distribution.stddev.mean().cpu().data.numpy()
                    )
                    if deterministic:
                        if num_actions_to_sample == 1 or (
                            latent_noise == 0 and action_space_noise == 0
                        ):
                            action = (
                                action_distribution.mean.cpu().data.numpy().flatten()
                            )
                        else:
                            action_no_noise = action_distribution.mean.flatten()
                            if latent_action_samples is not None:
                                action_samples = [
                                    (
                                        self.actor_vae.decode(
                                            state, z=latent_action_sample
                                        ).mean
                                        + torch.randn_like(action_no_noise)
                                        * action_space_noise
                                    ).flatten()
                                    for latent_action_sample in latent_action_samples
                                ]
                            else:
                                action_samples = [
                                    action_no_noise
                                    + torch.randn_like(action_no_noise)
                                    * action_space_noise
                                    for _ in range(num_actions_to_sample - 1)
                                ]
                            action_samples.append(action_no_noise)
                            action_samples = torch.stack(action_samples).view(
                                num_actions_to_sample, self.action_dim
                            )
                    else:
                        if max_std_dev != -1:
                            action_distribution = Normal(
                                action_distribution.mean,
                                torch.clamp(
                                    action_distribution.stddev, max=max_std_dev
                                ),
                            )
                        if num_actions_to_sample <= 1:
                            action = (
                                action_distribution.rsample()
                                .cpu()
                                .data.numpy()
                                .flatten()
                            )
                        else:
                            action_samples = action_distribution.rsample(
                                [num_actions_to_sample]
                            ).view(num_actions_to_sample, self.action_dim)
            if action_samples is not None:
                info["action_samples"] = action_samples
                action_qs_values = self.critic_target(
                    state.repeat(num_actions_to_sample, 1),
                    action_samples,
                )
                action_q_values = torch.mean(action_qs_values, dim=0).view(
                    num_actions_to_sample
                )
                if self._optimism_parameter > 0:
                    q_values_disagreement = torch.std(action_qs_values, dim=0).view(
                        num_actions_to_sample
                    )
                    info["q_values_disagreement_mean"] = (
                        q_values_disagreement.mean().cpu().data.numpy()
                    )
                    info["q_values_mean"] = action_q_values.mean().cpu().data.numpy()
                    if self._log_disagreement_histogram():
                        data = [
                            [d, q]
                            for d, q in zip(
                                q_values_disagreement.cpu().data.numpy(),
                                action_q_values.cpu().data.numpy(),
                            )
                        ]
                        table = wandb.Table(
                            data=data, columns=["q_values_disagreements", "q_values"]
                        )
                        wandb.log(
                            {
                                "q_value_disagreement_histogram": wandb.plot.histogram(
                                    table,
                                    "q_values_disagreements",
                                    title="Q-value disagreement distribution",
                                )
                            }
                        )
                        wandb.log(
                            {
                                "q_values_histogram": wandb.plot.histogram(
                                    table,
                                    "q_values",
                                    title="Q-value distribution",
                                )
                            }
                        )

                    if self._optimism_threshold != -1:
                        bonus_mask = q_values_disagreement <= self._optimism_threshold
                    else:
                        bonus_mask = torch.ones_like(q_values_disagreement)
                    action_q_values = (
                        action_q_values
                        + self._optimism_parameter * q_values_disagreement * bonus_mask
                    )

                action_index = torch.argmax(action_q_values)
                action = action_samples[action_index].cpu().data.numpy().flatten()

            action = self.replay_buffer.unnormalize_action(action)
            action = np.clip(action, -1, 1)

        return action, info

    def kl_loss(self, mu, log_var):
        KL_loss = -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1).view(
            -1, 1
        )
        return KL_loss

    def train_critic(self, iterations: int, batch_size: int):
        for it in range(iterations):
            state, action, next_state, reward, not_done = self.replay_buffer.sample(
                batch_size
            )
            with torch.no_grad():
                # TODO: add option for sampling from action distribution multiple times.
                if self.no_piz:
                    actor_action = self.actor_vae_target.decode(state)
                    if not self.actor_vae.deterministic:
                        actor_action = actor_action.rsample()
                else:
                    latent_action = self.actor_target(state)
                    if not self.no_noise:
                        latent_action += (torch.randn_like(latent_action) * 0.2).clamp(
                            -0.5, 0.5
                        )
                    actor_action = self.actor_vae_target.decode(state, z=latent_action)
                    if not self.actor_vae.deterministic:
                        # actor_action = actor_action.rsample()
                        actor_action = actor_action.mean
                        # actor_action += (torch.randn_like(actor_action) * 0.1).clamp(-0.3, 0.3)

                if self._train_v_func:
                    next_target_v = self.critic_target.v(next_state)
                else:
                    # Get next actor action
                    if self.no_piz:
                        next_actor_action = self.actor_vae_target.decode(next_state)
                        if not self.actor_vae.deterministic:
                            next_actor_action = next_actor_action.rsample()
                    else:
                        next_latent_action = self.actor_target(next_state)
                        next_actor_action = self.actor_vae_target.decode(
                            next_state, z=next_latent_action
                        )
                        if not self.actor_vae.deterministic:
                            # actor_action = actor_action.rsample()
                            next_actor_action = next_actor_action.mean
                    # Since we're not training a V-function, we place the
                    # min operation from REDQ here.
                    model_indices = np.random.choice(
                        self._num_q_functions, self._redq_subset_size, replace=False
                    )
                    next_target_vs = self.critic_target(
                        next_state,
                        next_actor_action,
                    )
                    next_target_vs = next_target_vs[model_indices]
                    next_target_v = torch.min(
                        next_target_vs,
                        dim=0,
                    ).values * self.doubleq_min + torch.max(
                        next_target_vs, dim=0
                    ).values * (
                        1 - self.doubleq_min
                    )
                target_Q = reward + not_done * self.discount * next_target_v
                # target_Qs = [
                #     reward_for_ensemble[
                #         batch_size * model_index : batch_size * model_index
                #         + batch_size
                #     ]
                #     + not_done_for_ensemble[
                #         batch_size * model_index : batch_size * model_index
                #         + batch_size
                #     ]
                #     * self.discount
                #     * next_target_v[
                #         batch_size * model_index : batch_size * model_index
                #         + batch_size
                #     ]
                #     for model_index in range(self.critic_ensemble._num_models)
                # ]

                if self._train_v_func:
                    target_v1, target_v2 = (
                        self.critic_target(state, actor_action, model_indices=0),
                        self.critic_target(state, actor_action, model_indices=1),
                    )
                    target_v = torch.min(
                        target_v1, target_v2
                    ) * self.doubleq_min + torch.max(target_v1, target_v2) * (
                        1 - self.doubleq_min
                    )

            # current_Q1, current_Q2 = self.critic(state, action)
            # current_Qs = [
            #     # self.critic_ensemble(state, action)
            #     self.critic_ensemble(
            #         state_for_ensemble[
            #             batch_size * model_index : batch_size * model_index
            #             + batch_size
            #         ],
            #         action_for_ensemble[
            #             batch_size * model_index : batch_size * model_index
            #             + batch_size
            #         ],
            #         model_index=model_index,
            #     )
            #     for model_index in range(self.critic_ensemble._num_models)
            # ]
            current_Qs = self.critic_ensemble(state, action)
            # current_v = self.critic.v(state)
            if self._train_v_func:
                current_v = self.critic_ensemble.v(state)
                v_loss = F.mse_loss(current_v, target_v.clamp(self.min_v, self.max_v))
            else:
                v_loss = 0
            # critic_loss_1 = F.mse_loss(current_Q1, target_Q)
            # critic_loss_2 = F.mse_loss(current_Q2, target_Q)
            # critic_loss = critic_loss_1 + critic_loss_2 + v_loss
            # q_losses = [
            #     F.mse_loss(current_Qs[model_index], target_Qs[model_index])
            #     for model_index in range(self.critic_ensemble._num_models)
            # ]
            q_losses = [
                F.mse_loss(
                    current_Qs[model_index], target_Q.clamp(self.min_v, self.max_v)
                )
                for model_index in range(self.critic_ensemble._num_models)
            ]
            for q_index, q_loss in enumerate(q_losses):
                if q_index > 4:
                    break
                logger.record_tabular(
                    f"q_loss_{q_index}",
                    q_losses[q_index].detach().cpu().data.numpy(),
                )
            if self._train_v_func:
                logger.record_tabular("v_loss", v_loss.detach().cpu().data.numpy())
            critic_loss = sum(q_losses) + v_loss

            logger.record_tabular(
                "critic_loss", critic_loss.detach().cpu().data.numpy()
            )

            # previous_critic_params = copy.deepcopy(self.critic_ensemble.combined_params)

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            # Log grad norm
            critic_norm = 0
            for p in self.critic_ensemble.combined_params:
                param_norm = p.grad.detach().data.norm(2)
                critic_norm += param_norm.item() ** 2
            total_critic_norm = critic_norm**0.5
            logger.record_tabular("critic_grad_norm", total_critic_norm)

            torch.nn.utils.clip_grad_norm_(
                self.critic_ensemble.parameters(), self.gradient_clipping
            )
            self.critic_optimizer.step()
            # update_distance = 0
            # for previous_param, updated_param in zip(
            #     previous_critic_params, self.critic_ensemble.combined_params
            # ):
            #     update_distance += (previous_param - updated_param).norm() ** 2
            # update_distance = update_distance**0.5
            # Update Target Networks
            for param, target_param in zip(
                self.critic_ensemble.combined_params, self.critic_target.combined_params
            ):
                target_param.data.copy_(
                    self.tau * param.data + (1 - self.tau) * target_param.data
                )
            logger.record_tabular(
                "target_q_norm", target_Q.norm().detach().cpu().data.numpy()
            )
            logger.record_tabular(
                "target_q_mean", target_Q.mean().detach().cpu().data.numpy()
            )

    def train(
        self,
        iterations,
        step: int,
        batch_size=100,
        train_critic: bool = True,
        critic_training_iterations: int = 1,
        entropy_training_only: bool = False,
    ):
        for it in range(iterations):
            # Sample replay buffer / batch

            # # As in MSG, get one batch per model in ensemble
            # (
            #     state_for_ensemble,
            #     action_for_ensemble,
            #     next_state_for_ensemble,
            #     reward_for_ensemble,
            #     not_done_for_ensemble,
            # ) = self.replay_buffer.sample(batch_size * self.critic_ensemble._num_models)

            # state = state_for_ensemble[:batch_size]
            # action = action_for_ensemble[:batch_size]
            # next_state = next_state_for_ensemble[:batch_size]
            # reward = reward_for_ensemble[:batch_size]
            # not_done = not_done_for_ensemble[:batch_size]
            # For now, we're going to do as in REDQ and have the same batch for all models.
            state, action, next_state, reward, not_done = self.replay_buffer.sample(
                batch_size
            )
            with torch.no_grad():
                if self.no_piz:
                    actor_action = self.actor_vae_target.decode(state)
                    if not self.actor_vae.deterministic:
                        actor_action = actor_action.rsample()
                else:
                    latent_action = self.actor_target(state)
                    if not self.no_noise:
                        latent_action += (torch.randn_like(latent_action) * 0.2).clamp(
                            -0.5, 0.5
                        )
                    actor_action = self.actor_vae_target.decode(state, z=latent_action)
                    if not self.actor_vae.deterministic:
                        actor_action = actor_action.mean

            # Critic Training
            if train_critic:
                self.train_critic(
                    iterations=critic_training_iterations,
                    batch_size=batch_size,
                )

            # compute adv and weight
            if self._dataset_noise:
                # TODO: clone the actions and add the exploration noise only to the copy
                state_explore = torch.tile(state, [2, 1])
                action_explore = torch.tile(action, [2, 1])
                action_explore[0 : action.shape[0]] = action + (
                    torch.randn_like(action) * 0.1
                ).clamp(-0.3, 0.3)
                # q_a1, q_a2 = (
                #     self.critic_target(state, action_explore, model_index=0),
                #     self.critic_target(state, action_explore, model_index=1),
                # )
                q_action = torch.mean(
                    self.critic_target(state_explore, action_explore),
                    dim=0,
                )
                # q_a1, q_a2 = (
                #     self.critic_target(state_explore, action_explore, model_index=0),
                #     self.critic_target(state_explore, action_explore, model_index=1),
                # )
                # q_action = torch.min(q_a1, q_a2)
            else:
                state_explore = state
                action_explore = action
                if self._train_v_func:
                    q_action = reward + not_done * self.discount * self.critic_target.v(
                        next_state
                    )
                else:
                    q_action = torch.mean(self.critic_target(state, action), dim=0)
            # action_explore_raw = replay_buffer.unnormalize_action(action_explore.cpu()).clamp(-1, 1)
            # action_explore = replay_buffer.normalize_action(action_explore_raw).to(self.device).float()

            if self._train_v_func:
                current_v = self.critic_target.v(state)
            else:
                current_v = torch.mean(self.critic_target(state, actor_action), dim=0)
            assert len(current_v.shape) == 2, len(current_v.shape)
            if self._dataset_noise:
                current_v = current_v.tile([2, 1])
            adv = q_action - current_v

            w_sign = (adv > 0).float()
            weights = (1 - w_sign) * (1 - self.expectile) + w_sign * self.expectile

            entropy_weights = F.softmax(-adv, dim=0).view(-1, 1) * adv.numel()

            # train weighted CVAE
            recons_action, z_sample, mu, log_var = self.actor_vae(
                state_explore, action_explore
            )

            if self.actor_vae.deterministic:
                recons_loss_ori = F.mse_loss(
                    recons_action, action_explore, reduction="none"
                )
                recon_loss = torch.sum(recons_loss_ori, 1).view(-1, 1)
                std_loss = 0
                entropy = 0
            else:
                if self.actor_vae.std_architecture != "linear_layer_stop_grad":
                    recon_loss = -recons_action.log_prob(action)
                    recon_loss = torch.sum(recon_loss, dim=-1).view(-1, 1)
                    entropy = recons_action.entropy().mean(dim=-1, keepdim=True)
                    assert recon_loss.shape == entropy.shape
                    std_loss = 0
                else:
                    recons_loss_ori = F.mse_loss(
                        recons_action.mean, action_explore, reduction="none"
                    )
                    recon_loss = torch.sum(recons_loss_ori, 1).view(-1, 1)
                    action_distribution = Normal(
                        recons_action.mean.detach(), recons_action.stddev
                    )
                    std_loss = (
                        -action_distribution.log_prob(action_explore)
                        .sum(dim=-1)
                        .view(-1, 1)
                    )
                    entropy = action_distribution.entropy().mean(dim=-1, keepdim=True)
                    assert recon_loss.shape == entropy.shape

            KL_loss = self.kl_loss(mu, log_var)
            recon_loss = recon_loss * weights.detach()
            KL_loss = (KL_loss * self.kl_beta) * weights.detach()
            std_loss = std_loss * weights.detach()
            actor_entropy_loss = (
                self._entropy_term_weight * entropy_weights.detach() * (-entropy)
            )
            if step % 1000 == 0:
                logger.record_tabular(
                    "reconstruction_loss", recon_loss.mean().detach().cpu().data.numpy()
                )
                logger.record_tabular(
                    "KL_loss", KL_loss.mean().detach().cpu().data.numpy()
                )
                logger.record_tabular(
                    "std_loss", std_loss.mean().detach().cpu().data.numpy()
                )
                logger.record_tabular(
                    "actor_entropy_loss",
                    actor_entropy_loss.mean().detach().cpu().data.numpy(),
                )
            if entropy_training_only:
                actor_vae_loss = std_loss + actor_entropy_loss
            else:
                actor_vae_loss = recon_loss + KL_loss + std_loss + actor_entropy_loss
            actor_vae_loss = actor_vae_loss.mean()
            self.actorvae_optimizer.zero_grad()
            actor_vae_loss.backward()
            # Log grad norm
            actor_vae_norm = 0
            for p in self.actor_vae.parameters():
                param_norm = p.grad.detach().data.norm(2)
                actor_vae_norm += param_norm.item() ** 2
            total_actor_vae_norm = actor_vae_norm**0.5
            logger.record_tabular("actor_vae_grad_norm", total_actor_vae_norm)

            torch.nn.utils.clip_grad_norm_(
                self.actor_vae.parameters(), self.gradient_clipping
            )
            self.actorvae_optimizer.step()
            # total_norm = 0.0
            # children = [self.actor_vae, self.actor, self.critic]
            # for child in children:
            #     for p in child.parameters():
            #         if p.grad is not None:
            #             param_norm = p.grad.detach().data.norm(2)
            #             total_norm += param_norm.item() ** 2
            # total_norm = total_norm ** (1.0 / 2)

            if not self.no_piz:
                # train latent policy
                latent_actor_action = self.actor(state)
                actor_action_ori = self.actor_vae.decode(state, z=latent_actor_action)
                if not self.actor_vae.deterministic:
                    # TODO: SAMPLE MULTIPLE TIMES.
                    actor_action_ori = actor_action_ori.rsample()
                actor_action = actor_action_ori + (
                    torch.randn_like(actor_action_ori) * 0.02
                ).clamp(-0.05, 0.05)

                # q_pi = self.critic.q1(state, actor_action)
                q_pi = self.critic_ensemble(state, actor_action)[0]

                actor_loss = -q_pi.mean()
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                # Log grad norm
                actor_norm = 0
                for p in self.actor.parameters():
                    param_norm = p.grad.detach().data.norm(2)
                    actor_norm += param_norm.item() ** 2
                total_actor_norm = actor_norm**0.5
                logger.record_tabular("actor_grad_norm", total_actor_norm)

                torch.nn.utils.clip_grad_norm_(
                    self.actor.parameters(), self.gradient_clipping
                )
                self.actor_optimizer.step()

            # Update Target Networks
            for param, target_param in zip(
                self.actor.parameters(), self.actor_target.parameters()
            ):
                target_param.data.copy_(
                    self.tau * param.data + (1 - self.tau) * target_param.data
                )
            for param, target_param in zip(
                self.actor_vae.parameters(), self.actor_vae_target.parameters()
            ):
                target_param.data.copy_(
                    self.tau_vae * param.data + (1 - self.tau_vae) * target_param.data
                )

    def save(self, filename, directory):
        # torch.save(
        #     self.critic_ensemble.state_dict(),
        #     "%s/%s_critic_ensemble.pth" % (directory, filename),
        # )
        torch.save(
            self.critic_ensemble,
            "%s/%s_critic_ensemble.pth" % (directory, filename),
        )
        torch.save(
            self.critic_optimizer.state_dict(),
            "%s/%s_critic_optimizer.pth" % (directory, filename),
        )
        torch.save(
            self.critic_target,
            "%s/%s_critic_target.pth" % (directory, filename),
        )

        torch.save(self.actor.state_dict(), "%s/%s_actor.pth" % (directory, filename))
        torch.save(
            self.actor_optimizer.state_dict(),
            "%s/%s_actor_optimizer.pth" % (directory, filename),
        )
        torch.save(
            self.actor_target.state_dict(),
            "%s/%s_actor_target.pth" % (directory, filename),
        )

        torch.save(
            self.actor_vae.state_dict(), "%s/%s_actor_vae.pth" % (directory, filename)
        )
        torch.save(
            self.actorvae_optimizer.state_dict(),
            "%s/%s_actor_vae_optimizer.pth" % (directory, filename),
        )
        torch.save(
            self.actor_vae_target.state_dict(),
            "%s/%s_actor_vae_target.pth" % (directory, filename),
        )

    def load(self, filename, directory):
        # self.critic_ensemble.load_state_dict(
        #     torch.load("%s/%s_critic_ensemble.pth" % (directory, filename)),
        #     # strict=False,
        # )
        self.critic_ensemble = torch.load(
            "%s/%s_critic_ensemble.pth" % (directory, filename)
        )

        (
            self.critic_ensemble.combined_model,
            # self.critic_ensemble.combined_params,
            _,
            self.critic_ensemble.combined_buffers,
        ) = combine_state_for_ensemble(self.critic_ensemble.models)
        [p.requires_grad_() for p in self.critic_ensemble.combined_params]
        # self.critic_optimizer.load_state_dict(
        #     torch.load(
        #         "%s/%s_critic_optimizer.pth" % (directory, filename),
        #         map_location="cuda:0",
        #     ),
        # )
        # self.critic_optimizer.param_groups[0][
        #     "params"
        # ] = self.critic_ensemble.combined_params
        self.critic_optimizer = torch.optim.Adam(
            self.critic_ensemble.combined_params,
            lr=self.critic_lr,
        )
        # self.critic_target.load_state_dict(
        #     torch.load(
        #         "%s/%s_critic_target.pth" % (directory, filename), map_location="cuda:0"
        #     ),
        #     strict=False,
        # )
        self.critic_target = torch.load(
            "%s/%s_critic_target.pth" % (directory, filename), map_location="cuda:0"
        )
        (
            self.critic_target.combined_model,
            # self.critic_target.combined_params,
            _,
            self.critic_target.combined_buffers,
        ) = combine_state_for_ensemble(self.critic_target.models)

        self.actor.load_state_dict(
            torch.load(
                "%s/%s_actor.pth" % (directory, filename), map_location="cuda:0"
            ),
        )
        self.actor_optimizer.load_state_dict(
            torch.load(
                "%s/%s_actor_optimizer.pth" % (directory, filename),
                map_location="cuda:0",
            ),
        )
        self.actor_target.load_state_dict(
            torch.load(
                "%s/%s_actor_target.pth" % (directory, filename), map_location="cuda:0"
            ),
        )

        self.actor_vae.load_state_dict(
            torch.load(
                "%s/%s_actor_vae.pth" % (directory, filename), map_location="cuda:0"
            ),
        )
        self.actorvae_optimizer.load_state_dict(
            torch.load(
                "%s/%s_actor_vae_optimizer.pth" % (directory, filename),
                map_location="cuda:0",
            ),
        )
        self.actor_vae_target.load_state_dict(
            torch.load(
                "%s/%s_actor_vae_target.pth" % (directory, filename),
                map_location="cuda:0",
            ),
        )
