from typing import Dict, List, Tuple, Optional
from functools import reduce
import numpy as np
from numpy.core.fromnumeric import partition
from algos.utils import ReplayBuffer
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from logger import logger
from logger import create_stats_ordered_dict
from torch.distributions import Normal
from functorch import combine_state_for_ensemble, vmap

LOG_SIG_MAX = 2
LOG_SIG_MIN = -20


class Actor(nn.Module):
    def __init__(self, state_dim, latent_dim, max_action, device):
        super(Actor, self).__init__()
        hidden_size = (256, 256, 256)

        self.pi1 = nn.Linear(state_dim, hidden_size[0])
        self.pi2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.pi3 = nn.Linear(hidden_size[1], hidden_size[2])
        self.pi4 = nn.Linear(hidden_size[2], latent_dim)

        self.max_action = max_action

    def forward(self, state):
        a = F.relu(self.pi1(state))
        a = F.relu(self.pi2(a))
        a = F.relu(self.pi3(a))
        a = self.pi4(a)
        a = self.max_action * torch.tanh(a)

        return a


class ActorVAE(nn.Module):
    def __init__(
        self,
        state_dim,
        action_dim,
        latent_dim,
        max_action,
        device,
        deterministic: bool = True,
        std_architecture: str = "noise_vector",
        log_sig_max: float = 2.0,
        log_sig_min: float = -20.0,
        constant_std_init: Optional[float] = None,
    ):
        super(ActorVAE, self).__init__()
        hidden_size = (256, 256, 256)

        self.e1 = nn.Linear(state_dim + action_dim, hidden_size[0])
        self.e2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.e3 = nn.Linear(hidden_size[1], hidden_size[2])

        self.mean = nn.Linear(hidden_size[2], latent_dim)
        self.log_var = nn.Linear(hidden_size[2], latent_dim)

        self.d1 = nn.Linear(state_dim + latent_dim, hidden_size[0])
        self.d2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.d3 = nn.Linear(hidden_size[1], hidden_size[2])
        self.d4 = nn.Linear(hidden_size[2], action_dim)

        if not deterministic:
            if std_architecture == "noise_vector":
                self.action_log_std_logits = nn.Parameter(
                    torch.zeros(action_dim, requires_grad=True)
                )
            elif (
                std_architecture == "linear_layer"
                or std_architecture == "linear_layer_stop_grad"
            ):
                self.action_log_std_linear = nn.Linear(hidden_size[2], action_dim)
                if constant_std_init is not None:
                    torch.nn.init.constant_(self.action_log_std_linear.weight, 0)
                    torch.nn.init.constant_(
                        self.action_log_std_linear.bias, constant_std_init
                    )
            elif std_architecture == "mlp_stop_grad":
                self.action_log_std_1 = nn.Linear(hidden_size[2], hidden_size[2])
                self.action_log_std_2 = nn.Linear(hidden_size[2], action_dim)
            else:
                raise NotImplementedError

        self.max_action = max_action
        self.action_dim = action_dim
        self.latent_dim = latent_dim
        self.device = device
        self.deterministic = deterministic
        self.std_architecture = std_architecture
        self._log_sig_max = log_sig_max
        self._log_sig_min = log_sig_min

    def forward(self, state, action):
        z = F.relu(self.e1(torch.cat([state, action], 1)))
        z = F.relu(self.e2(z))
        z = F.relu(self.e3(z))

        mean = self.mean(z)
        log_var = self.log_var(z)
        std = torch.exp(log_var / 2)
        z = mean + std * torch.randn_like(std)

        u = self.decode(state, z)

        return u, z, mean, log_var

    def decode(self, state, z=None, clip=None):
        # When sampling from the VAE, the latent vector is clipped
        if z is None:
            clip = self.max_action
            z = (
                torch.randn((state.shape[0], self.latent_dim))
                .to(self.device)
                .clamp(-clip, clip)
            )

        a = F.relu(self.d1(torch.cat([state, z], 1)))
        a = F.relu(self.d2(a))
        trunk = F.relu(self.d3(a))
        action_mean = self.d4(trunk)
        if not self.deterministic:
            if self.std_architecture == "linear_layer_stop_grad":
                log_std = torch.sigmoid(self.action_log_std_linear(trunk.detach()))
            elif self.std_architecture == "linear_layer":
                log_std = torch.sigmoid(self.action_log_std_linear(trunk))
            elif self.std_architecture == "noise_vector":
                log_std = torch.sigmoid(self.action_log_std_logits)
            elif self.std_architecture == "mlp_stop_grad":
                x = F.relu(self.action_log_std_1(trunk.detach()))
                log_std = torch.sigmoid(self.action_log_std_2(x))
            else:
                raise NotImplementedError
            # log_std = LOG_SIG_MIN + log_std * (LOG_SIG_MAX - LOG_SIG_MIN)
            log_std = self._log_sig_min + log_std * (
                self._log_sig_max - self._log_sig_min
            )
            std = torch.exp(log_std)
            return Normal(action_mean, std)
        return action_mean


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, device, train_v_func: bool = True):
        super(Critic, self).__init__()

        hidden_size = (256, 256, 256)

        self.l1 = nn.Linear(state_dim + action_dim, hidden_size[0])
        self.l2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.l3 = nn.Linear(hidden_size[1], hidden_size[2])
        self.l4 = nn.Linear(hidden_size[2], 1)

        self._train_v_func = train_v_func
        if train_v_func:
            self.v1 = nn.Linear(state_dim, hidden_size[0])
            self.v2 = nn.Linear(hidden_size[0], hidden_size[1])
            self.v3 = nn.Linear(hidden_size[1], hidden_size[2])
            self.v4 = nn.Linear(hidden_size[2], 1)

    # def forward(self, state, action):
    def forward(self, state_and_action):
        q1 = F.relu(self.l1(state_and_action))
        q1 = F.relu(self.l2(q1))
        q1 = F.relu(self.l3(q1))
        q1 = self.l4(q1)
        return q1

    def v(self, state):
        assert self._train_v_func
        v = F.relu(self.v1(state))
        v = F.relu(self.v2(v))
        v = F.relu(self.v3(v))
        v = self.v4(v)
        return v


