import copy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.distributions import Normal, TanhTransform, TransformedDistribution
from scipy import stats
from copy import deepcopy
import wandb


EXP_ADV_MAX = 100.0
LOG_STD_MIN = -20.0
LOG_STD_MAX = 2.0
TensorBatch = List[torch.Tensor]


def asymmetric_l2_loss(u: torch.Tensor, tau: float) -> torch.Tensor:
    return torch.mean(torch.abs(tau - (u < 0).float()) * u**2)


def soft_update(target: nn.Module, source: nn.Module, tau: float):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)


def extend_and_repeat(tensor: torch.Tensor, dim: int, repeat: int) -> torch.Tensor:
    return tensor.unsqueeze(dim).repeat_interleave(repeat, dim=dim)


def init_module_weights(module: torch.nn.Module, orthogonal_init: bool = False):
    if isinstance(module, nn.Linear):
        if orthogonal_init:
            nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
            nn.init.constant_(module.bias, 0.0)
        else:
            nn.init.xavier_uniform_(module.weight, gain=1e-2)


class Scalar(nn.Module):
    def __init__(self, init_value: float):
        super().__init__()
        self.constant = nn.Parameter(torch.tensor(init_value, dtype=torch.float32))

    def forward(self) -> nn.Parameter:
        return self.constant


class ReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        device: str = "cpu",
        incre: bool = False,
    ):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0
        self.init_size = 2000
        self.incre = incre

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        if self.incre:
            self._weights = torch.ones((buffer_size, 3), dtype=torch.float32, device=device)
        self._device = device
        self.eps = 1e-3
        self.ood_pointer = 0
        self.ood_size = 10000
        self._update_size = 500
        self.l_pointer = 0
        self.update_interval = 10000

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32).to(self._device)

    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        self.n_transitions = n_transitions
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)
        self.l_pointer = self._pointer

        print(f"Dataset size: {n_transitions}")

    def add_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        print(self._pointer)
        n_transitions = data["observations"].shape[0]
        if n_transitions + self._pointer > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[self._pointer:self._pointer + n_transitions] = self._to_tensor(data["observations"])
        self._actions[self._pointer:self._pointer + n_transitions] = self._to_tensor(data["actions"])
        self._rewards[self._pointer:self._pointer + n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[self._pointer:self._pointer + n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[self._pointer:self._pointer + n_transitions] = self._to_tensor(data["terminals"][..., None])

        self._size += n_transitions
        self._pointer = min(self._size, self._pointer + n_transitions)
        self.l_pointer = self._pointer
        self.length = self._pointer + n_transitions
        print(f"Dataset size: {n_transitions}")
        print(self._pointer)
        print("Added success.\n")

    def is_update_intervals(self):
        if self._pointer - self.l_pointer == self.update_interval or\
           self._pointer + self._buffer_size - self.l_pointer == self.update_interval:
            return True
        else:
            return False

    def filter_resolved_samples(self):
        print("Buffer pointer:", self._pointer, "Left pointer:", self.l_pointer)
        if self._pointer + self._buffer_size - self.l_pointer == self.update_interval:
            print("Enrolling the buffer once.")
            ood_indices_1 = (self._weights[self.l_pointer:self._buffer_size, 0] == 0).nonzero() + self.l_pointer
            ood_indices_2 = (self._weights[0:self._pointer, 0] == 0).nonzero()
            print(ood_indices_1.shape, ood_indices_2.shape)
            ood_indices = torch.cat([ood_indices_1, ood_indices_2]).squeeze()
        elif self._pointer - self.l_pointer == self.update_interval:
            ood_indices = (self._weights[self.l_pointer:self._pointer, 0] == 0).nonzero().squeeze() + self.l_pointer
        else:
            raise ValueError("The buffer size is too small and",
                             "the pointer enrolls the whole buffer more than one time.")

        print(ood_indices.shape)
        states = self._states[ood_indices, :].squeeze()
        actions = self._actions[ood_indices, :].squeeze()
        rewards = self._rewards[ood_indices, :]
        next_states = self._next_states[ood_indices, :].squeeze()
        dones = self._dones[ood_indices, :]
        weights = self._weights[ood_indices, :].squeeze()
        if len(weights.shape) == 1:
            weights = weights.unsqueeze(0)
            states = states.unsqueeze(0)
            actions = actions.unsqueeze(0)
            next_states = next_states.unsqueeze(0)
            rewards = rewards.unsqueeze(0)
            dones = dones.unsqueeze(0)
        resolved_indices = torch.argsort(weights[:, 2])[:self._update_size]
        print(resolved_indices.shape)
        print('rewards and dones', rewards.shape, dones.shape)
        states = states[resolved_indices, :]
        actions = actions[resolved_indices, :]
        rewards = rewards[resolved_indices, :]
        next_states = next_states[resolved_indices, :]
        dones = dones[resolved_indices, :]

        return [states.squeeze(), actions.squeeze(), rewards.squeeze(), next_states.squeeze(), dones.squeeze()]

    def update_left_pointer(self):
        self.l_pointer = self._pointer

    def filter_in_distribution(self):
        try:
            in_indices = (self._weights[self.ood_pointer * self.ood_size:(self.ood_pointer + 1) * self.ood_size, 0] != 0).nonzero()
        except ValueError:
            raise ValueError("OOD indices out of index")
        states = self._states[in_indices, :].squeeze()
        actions = self._states[in_indices, :].squeeze()
        next_states = self._states[in_indices, :].squeeze()

        return [states.squeeze(), actions.squeeze(), next_states.squeeze()]

    def weighted_random_sample(self, weights, batch_size):
        this_weights = weights.cpu().squeeze().numpy()[:self._size]
        Ni = np.sum(this_weights) + self.eps
        w_i = this_weights / Ni
        result = np.random.choice(np.arange(0, self._size), size=batch_size, replace=True, p=w_i)

        return result

    def sample(self, batch_size: int) -> TensorBatch:
        if self.incre:
            indices = np.random.randint(0, self._size, size=batch_size)
        else:
            indices = np.random.randint(0, self._size, size=batch_size)
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        return [states, actions, rewards, next_states, dones]

    def add_transition(
        self,
        state: np.ndarray,
        action: np.ndarray,
        reward: float,
        next_state: np.ndarray,
        done: bool,
        weight: 0,
    ):
        # Use this method to add new data into the replay buffer during fine-tuning.
        self._states[self._pointer] = self._to_tensor(state)
        self._actions[self._pointer] = self._to_tensor(action)
        self._rewards[self._pointer] = self._to_tensor(reward)
        self._next_states[self._pointer] = self._to_tensor(next_state)
        self._dones[self._pointer] = self._to_tensor(done)
        self._weights[self._pointer] = self._to_tensor(weight)

        self._pointer = (self._pointer + 1) % self._buffer_size
        self._size = min(self._size + 1, self._buffer_size)
        # raise NotImplementedError

    def initialize(self, env, trainer, config, max_action, max_steps):
        state, done = env.reset(), False
        while self._size < self.init_size:
            episode_step = 0
            while not done:
                episode_step += 1
                action, _ = trainer.actor(
                    torch.tensor(
                        state.reshape(1, -1), device=config.device, dtype=torch.float32
                    )
                )
                action = action.cpu().data.numpy().flatten()
                next_state, reward, done, env_infos = env.step(action)
                real_done = False  # Episode can timeout which is different from done
                if done and episode_step < max_steps:
                    real_done = True
                weights = trainer.get_weights(self._to_tensor(state).unsqueeze(0), self._to_tensor(action).unsqueeze(0),
                                              reward, self._to_tensor(next_state).unsqueeze(0),
                                              self._to_tensor(np.array(real_done)).unsqueeze(0))
                self.add_transition(state, action, reward, next_state, done, weights)
                state = next_state
            if done:
                state, done = env.reset(), False

    def __len__(self):
        return self.n_transitions

    def __getitem__(self, indices):
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        return [states, actions, rewards, next_states, dones]


class VAE(nn.Module):
    """
        Variational Auto-Encoder

        Args:
            obs_dim (int): The dimension of the observation space.
            act_dim (int): The dimension of the action space.
            hidden_size (int): The number of hidden units in the encoder and decoder networks, default=64.
            latent_dim (int): The dimensionality of the latent space.
            act_lim (float): The upper limit of the action space.
            device (str): The device to use for computation (cpu or cuda).
        """

    def __init__(self, obs_dim, act_dim, hidden_size, latent_dim, act_lim, obs_lim, device='cuda'):
        super(VAE, self).__init__()
        self.device = device
        self.latent_dim = latent_dim
        self.e1 = nn.Linear(obs_dim + act_dim, hidden_size).to(self.device)
        self.e2 = nn.Linear(hidden_size, hidden_size).to(self.device)

        self.mean = nn.Linear(hidden_size, latent_dim).to(self.device)
        self.log_std = nn.Linear(hidden_size, latent_dim).to(self.device)

        self.d1 = nn.Linear(obs_dim + act_dim + latent_dim, hidden_size).to(self.device)
        self.d2 = nn.Linear(hidden_size, hidden_size).to(self.device)
        self.d3 = nn.Linear(hidden_size, obs_dim).to(self.device)

        self.act_lim = act_lim
        self.obs_lim = obs_lim

    def forward(self, obs, act, next_obs):
        z, mean, std = self.encoder(obs, act, next_obs)
        u = self.decode(obs, act, next_obs, z)
        return u, mean, std

    def encoder(self, obs, act, next_obs):
        z = F.relu(self.e1(torch.cat([obs, act], 1)))
        z = F.relu(self.e2(z))

        mean = self.mean(z)
        # clamp for numerical stability
        log_std = self.log_std(z).clamp(-4, 15)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        return z, mean, std

    def decode(self, obs, act, next_obs, z=None):
        if z is None:
            z = torch.randn((obs.shape[0], self.latent_dim)).clamp(-0.5, 0.5).to(self.device)
        s = F.relu(self.d1(torch.cat([obs, act, z], 1)))
        s = F.relu(self.d2(s))

        return self.d3(s)

    # for BEARL only
    def decode_multiple(self, obs, z=None, num_decode=10):
        if z is None:
            z = torch.randn(
                (obs.shape[0], num_decode, self.latent_dim)).clamp(-0.5,
                                                                   0.5).to(self.device)

        a = F.relu(
            self.d1(
                torch.cat(
                    [obs.unsqueeze(0).repeat(num_decode, 1, 1).permute(1, 0, 2), z], 2)))
        a = F.relu(self.d2(a))
        return torch.tanh(self.d3(a)), self.d3(a)

    def test_one_step(self, observations, next_observations, actions, rewards, done):
        # update VAE
        with torch.no_grad():
            _, mean, std = self.encoder(observations, actions, next_observations)
            return mean, std

    def judge_normal_distribution(self, replay_buffer):
        data_mean_list = []
        data_std_list = []
        for i in range(len(replay_buffer)):
            batch = replay_buffer[i]
            (
                observations,
                actions,
                rewards,
                next_observations,
                dones,
            ) = batch
            observations = observations.unsqueeze(0)
            actions = actions.unsqueeze(0)
            next_observations = next_observations.unsqueeze(0)
            # observations = observations.unsqueeze(0)
            mean, std = self.test_one_step(observations, next_observations, actions, rewards, dones)
            data_mean_list.extend(mean.cpu().squeeze())
            data_std_list.extend(std.cpu().squeeze())
        import fitter
        mean_f = fitter.Fitter(data_mean_list, distributions=['norm'], timeout=10000)
        mean_f.fit()
        mean_dict = mean_f.fitted_param['norm']

        std_f = fitter.Fitter(data_std_list, distributions=['norm'], timeout=10000)
        std_f.fit()
        std_dict = std_f.fitted_param['norm']
        print(mean_dict, std_dict)

        return mean_dict, std_dict


class Squeeze(nn.Module):
    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.squeeze(dim=self.dim)


class MLP(nn.Module):
    def __init__(
        self,
        dims,
        activation_fn: Callable[[], nn.Module] = nn.ReLU,
        output_activation_fn: Callable[[], nn.Module] = None,
        squeeze_output: bool = False,
        dropout: float = 0.0,
    ):
        super().__init__()
        n_dims = len(dims)
        if n_dims < 2:
            raise ValueError("MLP requires at least two dims (input and output)")

        layers = []
        for i in range(n_dims - 2):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            layers.append(activation_fn())
            if dropout > 0.0:
                layers.append(nn.Dropout(dropout))
        layers.append(nn.Linear(dims[-2], dims[-1]))
        if output_activation_fn is not None:
            layers.append(output_activation_fn())
        if squeeze_output:
            if dims[-1] != 1:
                raise ValueError("Last dim must be 1 when squeezing")
            layers.append(Squeeze(-1))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class FullyConnectedQFunction(nn.Module):
    def __init__(
        self,
        observation_dim: int,
        action_dim: int,
        orthogonal_init: bool = False,
        n_hidden_layers: int = 2,
    ):
        super().__init__()
        self.observation_dim = observation_dim
        self.action_dim = action_dim
        self.orthogonal_init = orthogonal_init

        layers = [
            nn.Linear(observation_dim + action_dim, 256),
            nn.ReLU(),
        ]
        for _ in range(n_hidden_layers - 1):
            layers.append(nn.Linear(256, 256))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(256, 1))

        self.network = nn.Sequential(*layers)

        if orthogonal_init:
            self.network.apply(lambda m: init_module_weights(m, True))
        else:
            init_module_weights(self.network[-1], False)

    def forward(self, observations: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        multiple_actions = False
        batch_size = observations.shape[0]
        if actions.ndim == 3 and observations.ndim == 2:
            multiple_actions = True
            observations = extend_and_repeat(observations, 1, actions.shape[1]).reshape(
                -1, observations.shape[-1]
            )
            actions = actions.reshape(-1, actions.shape[-1])
        input_tensor = torch.cat([observations, actions], dim=-1)
        q_values = torch.squeeze(self.network(input_tensor), dim=-1)
        if multiple_actions:
            q_values = q_values.reshape(batch_size, -1)
        return q_values


class ReparameterizedTanhGaussian(nn.Module):
    def __init__(
        self, log_std_min: float = -20.0, log_std_max: float = 2.0, no_tanh: bool = False
    ):
        super().__init__()
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        self.no_tanh = no_tanh

    def log_prob(
        self, mean: torch.Tensor, log_std: torch.Tensor, sample: torch.Tensor
    ) -> torch.Tensor:
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        std = torch.exp(log_std)
        if self.no_tanh:
            action_distribution = Normal(mean, std)
        else:
            action_distribution = TransformedDistribution(
                Normal(mean, std), TanhTransform(cache_size=1)
            )
        return torch.sum(action_distribution.log_prob(sample), dim=-1)

    def forward(
        self, mean: torch.Tensor, log_std: torch.Tensor, deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        std = torch.exp(log_std)

        if self.no_tanh:
            action_distribution = Normal(mean, std)
        else:
            action_distribution = TransformedDistribution(
                Normal(mean, std), TanhTransform(cache_size=1)
            )

        if deterministic:
            action_sample = torch.tanh(mean)
        else:
            action_sample = action_distribution.rsample()

        log_prob = torch.sum(action_distribution.log_prob(action_sample), dim=-1)

        return action_sample, log_prob


class TanhGaussianPolicy(nn.Module):
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        max_action: float,
        log_std_multiplier: float = 1.0,
        log_std_offset: float = -1.0,
        orthogonal_init: bool = False,
        no_tanh: bool = False,
    ):
        super().__init__()
        self.observation_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.orthogonal_init = orthogonal_init
        self.no_tanh = no_tanh

        self.base_network = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 2 * action_dim),
        )

        if orthogonal_init:
            self.base_network.apply(lambda m: init_module_weights(m, True))
        else:
            init_module_weights(self.base_network[-1], False)

        self.log_std_multiplier = Scalar(log_std_multiplier)
        self.log_std_offset = Scalar(log_std_offset)
        self.tanh_gaussian = ReparameterizedTanhGaussian(no_tanh=no_tanh)

    def log_prob(
        self, observations: torch.Tensor, actions: torch.Tensor
    ) -> torch.Tensor:
        if actions.ndim == 3:
            observations = extend_and_repeat(observations, 1, actions.shape[1])
        base_network_output = self.base_network(observations)
        mean, log_std = torch.split(base_network_output, self.action_dim, dim=-1)
        log_std = self.log_std_multiplier() * log_std + self.log_std_offset()
        _, log_probs = self.tanh_gaussian(mean, log_std, False)
        return log_probs

    def forward(
        self,
        observations: torch.Tensor,
        deterministic: bool = False,
        repeat: bool = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if repeat is not None:
            observations = extend_and_repeat(observations, 1, repeat)
        base_network_output = self.base_network(observations)
        mean, log_std = torch.split(base_network_output, self.action_dim, dim=-1)
        log_std = self.log_std_multiplier() * log_std + self.log_std_offset()
        actions, log_probs = self.tanh_gaussian(mean, log_std, deterministic)
        return self.max_action * actions, log_probs

    @torch.no_grad()
    def act(self, state: np.ndarray, device: str = "cpu"):
        state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32)
        with torch.no_grad():
            actions, _ = self(state, not self.training)
        return actions.cpu().data.numpy().flatten()


class ContinuousCQL:
    def __init__(
        self,
        critic_1,
        critic_1_optimizer,
        critic_2,
        critic_2_optimizer,
        actor,
        actor_optimizer,
        state_dim,
        action_dim,
        max_action,
        vae_lr,
        target_entropy: float,
        discount: float = 0.99,
        alpha_multiplier: float = 1.0,
        use_automatic_entropy_tuning: bool = True,
        backup_entropy: bool = False,
        policy_lr: bool = 3e-4,
        qf_lr: bool = 3e-4,
        soft_target_update_rate: float = 5e-3,
        bc_steps=100000,
        target_update_period: int = 1,
        cql_n_actions: int = 10,
        cql_importance_sample: bool = True,
        cql_lagrange: bool = False,
        cql_target_action_gap: float = -1.0,
        cql_temp: float = 1.0,
        cql_alpha: float = 5.0,
        cql_max_target_backup: bool = False,
        cql_clip_diff_min: float = -np.inf,
        cql_clip_diff_max: float = np.inf,
        device: str = "cpu",
        vae_hidden_dim: int = 400,
        latent_dim: int = 1,
        normalize_mean: float = 0,
        normalize_std: float = 1.0
    ):
        super().__init__()

        self.discount = discount
        self.target_entropy = target_entropy
        self.alpha_multiplier = alpha_multiplier
        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        self.backup_entropy = backup_entropy
        self.policy_lr = policy_lr
        self.qf_lr = qf_lr
        self.soft_target_update_rate = soft_target_update_rate
        self.bc_steps = bc_steps
        self.target_update_period = target_update_period
        self.cql_n_actions = cql_n_actions
        self.cql_importance_sample = cql_importance_sample
        self.cql_lagrange = cql_lagrange
        self.cql_target_action_gap = cql_target_action_gap
        self.cql_temp = cql_temp
        self.cql_alpha = cql_alpha
        self.cql_max_target_backup = cql_max_target_backup
        self.cql_clip_diff_min = cql_clip_diff_min
        self.cql_clip_diff_max = cql_clip_diff_max
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.vae_hidden_dim = vae_hidden_dim
        self.latent_dim = state_dim + action_dim
        self._device = device
        self.normalize_mean = normalize_mean
        self.normalize_std = normalize_std
        self.total_it = 0

        self.critic_1 = critic_1
        self.critic_2 = critic_2

        self.off_target_critic_1 = copy.deepcopy(critic_1).to(device)
        self.off_target_critic_2 = copy.deepcopy(critic_2).to(device)

        self.target_critic_1 = deepcopy(self.critic_1).to(device)
        self.target_critic_2 = deepcopy(self.critic_2).to(device)

        self.actor = actor
        self.pm = 0.7

        self.actor_optimizer = actor_optimizer
        self.critic_1_optimizer = critic_1_optimizer
        self.critic_2_optimizer = critic_2_optimizer

        self.vae = VAE(self.state_dim, self.action_dim, self.vae_hidden_dim,
                       self.latent_dim, self.max_action, self._device).to(self._device)
        self.vae_optim = torch.optim.Adam(self.vae.parameters(), lr=vae_lr)
        self.prob_mean, self.prob_std = (), ()
        self.offline_time = 10000
        self.count = 0
        self.vae_batch_size = 64
        self.vae_step = 0

        if self.use_automatic_entropy_tuning:
            self.log_alpha = Scalar(0.0)
            self.alpha_optimizer = torch.optim.Adam(
                self.log_alpha.parameters(),
                lr=self.policy_lr,
            )
        else:
            self.log_alpha = None

        self.log_alpha_prime = Scalar(1.0)
        self.alpha_prime_optimizer = torch.optim.Adam(
            self.log_alpha_prime.parameters(),
            lr=self.qf_lr,
        )

        self.total_it = 0

    def update_target_network(self, soft_target_update_rate: float):
        soft_update(self.target_critic_1, self.critic_1, soft_target_update_rate)
        soft_update(self.target_critic_2, self.critic_2, soft_target_update_rate)

    def _alpha_and_alpha_loss(self, observations: torch.Tensor, log_pi: torch.Tensor):
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(
                self.log_alpha() * (log_pi + self.target_entropy).detach()
            ).mean()
            alpha = self.log_alpha().exp() * self.alpha_multiplier
        else:
            alpha_loss = observations.new_tensor(0.0)
            alpha = observations.new_tensor(self.alpha_multiplier)
        return alpha, alpha_loss

    def _policy_loss(
        self,
        observations: torch.Tensor,
        actions: torch.Tensor,
        new_actions: torch.Tensor,
        alpha: torch.Tensor,
        log_pi: torch.Tensor,
    ) -> torch.Tensor:
        if self.total_it <= self.bc_steps:
            log_probs = self.actor.log_prob(observations, actions)
            policy_loss = (alpha * log_pi - log_probs).mean()
        else:
            q_new_actions = torch.min(
                self.critic_1(observations, new_actions),
                self.critic_2(observations, new_actions),
            )
            policy_loss = (alpha * log_pi - q_new_actions).mean()
        return policy_loss

    def get_prob(self, observations, actions, next_observations):
        multiple_actions = False
        batch_size = observations.shape[0]
        if actions.ndim == 3 and observations.ndim == 2:
            multiple_actions = True
            observations = extend_and_repeat(observations, 1, actions.shape[1]).reshape(
                -1, observations.shape[-1]
            )
            next_observations = extend_and_repeat(next_observations, 1, actions.shape[1]).reshape(
                -1, observations.shape[-1]
            )
            actions = actions.reshape(-1, actions.shape[-1])
        with torch.no_grad():
            if self.cql_max_target_backup:
                z, mean, std = self.vae.encoder(observations, actions, next_observations)
                prob = stats.norm.cdf((self.prob_mean[0] - abs(mean.cpu() - self.prob_mean[0])),
                                      self.prob_mean[0],
                                      self.prob_mean[1]) * 2
                prob = torch.tensor(prob, dtype=torch.float32).to(self._device)
                prob = torch.mean(prob, dim=1).unsqueeze(1)
                zeros = torch.zeros_like(prob)
                prob = torch.where(prob < self.pm, zeros, prob)
            else:
                print(self.normalize_std, self.normalize_mean, observations.shape)
                z, mean, std = self.vae.encoder(observations, actions, next_observations)
                prob = stats.norm.cdf(self.prob_mean[0] - abs(mean.cpu() - self.prob_mean[0]), self.prob_mean[0],
                                      self.prob_mean[1]) * 2
                prob = torch.tensor(prob, dtype=torch.float32).to(self._device)
                prob = torch.mean(prob, dim=1).unsqueeze(1)
                zeros = torch.zeros_like(prob)
                prob = torch.where(prob < self.pm, zeros, prob)
            if multiple_actions:
                prob = prob.reshape(batch_size, -1)
        return prob

    def _q_loss(
        self,
        observations: torch.Tensor,
        actions: torch.Tensor,
        next_observations: torch.Tensor,
        rewards: torch.Tensor,
        dones: torch.Tensor,
        alpha: torch.Tensor,
        log_dict: Dict,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        q1_predicted = self.critic_1(observations, actions)
        q2_predicted = self.critic_2(observations, actions)

        if self.cql_max_target_backup:
            new_next_actions, next_log_pi = self.actor(
                next_observations, repeat=self.cql_n_actions
            )
            target_q_values, max_target_indices = torch.max(
                torch.min(
                    self.target_critic_1(next_observations, new_next_actions),
                    self.target_critic_2(next_observations, new_next_actions),
                ),
                dim=-1,
            )
            next_log_pi = torch.gather(
                next_log_pi, -1, max_target_indices.unsqueeze(-1)
            ).squeeze(-1)

            with torch.no_grad():
                off_target_q_values, off_max_target_indices = torch.max(
                    torch.min(
                        self.off_target_critic_1(next_observations, new_next_actions),
                        self.off_target_critic_2(next_observations, new_next_actions),
                    ),
                    dim=-1,
                )
        else:
            new_next_actions, next_log_pi = self.actor(next_observations)
            target_q_values = torch.min(
                self.target_critic_1(next_observations, new_next_actions),
                self.target_critic_2(next_observations, new_next_actions),
            )
            with torch.no_grad():
                off_target_q_values = torch.min(
                    self.off_target_critic_1(next_observations, new_next_actions),
                    self.off_target_critic_2(next_observations, new_next_actions),
                )

        if self.backup_entropy:
            target_q_values = target_q_values - alpha * next_log_pi

        target_q_values = target_q_values.unsqueeze(-1)
        off_target_q_values = off_target_q_values.unsqueeze(-1)
        target_q_values = (1 - self.prob) * target_q_values + self.prob * off_target_q_values
        td_target = rewards + (1.0 - dones) * self.discount * target_q_values.detach()

        td_target = td_target.squeeze(-1)
        qf1_loss = F.mse_loss(q1_predicted, td_target.detach())
        qf2_loss = F.mse_loss(q2_predicted, td_target.detach())

        # CQL
        batch_size = actions.shape[0]
        action_dim = actions.shape[-1]
        cql_random_actions = actions.new_empty(
            (batch_size, self.cql_n_actions, action_dim), requires_grad=False
        ).uniform_(-1, 1)
        cql_current_actions, cql_current_log_pis = self.actor(
            observations, repeat=self.cql_n_actions
        )
        cql_next_actions, cql_next_log_pis = self.actor(
            next_observations, repeat=self.cql_n_actions
        )
        cql_current_actions, cql_current_log_pis = (
            cql_current_actions.detach(),
            cql_current_log_pis.detach(),
        )
        cql_next_actions, cql_next_log_pis = (
            cql_next_actions.detach(),
            cql_next_log_pis.detach(),
        )

        cql_q1_rand = self.critic_1(observations, cql_random_actions)
        cql_q2_rand = self.critic_2(observations, cql_random_actions)
        cql_q1_current_actions = self.critic_1(observations, cql_current_actions)
        cql_q2_current_actions = self.critic_2(observations, cql_current_actions)
        cql_q1_next_actions = self.critic_1(observations, cql_next_actions)
        cql_q2_next_actions = self.critic_2(observations, cql_next_actions)

        cql_cat_q1 = torch.cat(
            [
                cql_q1_rand,
                torch.unsqueeze(q1_predicted, 1),
                cql_q1_next_actions,
                cql_q1_current_actions,
            ],
            dim=1,
        )
        cql_cat_q2 = torch.cat(
            [
                cql_q2_rand,
                torch.unsqueeze(q2_predicted, 1),
                cql_q2_next_actions,
                cql_q2_current_actions,
            ],
            dim=1,
        )
        cql_std_q1 = torch.std(cql_cat_q1, dim=1)
        cql_std_q2 = torch.std(cql_cat_q2, dim=1)

        if self.cql_importance_sample:
            random_density = np.log(0.5**action_dim)
            cql_cat_q1 = torch.cat(
                [
                    cql_q1_rand - random_density,
                    cql_q1_next_actions - cql_next_log_pis.detach(),
                    cql_q1_current_actions - cql_current_log_pis.detach(),
                ],
                dim=1,
            )
            cql_cat_q2 = torch.cat(
                [
                    cql_q2_rand - random_density,
                    cql_q2_next_actions - cql_next_log_pis.detach(),
                    cql_q2_current_actions - cql_current_log_pis.detach(),
                ],
                dim=1,
            )

        cql_qf1_ood = torch.logsumexp(cql_cat_q1 / self.cql_temp, dim=1) * self.cql_temp
        cql_qf2_ood = torch.logsumexp(cql_cat_q2 / self.cql_temp, dim=1) * self.cql_temp

        """Subtract the log likelihood of data"""
        with torch.no_grad():
            q1_predicted = self.prob * q1_predicted + (1 - self.prob) * self.off_target_critic_1(observations, actions)
            q2_predicted = self.prob * q2_predicted + (1 - self.prob) * self.off_target_critic_2(observations, actions)
        cql_qf1_diff = torch.clamp(
            cql_qf1_ood - q1_predicted,
            self.cql_clip_diff_min,
            self.cql_clip_diff_max,
        ).mean()
        cql_qf2_diff = torch.clamp(
            cql_qf2_ood - q2_predicted,
            self.cql_clip_diff_min,
            self.cql_clip_diff_max,
        ).mean()

        if self.cql_lagrange:
            alpha_prime = torch.clamp(
                torch.exp(self.log_alpha_prime()), min=0.0, max=1000000.0
            )
            cql_min_qf1_loss = (
                alpha_prime
                * self.cql_alpha
                * (cql_qf1_diff - self.cql_target_action_gap)
            )
            cql_min_qf2_loss = (
                alpha_prime
                * self.cql_alpha
                * (cql_qf2_diff - self.cql_target_action_gap)
            )

            self.alpha_prime_optimizer.zero_grad()
            alpha_prime_loss = (-cql_min_qf1_loss - cql_min_qf2_loss) * 0.5
            alpha_prime_loss.backward(retain_graph=True)
            self.alpha_prime_optimizer.step()
        else:
            cql_min_qf1_loss = cql_qf1_diff * self.cql_alpha
            cql_min_qf2_loss = cql_qf2_diff * self.cql_alpha
            alpha_prime_loss = observations.new_tensor(0.0)
            alpha_prime = observations.new_tensor(0.0)

        qf_loss = qf1_loss + qf2_loss + cql_min_qf1_loss + cql_min_qf2_loss

        log_dict.update(
            dict(
                qf1_loss=qf1_loss.item(),
                qf2_loss=qf2_loss.item(),
                alpha=alpha.item(),
                average_qf1=q1_predicted.mean().item(),
                average_qf2=q2_predicted.mean().item(),
                average_target_q=target_q_values.mean().item(),
            )
        )
        if self.total_it < self.offline_time:
            log_dict.update(
                dict(
                    off_average_target_q=off_target_q_values.mean().item(),
                    prob=self.prob.mean().item(),
                )
            )

        log_dict.update(
            dict(
                cql_std_q1=cql_std_q1.mean().item(),
                cql_std_q2=cql_std_q2.mean().item(),
                cql_q1_rand=cql_q1_rand.mean().item(),
                cql_q2_rand=cql_q2_rand.mean().item(),
                cql_min_qf1_loss=cql_min_qf1_loss.mean().item(),
                cql_min_qf2_loss=cql_min_qf2_loss.mean().item(),
                cql_qf1_diff=cql_qf1_diff.mean().item(),
                cql_qf2_diff=cql_qf2_diff.mean().item(),
                cql_q1_current_actions=cql_q1_current_actions.mean().item(),
                cql_q2_current_actions=cql_q2_current_actions.mean().item(),
                cql_q1_next_actions=cql_q1_next_actions.mean().item(),
                cql_q2_next_actions=cql_q2_next_actions.mean().item(),
                alpha_prime_loss=alpha_prime_loss.item(),
                alpha_prime=alpha_prime.item(),
            )
        )

        return qf_loss, alpha_prime, alpha_prime_loss

    def train(self, batch: TensorBatch) -> Dict[str, float]:
        (
            observations,
            actions,
            rewards,
            next_observations,
            dones,
        ) = batch
        self.total_it += 1
        new_actions, log_pi = self.actor(observations)
        with torch.no_grad():
            if self.cql_max_target_backup:
                z, mean, std = self.vae.encoder(observations, actions, next_observations)
                prob = stats.norm.cdf((self.prob_mean[0] - abs(mean.cpu() - self.prob_mean[0])),
                                      self.prob_mean[0],
                                      self.prob_mean[1]) * 2
                prob = torch.tensor(prob, dtype=torch.float32).to(self._device)
                prob = torch.mean(prob, dim=1).unsqueeze(1)
                zeros = torch.zeros_like(prob)
                prob = torch.where(prob < self.pm, zeros, prob)
                self.prob = prob
            else:
                z, mean, std = self.vae.encoder(observations, actions, next_observations)
                prob = stats.norm.cdf(self.prob_mean[0] - abs(mean.cpu() - self.prob_mean[0]), self.prob_mean[0],
                                      self.prob_mean[1]) * 2
                prob = torch.tensor(prob, dtype=torch.float32).to(self._device)
                prob = torch.mean(prob, dim=1).unsqueeze(1)
                zeros = torch.zeros_like(prob)
                prob = torch.where(prob < self.pm, zeros, prob)
                self.prob = prob

        alpha, alpha_loss = self._alpha_and_alpha_loss(observations, log_pi)

        """ Policy loss """
        policy_loss = self._policy_loss(
            observations, actions, new_actions, alpha, log_pi
        )

        log_dict = dict(
            log_pi=log_pi.mean().item(),
            policy_loss=policy_loss.item(),
            alpha_loss=alpha_loss.item(),
            alpha=alpha.item(),
        )

        """ Q function loss """
        qf_loss, alpha_prime, alpha_prime_loss = self._q_loss(
            observations, actions, next_observations, rewards, dones, alpha, log_dict
        )

        if self.use_automatic_entropy_tuning:
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()

        self.actor_optimizer.zero_grad()
        policy_loss.backward()
        self.actor_optimizer.step()

        self.critic_1_optimizer.zero_grad()
        self.critic_2_optimizer.zero_grad()
        qf_loss.backward(retain_graph=True)
        self.critic_1_optimizer.step()
        self.critic_2_optimizer.step()

        if self.total_it % self.target_update_period == 0:
            self.update_target_network(self.soft_target_update_rate)

        return log_dict

    def state_dict(self) -> Dict[str, Any]:
        return {
            "actor": self.actor.state_dict(),
            "critic1": self.critic_1.state_dict(),
            "critic2": self.critic_2.state_dict(),
            "critic1_target": self.target_critic_1.state_dict(),
            "critic2_target": self.target_critic_2.state_dict(),
            "critic_1_optimizer": self.critic_1_optimizer.state_dict(),
            "critic_2_optimizer": self.critic_2_optimizer.state_dict(),
            "actor_optim": self.actor_optimizer.state_dict(),
            "sac_log_alpha": self.log_alpha,
            "sac_log_alpha_optim": self.alpha_optimizer.state_dict(),
            "cql_log_alpha": self.log_alpha_prime,
            "cql_log_alpha_optim": self.alpha_prime_optimizer.state_dict(),
            "total_it": self.total_it,
        }

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.actor.load_state_dict(state_dict=state_dict["actor"])
        self.critic_1.load_state_dict(state_dict=state_dict["critic1"])
        self.critic_2.load_state_dict(state_dict=state_dict["critic2"])

        self.target_critic_1.load_state_dict(state_dict=state_dict["critic1_target"])
        self.target_critic_2.load_state_dict(state_dict=state_dict["critic2_target"])

        self.off_target_critic_1 = copy.deepcopy(self.target_critic_1)
        self.off_target_critic_2 = copy.deepcopy(self.target_critic_2)

        self.critic_1_optimizer.load_state_dict(
            state_dict=state_dict["critic_1_optimizer"]
        )
        self.critic_2_optimizer.load_state_dict(
            state_dict=state_dict["critic_2_optimizer"]
        )
        self.actor_optimizer.load_state_dict(state_dict=state_dict["actor_optim"])

        self.log_alpha = state_dict["sac_log_alpha"]
        self.alpha_optimizer.load_state_dict(
            state_dict=state_dict["sac_log_alpha_optim"]
        )

        self.log_alpha_prime = state_dict["cql_log_alpha"]
        self.alpha_prime_optimizer.load_state_dict(
            state_dict=state_dict["cql_log_alpha_optim"]
        )
        self.total_it = 0

    @torch.no_grad()
    def get_weights(self,
                    observations,
                    actions,
                    rewards,
                    next_observations,
                    dones,
                    ):
        with torch.no_grad():
            z, mean, std = self.vae.encoder(observations, actions, next_observations)
            prob = stats.norm.cdf((self.prob_mean[0] - abs(mean.cpu() - self.prob_mean[0])),
                                  self.prob_mean[0],
                                  self.prob_mean[1]) * 2
            prob = torch.tensor(prob, dtype=torch.float32).to(self._device)
            prob = torch.mean(prob, dim=1).unsqueeze(1)
            zeros = torch.zeros_like(prob)
            prob = torch.where(prob < self.pm, zeros, prob)

            q1_predicted = self.critic_1(observations, actions)
            q2_predicted = self.critic_2(observations, actions)
            new_next_actions, next_log_pi = self.actor(next_observations)
            off_target_q_values = torch.min(
                self.off_target_critic_1(next_observations, new_next_actions),
                self.off_target_critic_2(next_observations, new_next_actions),
            )
            off_target_q_values = off_target_q_values.unsqueeze(-1)
            td_target = rewards + (1.0 - dones) * self.discount * off_target_q_values.detach()
            td_target = td_target.squeeze(-1)
            qf1_loss = F.mse_loss(q1_predicted, td_target.detach())
            qf2_loss = F.mse_loss(q2_predicted, td_target.detach())
            off_values = abs(qf1_loss.sum()) + abs(qf2_loss.sum())

            on_target_q_values = torch.min(
                self.target_critic_1(next_observations, new_next_actions),
                self.target_critic_2(next_observations, new_next_actions),
            )
            online_target_q_values = on_target_q_values.unsqueeze(-1)
            td_target = rewards + (1.0 - dones) * self.discount * online_target_q_values.detach()
            td_target = td_target.squeeze(-1)
            qf1_loss = F.mse_loss(q1_predicted, td_target.detach())
            qf2_loss = F.mse_loss(q2_predicted, td_target.detach())
            on_values = abs(qf1_loss.sum()) + abs(qf2_loss.sum())

        return [prob, off_values, on_values]

    def vae_loss(self, obs, act, next_observations, t):
        recon, mean, std = self.vae(obs, act, next_observations)
        recon_loss = nn.functional.mse_loss(recon, next_observations)
        KL_loss = - 0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()

        loss_vae = recon_loss + 20 * KL_loss

        self.vae_optim.zero_grad()
        loss_vae.backward()
        self.vae_optim.step()
        stats_vae = {"loss/loss_vae": loss_vae.item(),
                     "loss/KL_loss": KL_loss.item(),
                     "loss/recon_loss": recon_loss.item(),
                     # "beta": self.this_beta
                     }
        print(stats_vae)
        return loss_vae, stats_vae

    def train_one_step(self, observations, next_observations, actions, rewards, done, t):
        # update VAE
        loss_vae, log_dict = self.vae_loss(observations, actions, next_observations, t)

        return log_dict

    def update_vae(self, buffer):
        batch = buffer.filter_resolved_samples()
        batch = [b.to(self._device) for b in batch]
        (
            observations,
            actions,
            rewards,
            next_observations,
            dones,
        ) = batch
        if len(observations.shape) == 1:
            observations = observations.unsqueeze(0)
            actions = actions.unsqueeze(0)
            rewards = rewards.unsqueeze(0)
            next_observations = next_observations.unsqueeze(0)
            dones = dones.unsqueeze(0)
        print('observation shape', observations.shape)
        self.off_target_critic_1 = copy.deepcopy(self.target_critic_1).to(self._device)
        self.off_target_critic_1 = copy.deepcopy(self.target_critic_1).to(self._device)

        print('observation shape', observations.shape)
        print('action shape', actions.shape)
        print('next observation shape', next_observations.shape)
        if observations.shape[0] >= self.vae_batch_size:
            for epoch in range(observations.shape[0] // self.vae_batch_size * 10):
                indices = np.random.randint(0, observations.shape[0], size=self.vae_batch_size)
                print(indices.shape)
                log_dict = self.train_one_step(observations[indices].to(self._device),
                                               next_observations[indices].to(self._device),
                                               actions[indices].to(self._device),
                                               rewards[indices].to(self._device),
                                               dones[indices].to(self._device),
                                               0)
                log_dict['vae_step'] = self.vae_step
                wandb.log(log_dict, commit=False)
                self.vae_step += 1