class CriticEnsemble(nn.Module):
    def __init__(
        self,
        state_dim,
        action_dim,
        device,
        num_models: int = 2,
        train_v_func: bool = True,
    ):
        super(CriticEnsemble, self).__init__()

        hidden_size = (256, 256, 256)

        self._num_models = num_models
        self.models = [
            Critic(state_dim, action_dim, device, train_v_func=False).to(device)
            for _ in range(num_models)
        ]
        (
            self.combined_model,
            self.combined_params,
            self.combined_buffers,
        ) = combine_state_for_ensemble(self.models)
        [p.requires_grad_() for p in self.combined_params]

        self._train_v_func = train_v_func
        if train_v_func:
            self._module_layers["v1"] = nn.Linear(state_dim, hidden_size[0])
            self._module_layers["v2"] = nn.Linear(hidden_size[0], hidden_size[1])
            self._module_layers["v3"] = nn.Linear(hidden_size[1], hidden_size[2])
            self._module_layers["v4"] = nn.Linear(hidden_size[2], 1)

    def forward(
        self, state, action, model_indices: Optional[List[int]] = None
    ) -> List[torch.Tensor]:
        if model_indices is None:
            model_indices = range(self._num_models)
        elif type(model_indices) is int:
            model_indices = [model_indices]
        model_returns = vmap(self.combined_model, in_dims=(0, 0, None))(
            self.combined_params, self.combined_buffers, torch.cat([state, action], 1)
        )
        if len(model_returns) == 1:
            model_returns = model_returns[0]
        return model_returns

    def v(self, state: torch.Tensor) -> torch.Tensor:
        assert self._train_v_func
        v = F.relu(self._module_layers["v1"](state))
        v = F.relu(self._module_layers["v2"](v))
        v = F.relu(self._module_layers["v3"](v))
        v = self._module_layers["v4"](v)
        return v


class Latent(object):
    def __init__(
        self,
        state_dim,
        action_dim,
        latent_dim,
        max_action,
        min_v,
        max_v,
        replay_buffer,
        device,
        discount=0.99,
        tau=0.005,
        vae_lr=1e-4,
        actor_lr=1e-4,
        critic_lr=5e-4,
        max_latent_action=1,
        expectile=0.8,
        kl_beta=1.0,
        no_piz=False,
        no_noise=True,
        doubleq_min=0.8,
        deterministic_actions: bool = True,
        std_architecture: str = "noise_vector",
        log_sig_max: float = 2.0,
        log_sig_min: float = -20.0,
        constant_std_init: Optional[float] = None,
        gradient_clipping: Optional[float] = None,
        num_q_functions: int = 2,
        not_train_v_func: bool = False,
        redq_subset_size: int = 2,
    ):

        self.device = torch.device(device)
        self.actor_vae = ActorVAE(
            state_dim,
            action_dim,
            latent_dim,
            max_latent_action,
            self.device,
            deterministic_actions,
            std_architecture,
            log_sig_max,
            log_sig_min,
            constant_std_init,
        ).to(self.device)
        self.actor_vae_target = copy.deepcopy(self.actor_vae)
        self.actorvae_optimizer = torch.optim.Adam(
            self.actor_vae.parameters(), lr=vae_lr
        )

        self.actor = Actor(state_dim, latent_dim, max_latent_action, self.device).to(
            self.device
        )
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)

        train_v_func = not not_train_v_func
        assert not (no_piz and not_train_v_func)
        assert not (not_train_v_func and not no_noise)
        assert redq_subset_size <= num_q_functions
        self._num_q_functions = num_q_functions
        self._redq_subset_size = redq_subset_size
        self._train_v_func = train_v_func
        # self.critic = Critic(state_dim, action_dim, self.device).to(self.device)
        self.critic_ensemble = CriticEnsemble(
            state_dim,
            action_dim,
            self.device,
            num_models=num_q_functions,
            train_v_func=train_v_func,
        )  # .to(self.device)
        self.critic_target = copy.deepcopy(self.critic_ensemble)
        (
            self.critic_target.combined_model,
            self.critic_target.combined_params,
            self.critic_target.combined_buffers,
        ) = combine_state_for_ensemble(self.critic_target.models)
        self.critic_optimizer = torch.optim.Adam(
            self.critic_ensemble.combined_params,
            lr=critic_lr,
        )
        self.critic_lr = critic_lr
        self.latent_dim = latent_dim
        self.max_action = max_action
        self.max_latent_action = max_latent_action
        self.action_dim = action_dim
        self.discount = discount
        self.tau = tau
        self.tau_vae = tau

        self.expectile = expectile
        self.kl_beta = kl_beta
        self.no_piz = no_piz
        self.no_noise = no_noise
        self.doubleq_min = doubleq_min

        self.replay_buffer = replay_buffer
        self.min_v, self.max_v = min_v, max_v
        self.gradient_clipping = gradient_clipping

    def select_action(
        self, state, deterministic: bool = True, num_actions_to_sample: int = 1
    ) -> Tuple[np.ndarray, Dict]:
        info = {}
        with torch.no_grad():
            state = self.replay_buffer.normalize_state(state)
            state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)

            if self.no_piz:
                action = self.actor_vae.decode(state).cpu().data.numpy().flatten()
            else:
                latent_a = self.actor(state)
                if self.actor_vae.deterministic:
                    assert num_actions_to_sample == 1
                    action = (
                        self.actor_vae.decode(state, z=latent_a)
                        .cpu()
                        .data.numpy()
                        .flatten()
                    )
                else:
                    action_distribution = self.actor_vae.decode(state, z=latent_a)
                    info[
                        "action_distribution_entropy"
                    ] = action_distribution.entropy().mean()
                    info["action_std_dev"] = (
                        action_distribution.stddev.mean().cpu().data.numpy()
                    )
                    if deterministic:
                        assert num_actions_to_sample == 1
                        action = action_distribution.mean.cpu().data.numpy().flatten()
                    else:
                        if num_actions_to_sample <= 1:
                            action = (
                                action_distribution.rsample()
                                .cpu()
                                .data.numpy()
                                .flatten()
                            )
                        else:
                            action_samples = action_distribution.rsample(
                                [num_actions_to_sample]
                            ).view(num_actions_to_sample, self.action_dim)
                            action_q_values = torch.mean(
                                self.critic_target(
                                    state.repeat(num_actions_to_sample, 1),
                                    action_samples,
                                ),
                                dim=0,
                            ).view(num_actions_to_sample)
                            action_index = torch.argmax(action_q_values)
                            action = (
                                action_samples[action_index]
                                .cpu()
                                .data.numpy()
                                .flatten()
                            )

            action = self.replay_buffer.unnormalize_action(action)
            action = np.clip(action, -1, 1)

        return action, info

    def kl_loss(self, mu, log_var):
        KL_loss = -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1).view(
            -1, 1
        )
        return KL_loss

    def train(
        self,
        iterations,
        batch_size=100,
        kl_beta_list=None,
        critic_training_iterations: int = 1,
    ):
        for it in range(iterations):
            # Sample replay buffer / batch

            # state = state_for_ensemble[:batch_size]
            # action = action_for_ensemble[:batch_size]
            # next_state = next_state_for_ensemble[:batch_size]
            # reward = reward_for_ensemble[:batch_size]
            # not_done = not_done_for_ensemble[:batch_size]
            # For now, we're going to do as in REDQ and have the same batch for all models.
            self.train_critic(
                iterations=critic_training_iterations, batch_size=batch_size
            )
            state, action, next_state, reward, not_done = self.replay_buffer.sample(
                batch_size
            )
            with torch.no_grad():
                if self.no_piz:
                    actor_action = self.actor_vae_target.decode(state)
                    if not self.actor_vae.deterministic:
                        actor_action = actor_action.rsample()
                else:
                    latent_action = self.actor_target(state)
                    actor_action = self.actor_vae_target.decode(state, z=latent_action)
                    if not self.actor_vae.deterministic:
                        # actor_action = actor_action.rsample()
                        actor_action = actor_action.mean
                    if not self.no_noise:
                        actor_action += (torch.randn_like(actor_action) * 0.1).clamp(
                            -0.3, 0.3
                        )

            # compute adv and weight
            if self._train_v_func:
                current_v = self.critic_target.v(state)
                q_action = reward + not_done * self.discount * self.critic_target.v(
                    next_state
                )
            else:
                current_v = torch.mean(self.critic_target(state, actor_action), dim=0)
                q_action = torch.mean(self.critic_target(state, action), dim=0)

            adv = q_action - current_v

            w_sign = (adv > 0).float()
            weights = (1 - w_sign) * (1 - self.expectile) + w_sign * self.expectile

            # train weighted CVAE
            recons_action, z_sample, mu, log_var = self.actor_vae(state, action)

            if self.actor_vae.deterministic:
                recons_loss_ori = F.mse_loss(recons_action, action, reduction="none")
                recon_loss = torch.sum(recons_loss_ori, 1).view(-1, 1)
                std_loss = 0
            else:
                if self.actor_vae.std_architecture != "linear_layer_stop_grad":
                    recon_loss = -recons_action.log_prob(action)
                    recon_loss = torch.sum(recon_loss, dim=-1).view(-1, 1)
                    std_loss = 0
                else:
                    recons_loss_ori = F.mse_loss(
                        recons_action.mean, action, reduction="none"
                    )
                    recon_loss = torch.sum(recons_loss_ori, 1).view(-1, 1)
                    action_distribution = Normal(
                        recons_action.mean.detach(), recons_action.stddev
                    )
                    std_loss = (
                        -action_distribution.log_prob(action).sum(dim=-1).view(-1, 1)
                    )

            KL_loss = self.kl_loss(mu, log_var)
            recon_loss = recon_loss * weights.detach()
            KL_loss = (KL_loss * self.kl_beta) * weights.detach()
            std_loss = std_loss * weights.detach()
            logger.record_tabular(
                "reconstruction_loss", recon_loss.mean().detach().cpu().data.numpy()
            )
            logger.record_tabular("KL_loss", KL_loss.mean().detach().cpu().data.numpy())
            logger.record_tabular(
                "std_loss", std_loss.mean().detach().cpu().data.numpy()
            )
            actor_vae_loss = recon_loss + KL_loss + std_loss

            actor_vae_loss = actor_vae_loss.mean()
            self.actorvae_optimizer.zero_grad()
            actor_vae_loss.backward()
            # Log grad norm
            actor_vae_norm = 0
            for p in self.actor_vae.parameters():
                param_norm = p.grad.detach().data.norm(2)
                actor_vae_norm += param_norm.item() ** 2
            total_actor_vae_norm = actor_vae_norm**0.5
            logger.record_tabular("actor_vae_grad_norm", total_actor_vae_norm)

            torch.nn.utils.clip_grad_norm_(
                self.actor_vae.parameters(), self.gradient_clipping
            )
            self.actorvae_optimizer.step()

            if not self.no_piz:
                # train latent policy
                latent_actor_action = self.actor(state)
                actor_action = self.actor_vae.decode(state, z=latent_actor_action)
                if not self.actor_vae.deterministic:
                    actor_action = actor_action.rsample()
                actor_action += (torch.randn_like(actor_action) * 0.02).clamp(
                    -0.05, 0.05
                )

                # q_pi = self.critic_ensemble(state, actor_action, model_indices=0)
                q_pi = self.critic_ensemble(state, actor_action)[0]

                actor_loss = -q_pi.mean()
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                # Log grad norm
                actor_norm = 0
                for p in self.actor.parameters():
                    param_norm = p.grad.detach().data.norm(2)
                    actor_norm += param_norm.item() ** 2
                total_actor_norm = actor_norm**0.5
                logger.record_tabular("actor_grad_norm", total_actor_norm)

                torch.nn.utils.clip_grad_norm_(
                    self.actor.parameters(), self.gradient_clipping
                )
                self.actor_optimizer.step()

            # Update Target Networks
            for param, target_param in zip(
                self.actor.parameters(), self.actor_target.parameters()
            ):
                target_param.data.copy_(
                    self.tau * param.data + (1 - self.tau) * target_param.data
                )
            for param, target_param in zip(
                self.actor_vae.parameters(), self.actor_vae_target.parameters()
            ):
                target_param.data.copy_(
                    self.tau_vae * param.data + (1 - self.tau_vae) * target_param.data
                )

    def train_critic(self, iterations: int, batch_size: int):
        for it in range(iterations):
            state, action, next_state, reward, not_done = self.replay_buffer.sample(
                batch_size
            )

            # Critic Training
            with torch.no_grad():
                if self.no_piz:
                    actor_action = self.actor_vae_target.decode(state)
                    if not self.actor_vae.deterministic:
                        actor_action = actor_action.rsample()
                else:
                    latent_action = self.actor_target(state)
                    actor_action = self.actor_vae_target.decode(state, z=latent_action)
                    if not self.actor_vae.deterministic:
                        # actor_action = actor_action.rsample()
                        actor_action = actor_action.mean
                    if not self.no_noise:
                        actor_action += (torch.randn_like(actor_action) * 0.1).clamp(
                            -0.3, 0.3
                        )

                if self._train_v_func:
                    next_target_v = self.critic_target.v(next_state)
                else:
                    # Get next actor action
                    if self.no_piz:
                        next_actor_action = self.actor_vae_target.decode(next_state)
                        if not self.actor_vae.deterministic:
                            next_actor_action = next_actor_action.rsample()
                    else:
                        next_latent_action = self.actor_target(next_state)
                        next_actor_action = self.actor_vae_target.decode(
                            next_state, z=next_latent_action
                        )
                        if not self.actor_vae.deterministic:
                            # actor_action = actor_action.rsample()
                            next_actor_action = next_actor_action.mean
                    # Since we're not training a V-function, we place the
                    # min operation from REDQ here.
                    model_indices = np.random.choice(
                        self._num_q_functions, self._redq_subset_size, replace=False
                    )
                    next_target_vs = self.critic_target(
                        next_state,
                        next_actor_action,
                        # model_indices=model_indices,
                    )
                    next_target_vs = next_target_vs[model_indices]
                    next_target_v = torch.min(
                        next_target_vs,
                        dim=0,
                    ).values * self.doubleq_min + torch.max(
                        next_target_vs, dim=0
                    ).values * (
                        1 - self.doubleq_min
                    )
                target_Q = reward + not_done * self.discount * next_target_v
                # target_Qs = [
                #     reward_for_ensemble[
                #         batch_size * model_index : batch_size * model_index + batch_size
                #     ]
                #     + not_done_for_ensemble[
                #         batch_size * model_index : batch_size * model_index + batch_size
                #     ]
                #     * self.discount
                #     * next_target_v[
                #         batch_size * model_index : batch_size * model_index + batch_size
                #     ]
                #     for model_index in range(self.critic_ensemble._num_models)
                # ]

                if self._train_v_func:
                    target_v1, target_v2 = (
                        self.critic_target(state, actor_action, model_indices=0),
                        self.critic_target(state, actor_action, model_indices=1),
                    )
                    target_v = torch.min(
                        target_v1, target_v2
                    ) * self.doubleq_min + torch.max(target_v1, target_v2) * (
                        1 - self.doubleq_min
                    )
            current_Qs = self.critic_ensemble(state, action)
            if self._train_v_func:
                current_v = self.critic_ensemble.v(state)

                v_loss = F.mse_loss(current_v, target_v.clamp(self.min_v, self.max_v))
            else:
                v_loss = 0

            q_losses = [
                F.mse_loss(
                    current_Qs[model_index], target_Q.clamp(self.min_v, self.max_v)
                )
                for model_index in range(self.critic_ensemble._num_models)
            ]
            for q_index, q_loss in enumerate(q_losses):
                if q_index > 4:
                    break
                logger.record_tabular(
                    f"q_loss_{q_index}", q_losses[q_index].detach().cpu().data.numpy()
                )
            if self._train_v_func:
                logger.record_tabular("v_loss", v_loss.detach().cpu().data.numpy())
            critic_loss = sum(q_losses) + v_loss

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            # Log grad norm
            critic_norm = 0
            # for p in self.critic_ensemble.parameters():
            for p in self.critic_ensemble.combined_params:
                param_norm = p.grad.detach().data.norm(2)
                critic_norm += param_norm.item() ** 2
            total_critic_norm = critic_norm**0.5
            logger.record_tabular("critic_grad_norm", total_critic_norm)

            torch.nn.utils.clip_grad_norm_(
                self.critic_ensemble.parameters(), self.gradient_clipping
            )
            self.critic_optimizer.step()

            # Update Target Networks
            for param, target_param in zip(
                self.critic_ensemble.combined_params, self.critic_target.combined_params
            ):
                target_param.data.copy_(
                    self.tau * param.data + (1 - self.tau) * target_param.data
                )
            logger.record_tabular(
                "target_q_norm", target_Q.norm().detach().cpu().data.numpy()
            )

    def save(self, filename, directory):
        # torch.save(
        #     self.critic_ensemble.state_dict(),
        #     "%s/%s_critic_ensemble.pth" % (directory, filename),
        # )
        torch.save(
            self.critic_ensemble, "%s/%s_critic_ensemble.pth" % (directory, filename)
        )
        torch.save(
            self.critic_optimizer.state_dict(),
            "%s/%s_critic_optimizer.pth" % (directory, filename),
        )
        torch.save(
            self.critic_target,
            "%s/%s_critic_target.pth" % (directory, filename),
        )

        torch.save(self.actor.state_dict(), "%s/%s_actor.pth" % (directory, filename))
        torch.save(
            self.actor_optimizer.state_dict(),
            "%s/%s_actor_optimizer.pth" % (directory, filename),
        )
        torch.save(
            self.actor_target.state_dict(),
            "%s/%s_actor_target.pth" % (directory, filename),
        )

        torch.save(
            self.actor_vae.state_dict(), "%s/%s_actor_vae.pth" % (directory, filename)
        )
        torch.save(
            self.actorvae_optimizer.state_dict(),
            "%s/%s_actor_vae_optimizer.pth" % (directory, filename),
        )
        torch.save(
            self.actor_vae_target.state_dict(),
            "%s/%s_actor_vae_target.pth" % (directory, filename),
        )

    def load(self, filename, directory):
        self.critic_ensemble.load_state_dict(
            torch.load("%s/%s_critic_ensemble.pth" % (directory, filename))
        )
        self.critic_optimizer.load_state_dict(
            torch.load("%s/%s_critic_optimizer.pth" % (directory, filename))
        )
        self.critic_target.load_state_dict(
            torch.load("%s/%s_critic_target.pth" % (directory, filename))
        )

        self.actor.load_state_dict(
            torch.load("%s/%s_actor.pth" % (directory, filename))
        )
        self.actor_optimizer.load_state_dict(
            torch.load("%s/%s_actor_optimizer.pth" % (directory, filename))
        )
        self.actor_target.load_state_dict(
            torch.load("%s/%s_actor_target.pth" % (directory, filename))
        )

        self.actor_vae.load_state_dict(
            torch.load("%s/%s_actor_vae.pth" % (directory, filename))
        )
        self.actorvae_optimizer.load_state_dict(
            torch.load("%s/%s_actor_vae_optimizer.pth" % (directory, filename))
        )
        self.actor_vae_target.load_state_dict(
            torch.load("%s/%s_actor_vae_target.pth" % (directory, filename))
        )
