import os
import random
import uuid
from copy import deepcopy
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from pathlib import Path

import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from tqdm import trange
import copy
import sys
from scipy import stats
sys.path.append("../..")
from configs.configs import get_awac_train_configs

TensorBatch = List[torch.Tensor]

ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate")


class ReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        device: str = "cpu",
        incre: bool = False,
    ):
        self.incre = incre
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0
        self.init_size = buffer_size
        self.ood_pointer = 0
        self.ood_size = 10000
        self._update_size = 500
        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        if self.incre:
            self._weights = torch.ones((buffer_size, 3), dtype=torch.float32, device=device)
            '''
            There are two weights, one for offline Q error, one for online Q error
            '''
        self.l_pointer = 0
        self.update_interval = 10000
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._device = device

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32, device=self._device)

    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        self.n_transitions = n_transitions
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)
        self.l_pointer = self._pointer

        print(f"Dataset size: {n_transitions}")

        # Add extra data in d4rl format, i.e. from Dict[str, np.array].

    def add_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        print(self._pointer)
        n_transitions = data["observations"].shape[0]
        if n_transitions + self._pointer > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[self._pointer:self._pointer + n_transitions] = self._to_tensor(data["observations"])
        self._actions[self._pointer:self._pointer + n_transitions] = self._to_tensor(data["actions"])
        self._rewards[self._pointer:self._pointer + n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[self._pointer:self._pointer + n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[self._pointer:self._pointer + n_transitions] = self._to_tensor(data["terminals"][..., None])

        self._size += n_transitions
        self._pointer = min(self._size, self._pointer + n_transitions)
        self.l_pointer = self._pointer
        self.length = self._pointer + n_transitions
        print(f"Dataset size: {n_transitions}")
        print(self._pointer)
        print("Added success.\n")

    def sample(self, batch_size: int) -> TensorBatch:
        indices = np.random.randint(0, self._size, size=batch_size)
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        return [states, actions, rewards, next_states, dones]

    def add_transition(
        self,
        state: np.ndarray,
        action: np.ndarray,
        reward: float,
        next_state: np.ndarray,
        done: bool,
        weight: 0
    ):
        # Use this method to add new data into the replay buffer during fine-tuning.
        self._states[self._pointer] = self._to_tensor(state)
        self._actions[self._pointer] = self._to_tensor(action)
        self._rewards[self._pointer] = self._to_tensor(reward)
        self._next_states[self._pointer] = self._to_tensor(next_state)
        self._dones[self._pointer] = self._to_tensor(done)

        self._weights[self._pointer] = self._to_tensor(weight)
        self._pointer = (self._pointer + 1) % self._buffer_size
        self._size = min(self._size + 1, self._buffer_size)

    # todo: how to improve this
    def initialize(self, env, trainer, config, max_action, max_steps):
        state, done = env.reset(), False
        # max_steps = env._max_episode_steps
        print(self._size)
        while self._size < self.init_size:
            episode_step = 0
            while not done:
                episode_step += 1
                action, _ = trainer._actor(
                    torch.tensor(
                        state.reshape(1, -1), device=config.device, dtype=torch.float32
                    )
                )
                action = action.cpu().data.numpy().flatten()
                next_state, reward, done, env_infos = env.step(action)
                real_done = False  # Episode can timeout which is different from done
                if done and episode_step < max_steps:
                    real_done = True
                # print(real_done, episode_step)
                weights = trainer.get_weights(self._to_tensor(state).unsqueeze(0), self._to_tensor(action).unsqueeze(0),
                                              reward, self._to_tensor(next_state).unsqueeze(0),
                                              self._to_tensor(np.array(real_done)).unsqueeze(0))
                self.add_transition(state, action, reward, next_state, done, weights)
                state = next_state
            if done:
                state, done = env.reset(), False

    def __len__(self):
        return self.n_transitions

    def __getitem__(self, indices):
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        return [states, actions, rewards, next_states, dones]

    def is_update_intervals(self):
        if self._pointer - self.l_pointer == self.update_interval or\
           self._pointer + self._buffer_size - self.l_pointer == self.update_interval:
            return True
        else:
            return False

    def filter_resolved_samples(self):
        # [self.l_pointer, self._pointer)
        # ood_indices
        print("Buffer pointer:", self._pointer, "Left pointer:", self.l_pointer)
        if self._pointer + self._buffer_size - self.l_pointer == self.update_interval:
            print("Enrolling the buffer once.")
            ood_indices_1 = (self._weights[self.l_pointer:self._buffer_size, 0] == 0).nonzero() + self.l_pointer
            ood_indices_2 = (self._weights[0:self._pointer, 0] == 0).nonzero()
            print(ood_indices_1.shape, ood_indices_2.shape)
            ood_indices = torch.cat([ood_indices_1, ood_indices_2]).squeeze()
        elif self._pointer - self.l_pointer == self.update_interval:
            ood_indices = (self._weights[self.l_pointer:self._pointer, 0] == 0).nonzero().squeeze() + self.l_pointer
        else:
            raise ValueError("The buffer size is too small and",
                             "the pointer enrolls the whole buffer more than one time.")

        # print(ood_indices)
        print(ood_indices.shape)
        states = self._states[ood_indices, :].squeeze()
        actions = self._actions[ood_indices, :].squeeze()
        rewards = self._rewards[ood_indices, :]
        next_states = self._next_states[ood_indices, :].squeeze()
        dones = self._dones[ood_indices, :]
        weights = self._weights[ood_indices, :].squeeze()
        if len(weights.shape) == 1:
            weights = weights.unsqueeze(0)
            states = states.unsqueeze(0)
            actions = actions.unsqueeze(0)
            next_states = next_states.unsqueeze(0)
            rewards = rewards.unsqueeze(0)
            dones = dones.unsqueeze(0)
        # print(weights.shape)
        # ood_samples = self._weights[ood_indices, 2]
        resolved_indices = torch.argsort(weights[:, 2])[:self._update_size]
        print(resolved_indices.shape)
        print('rewards and dones', rewards.shape, dones.shape)
        states = states[resolved_indices, :]
        actions = actions[resolved_indices, :]
        rewards = rewards[resolved_indices, :]
        next_states = next_states[resolved_indices, :]
        dones = dones[resolved_indices, :]
        # print(next_states.shape,rewards.shape)

        # self.ood_pointer += 1
        # if self.ood_pointer * self.ood_size == self._buffer_size:
        #     print('ood index rounded')
        #     self.ood_pointer = 0

        return [states.squeeze(), actions.squeeze(), rewards.squeeze(), next_states.squeeze(), dones.squeeze()]

    def update_left_pointer(self):
        self.l_pointer = self._pointer

    def filter_in_distribution(self):
        try:
            in_indices = (self._weights[self.ood_pointer * self.ood_size:(self.ood_pointer + 1) * self.ood_size, 0] != 0).nonzero()
        except ValueError:
            raise ValueError("OOD indices out of index")
        states = self._states[in_indices, :].squeeze()
        actions = self._states[in_indices, :].squeeze()
        next_states = self._states[in_indices, :].squeeze()

        return [states.squeeze(), actions.squeeze(), next_states.squeeze()]

def set_env_seed(env: Optional[gym.Env], seed: int):
    env.seed(seed)
    env.action_space.seed(seed)


class VAE(nn.Module):
    """
        Variational Auto-Encoder

        Args:
            obs_dim (int): The dimension of the observation space.
            act_dim (int): The dimension of the action space.
            hidden_size (int): The number of hidden units in the encoder and decoder networks, default=64.
            latent_dim (int): The dimensionality of the latent space.
            act_lim (float): The upper limit of the action space.
            device (str): The device to use for computation (cpu or cuda).
        """

    def __init__(self, obs_dim, act_dim, hidden_size, latent_dim, act_lim, obs_lim, device='cuda'):
        super(VAE, self).__init__()
        self.device = device
        self.latent_dim = latent_dim
        self.e1 = nn.Linear(obs_dim + act_dim, hidden_size).to(self.device)
        self.e2 = nn.Linear(hidden_size, hidden_size).to(self.device)

        self.mean = nn.Linear(hidden_size, latent_dim).to(self.device)
        self.log_std = nn.Linear(hidden_size, latent_dim).to(self.device)

        self.d1 = nn.Linear(obs_dim + act_dim + latent_dim, hidden_size).to(self.device)
        self.d2 = nn.Linear(hidden_size, hidden_size).to(self.device)
        self.d3 = nn.Linear(hidden_size, obs_dim).to(self.device)

        self.act_lim = act_lim
        self.obs_lim = obs_lim

    def forward(self, obs, act, next_obs):
        z, mean, std = self.encoder(obs, act, next_obs)
        u = self.decode(obs, act, next_obs, z)

        return u, mean, std

    def encoder(self, obs, act, next_obs):
        z = F.relu(self.e1(torch.cat([obs, act], 1)))
        z = F.relu(self.e2(z))

        mean = self.mean(z)
        # clamp for numerical stability
        log_std = self.log_std(z).clamp(-4, 15)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        return z, mean, std

    def decode(self, obs, act, next_obs, z=None):
        if z is None:
            z = torch.randn((obs.shape[0], self.latent_dim)).clamp(-0.5, 0.5).to(self.device)
        s = F.relu(self.d1(torch.cat([obs, act, z], 1)))
        s = F.relu(self.d2(s))

        return self.d3(s)

    # for BEARL only
    def decode_multiple(self, obs, z=None, num_decode=10):
        if z is None:
            z = torch.randn(
                (obs.shape[0], num_decode, self.latent_dim)).clamp(-0.5,
                                                                   0.5).to(self.device)

        a = F.relu(
            self.d1(
                torch.cat(
                    [obs.unsqueeze(0).repeat(num_decode, 1, 1).permute(1, 0, 2), z], 2)))
        a = F.relu(self.d2(a))
        return torch.tanh(self.d3(a)), self.d3(a)



class Actor(nn.Module):
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dim: int,
        min_log_std: float = -20.0,
        max_log_std: float = 2.0,
        min_action: float = -1.0,
        max_action: float = 1.0,
    ):
        super().__init__()
        self._mlp = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
        )
        self._log_std = nn.Parameter(torch.zeros(action_dim, dtype=torch.float32))
        self._min_log_std = min_log_std
        self._max_log_std = max_log_std
        self._min_action = min_action
        self._max_action = max_action

    def _get_policy(self, state: torch.Tensor) -> torch.distributions.Distribution:
        mean = self._mlp(state)
        log_std = self._log_std.clamp(self._min_log_std, self._max_log_std)
        policy = torch.distributions.Normal(mean, log_std.exp())
        return policy

    def log_prob(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        policy = self._get_policy(state)
        log_prob = policy.log_prob(action).sum(-1, keepdim=True)
        return log_prob

    def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        policy = self._get_policy(state)
        action = policy.rsample()
        action.clamp_(self._min_action, self._max_action)
        log_prob = policy.log_prob(action).sum(-1, keepdim=True)
        return action, log_prob

    def act(self, state: np.ndarray, device: str) -> np.ndarray:
        state_t = torch.tensor(state[None], dtype=torch.float32, device=device)
        policy = self._get_policy(state_t)
        if self._mlp.training:
            action_t = policy.sample()
        else:
            action_t = policy.mean
        action = action_t[0].cpu().numpy()
        return action


class Critic(nn.Module):
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dim: int,
    ):
        super().__init__()
        self._mlp = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        q_value = self._mlp(torch.cat([state, action], dim=-1))
        return q_value


def soft_update(target: nn.Module, source: nn.Module, tau: float):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)


def to_tensor(data, device) -> torch.Tensor:
    return torch.tensor(data, dtype=torch.float32).to(device)


class AdvantageWeightedActorCritic:
    def __init__(
        self,
        state_dim,
        action_dim,
        max_action,
        actor: nn.Module,
        actor_optimizer: torch.optim.Optimizer,
        critic_1: nn.Module,
        critic_1_optimizer: torch.optim.Optimizer,
        critic_2: nn.Module,
        critic_2_optimizer: torch.optim.Optimizer,
        vae_lr,
        pm: float = 0.6,
        gamma: float = 0.99,
        tau: float = 5e-3,  # parameter for the soft target update,
        awac_lambda: float = 1.0,
        exp_adv_max: float = 100.0,
        device: str = "cpu",
        vae_hidden_dim: int = 400,
    ):
        self._actor = actor
        self._actor_optimizer = actor_optimizer

        self._critic_1 = critic_1
        self._critic_1_optimizer = critic_1_optimizer
        self._target_critic_1 = deepcopy(critic_1)

        self._critic_2 = critic_2
        self._critic_2_optimizer = critic_2_optimizer
        self._target_critic_2 = deepcopy(critic_2)

        self.off_target_critic_1 = copy.deepcopy(critic_1).to(device)
        self.off_target_critic_2 = copy.deepcopy(critic_2).to(device)

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.vae_hidden_dim = vae_hidden_dim
        self.latent_dim = state_dim + action_dim
        self.vae = VAE(self.state_dim, self.action_dim, self.vae_hidden_dim,
                       self.latent_dim, self.max_action, device).to(device)
        self.vae_optim = torch.optim.Adam(self.vae.parameters(), lr=vae_lr)
        self.prob_mean, self.prob_std = (), ()
        self.offline_time = 1e9
        self.count = 0
        self.vae_batch_size = 64
        self.vae_step = 0
        self.total_it = 0
        self.pm = pm
        self.device = device

        self._gamma = gamma
        self._tau = tau
        self._awac_lambda = awac_lambda
        self._exp_adv_max = exp_adv_max

    def _actor_loss(
        self,
        states: torch.Tensor,
        actions: torch.Tensor,
    ) -> torch.Tensor:
        with torch.no_grad():
            pi_action, _ = self._actor(states)
            v = torch.min(
                self._critic_1(states, pi_action), self._critic_2(states, pi_action)
            )

            q = torch.min(
                self._critic_1(states, actions), self._critic_2(states, actions)
            )
            adv = q - v
            weights = torch.clamp_max(
                torch.exp(adv / self._awac_lambda), self._exp_adv_max
            )

        action_log_prob = self._actor.log_prob(states, actions)
        loss = (-action_log_prob * weights).mean()
        return loss

    def _critic_loss(
        self,
        states: torch.Tensor,
        actions: torch.Tensor,
        rewards: torch.Tensor,
        dones: torch.Tensor,
        next_states: torch.Tensor,
    ):
        prob = torch.zeros(rewards.shape[0])
        with torch.no_grad():
            next_actions, _ = self._actor(next_states)
            q_next = torch.min(
                self._target_critic_1(next_states, next_actions),
                self._target_critic_2(next_states, next_actions),
            )
            if self.total_it <= self.offline_time:
                off_q_next = torch.min(
                    self.off_target_critic_1(next_states, next_actions),
                    self.off_target_critic_2(next_states, next_actions),
                )
                z, mean, std = self.vae.encoder(states, actions, next_states)
                prob = stats.norm.cdf((self.prob_mean[0] - abs(mean.cpu() - self.prob_mean[0])), self.prob_mean[0],
                                      self.prob_mean[1]) * 2
                # prob2 = stats.norm.cdf(self.prob_std[0] - abs(mean.cpu() - self.prob_std[0]), self.prob_std[0],
                #                       self.prob_std[1]) * 2
                # prob = (prob1 + prob2) / 2
                prob = torch.tensor(prob, dtype=torch.float32).to(self.device)
                # print(prob.shape)
                zeros = torch.zeros_like(prob)
                prob = torch.where(prob < self.pm, zeros, prob)
                prob = torch.mean(prob, dim=1).unsqueeze(1)
                q_next = (1 - prob) * q_next + prob * off_q_next
            q_target = rewards + self._gamma * (1.0 - dones) * q_next

        q1 = self._critic_1(states, actions)
        q2 = self._critic_2(states, actions)

        q1_loss = nn.functional.mse_loss(q1, q_target)
        q2_loss = nn.functional.mse_loss(q2, q_target)
        loss = q1_loss + q2_loss
        return loss, prob

    def _update_critic(
        self,
        states: torch.Tensor,
        actions: torch.Tensor,
        rewards: torch.Tensor,
        dones: torch.Tensor,
        next_states: torch.Tensor,
    ):
        loss, prob = self._critic_loss(states, actions, rewards, dones, next_states)
        self._critic_1_optimizer.zero_grad()
        self._critic_2_optimizer.zero_grad()
        loss.backward()
        self._critic_1_optimizer.step()
        self._critic_2_optimizer.step()
        return loss.item(), prob

    def _update_actor(self, states, actions):
        loss = self._actor_loss(states, actions)
        self._actor_optimizer.zero_grad()
        loss.backward()
        self._actor_optimizer.step()
        return loss.item()

    def update(self, batch: TensorBatch) -> Dict[str, float]:
        states, actions, rewards, next_states, dones = batch
        self.total_it += 1
        critic_loss, prob = self._update_critic(states, actions, rewards, dones, next_states)
        actor_loss = self._update_actor(states, actions)

        soft_update(self._target_critic_1, self._critic_1, self._tau)
        soft_update(self._target_critic_2, self._critic_2, self._tau)

        result = {"critic_loss": critic_loss, "actor_loss": actor_loss, 'prob': prob.mean().item()}
        return result

    def state_dict(self) -> Dict[str, Any]:
        return {
            "actor": self._actor.state_dict(),
            "critic_1": self._critic_1.state_dict(),
            "critic_2": self._critic_2.state_dict(),
        }

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self._actor.load_state_dict(state_dict["actor"])
        self._critic_1.load_state_dict(state_dict["critic_1"])
        self._critic_2.load_state_dict(state_dict["critic_2"])

        self.off_target_critic_1 = copy.deepcopy(self._critic_1)
        self.off_target_critic_2 = copy.deepcopy(self._critic_2)

        self._target_critic_1 = copy.deepcopy(self._critic_1)
        self._target_critic_2 = copy.deepcopy(self._critic_2)

    @torch.no_grad()
    def get_weights(self,
                    observations,
                    actions,
                    rewards,
                    next_observations,
                    dones,
                    ):
        with torch.no_grad():
            # print(observations, actions)
            z, vae_mean, std = self.vae.encoder(observations, actions, next_observations)
            # print(self.prob_mean, vae_mean)
            prob = stats.norm.cdf((self.prob_mean[0] - abs(vae_mean.cpu() - self.prob_mean[0])), self.prob_mean[0],
                                  self.prob_mean[1]) * 2
            # print(prob.shape)
            prob = torch.tensor(prob, dtype=torch.float32).to(self.device)
            prob = torch.mean(prob, dim=1).unsqueeze(1)
            zeros = torch.zeros_like(prob)
            # print(prob.shape)
            prob = torch.where(prob < self.pm, zeros, prob)

            next_actions, _ = self._actor(next_observations)
            q_next = torch.min(
                self._target_critic_1(next_observations, next_actions),
                self._target_critic_2(next_observations, next_actions),
            )
            off_q_next = torch.min(
                self.off_target_critic_1(next_observations, next_actions),
                self.off_target_critic_2(next_observations, next_actions),
            )
            q1 = self._critic_1(observations, actions)
            q2 = self._critic_2(observations, actions)

            online_q_loss_1 = torch.nn.functional.mse_loss(q1, q_next)
            online_q_loss_2 = torch.nn.functional.mse_loss(q2, q_next)
            online_q_loss = online_q_loss_1 + online_q_loss_2

            offline_q_loss_1 = torch.nn.functional.mse_loss(q1, off_q_next)
            offline_q_loss_2 = torch.nn.functional.mse_loss(q2, off_q_next)
            offline_q_loss = offline_q_loss_1 + offline_q_loss_2

        return torch.tensor([prob, offline_q_loss, online_q_loss])

    def vae_loss(self, obs, act, next_observations, t):
        recon, mean, std = self.vae(obs, act, next_observations)
        # print(recon, mean, std)
        recon_loss = nn.functional.mse_loss(recon, next_observations)
        KL_loss = - 0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
        # KL_loss = torch.clamp(KL_loss, -100, 100)

        loss_vae = recon_loss + 20 * KL_loss

        self.vae_optim.zero_grad()
        loss_vae.backward()
        self.vae_optim.step()
        stats_vae = {"loss/loss_vae": loss_vae.item(),
                     "loss/KL_loss": KL_loss.item(),
                     "loss/recon_loss": recon_loss.item(),
                     # "beta": self.this_beta
                     }
        print(stats_vae)
        return loss_vae, stats_vae

    def train_one_step(self, observations, next_observations, actions, rewards, done, t):
        # update VAE
        # print(observations.shape, actions.shape)
        loss_vae, log_dict = self.vae_loss(observations, actions, next_observations, t)

        return log_dict

    def update_vae(self, buffer):
        batch = buffer.filter_resolved_samples()
        batch = [b.to(self.device) for b in batch]
        (
            observations,
            actions,
            rewards,
            next_observations,
            dones,
        ) = batch
        # print(indices)
        if len(observations.shape) == 1:
            observations = observations.unsqueeze(0)
            actions = actions.unsqueeze(0)
            rewards = rewards.unsqueeze(0)
            next_observations = next_observations.unsqueeze(0)
            dones = dones.unsqueeze(0)
        # print('observation shape', observations.shape)
        self.off_target_critic_1 = copy.deepcopy(self._critic_1)
        self.off_target_critic_2 = copy.deepcopy(self._critic_2)
        print('observation shape', observations.shape)
        print('action shape', actions.shape)
        print('next observation shape', next_observations.shape)
        if observations.shape[0] >= self.vae_batch_size:
            for epoch in range(observations.shape[0] // self.vae_batch_size):
                indices = np.random.randint(0, observations.shape[0], size=self.vae_batch_size)
                # print(observations[epoch*self.online_batch_size:(epoch+1)*self.online_batch_size)
                log_dict = self.train_one_step(observations[indices].to(self.device),
                                               next_observations[indices].to(self.device),
                                               actions[indices].to(self.device),
                                               rewards[indices].to(self.device),
                                               dones[indices].to(self.device),
                                               0)
                log_dict['vae_step'] = self.vae_step
                # print(log_dict)
                wandb.log(log_dict, commit=False)
                self.vae_step += 1
        else:
            print("No extra OOD samples detected.")

    def update_vae_parameters(self, observations, actions, next_observations):
        data_mean_list = []
        data_std_list = []
        for i in range(len(observations)):
            observations = observations[i].unsqueeze(0)
            actions = actions[i].unsqueeze(0)
            next_observations = next_observations[i].unsqueeze(0)
            # observations = observations.unsqueeze(0)
            mean, std = self.vae.test_one_step(observations, next_observations, actions, 1, 0)
            data_mean_list.extend(mean.cpu().squeeze())
            data_std_list.extend(std.cpu().squeeze())
        import fitter
        mean_f = fitter.Fitter(data_mean_list, distributions=['norm'], timeout=10000)
        mean_f.fit()
        mean_dict = mean_f.fitted_param['norm']

        std_f = fitter.Fitter(data_std_list, distributions=['norm'], timeout=10000)
        std_f.fit()
        std_dict = std_f.fitted_param['norm']
        print(mean_dict, std_dict)

        return mean_dict, std_dict

def set_seed(
    seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
    if env is not None:
        set_env_seed(env, seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
    mean = states.mean(0)
    std = states.std(0) + eps
    return mean, std


def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
    return (states - mean) / std


def wrap_env(
    env: gym.Env,
    state_mean: Union[np.ndarray, float] = 0.0,
    state_std: Union[np.ndarray, float] = 1.0,
) -> gym.Env:
    def normalize_state(state):
        return (state - state_mean) / state_std

    env = gym.wrappers.TransformObservation(env, normalize_state)
    return env


def is_goal_reached(reward: float, info: Dict) -> bool:
    if "goal_achieved" in info:
        return info["goal_achieved"]
    return reward > 0  # Assuming that reaching target is a positive reward


@torch.no_grad()
def eval_actor(
    env: gym.Env, actor: Actor, device: str, n_episodes: int, seed: int
) -> Tuple[np.ndarray, np.ndarray]:
    env.seed(seed)
    actor.eval()
    episode_rewards = []
    successes = []
    for _ in range(n_episodes):
        state, done = env.reset(), False
        episode_reward = 0.0
        goal_achieved = False
        while not done:
            action = actor.act(state, device)
            state, reward, done, env_infos = env.step(action)
            episode_reward += reward
            if not goal_achieved:
                goal_achieved = is_goal_reached(reward, env_infos)
        # Valid only for environments with goal
        successes.append(float(goal_achieved))
        episode_rewards.append(episode_reward)

    actor.train()
    return np.asarray(episode_rewards), np.mean(successes)


def return_reward_range(dataset: Dict, max_episode_steps: int) -> Tuple[float, float]:
    returns, lengths = [], []
    ep_ret, ep_len = 0.0, 0
    for r, d in zip(dataset["rewards"], dataset["terminals"]):
        ep_ret += float(r)
        ep_len += 1
        if d or ep_len == max_episode_steps:
            returns.append(ep_ret)
            lengths.append(ep_len)
            ep_ret, ep_len = 0.0, 0
    lengths.append(ep_len)  # but still keep track of number of steps
    assert sum(lengths) == len(dataset["rewards"])
    return min(returns), max(returns)


def modify_reward(dataset: Dict, env_name: str, max_episode_steps: int = 1000) -> Dict:
    if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
        min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
        dataset["rewards"] /= max_ret - min_ret
        dataset["rewards"] *= max_episode_steps
        return {
            "max_ret": max_ret,
            "min_ret": min_ret,
            "max_episode_steps": max_episode_steps,
        }
    elif "antmaze" in env_name:
        dataset["rewards"] -= 1.0
    return {}


def modify_reward_online(reward: float, env_name: str, **kwargs) -> float:
    if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
        reward /= kwargs["max_ret"] - kwargs["min_ret"]
        reward *= kwargs["max_episode_steps"]
    elif "antmaze" in env_name:
        reward -= 1.0
    return reward


def wandb_init(config) -> None:
    wandb.init(
        project=config.project,
        group=config.group,
        name=config.name,
        settings=wandb.Settings(start_method="thread")
    )
    wandb.run.save()


def train(config):
    env = gym.make(config.env)
    eval_env = gym.make(config.env)

    is_env_with_goal = config.env.startswith(ENVS_WITH_GOAL)

    max_steps = env._max_episode_steps

    set_seed(config.seed, env, deterministic_torch=config.deterministic_torch)
    set_env_seed(eval_env, config.eval_seed)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    dataset = d4rl.qlearning_dataset(env)

    reward_mod_dict = {}
    if config.normalize_reward:
        reward_mod_dict = modify_reward(dataset, config.env)

    state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )
    env = wrap_env(env, state_mean=state_mean, state_std=state_std)
    eval_env = wrap_env(eval_env, state_mean=state_mean, state_std=state_std)
    replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
        True
    )

    offline_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        2_000_000,
        config.device,
    )
    offline_buffer.load_d4rl_dataset(dataset)
    if "pen" in config.env:
        env1 = gym.make("pen-human-v1")
        dataset1 = d4rl.qlearning_dataset(env1)
        offline_buffer.add_d4rl_dataset(dataset1)
    elif "door" in config.env:
        env1 = gym.make("door-human-v1")
        dataset1 = d4rl.qlearning_dataset(env1)
        offline_buffer.add_d4rl_dataset(dataset1)
    elif "relocate" in config.env:
        env1 = gym.make("relocate-human-v1")
        dataset1 = d4rl.qlearning_dataset(env1)
        offline_buffer.add_d4rl_dataset(dataset1)

    max_action = float(env.action_space.high[0])

    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)

    actor_critic_kwargs = {
        "state_dim": state_dim,
        "action_dim": action_dim,
        "hidden_dim": config.hidden_dim,
    }

    actor = Actor(**actor_critic_kwargs)
    actor.to(config.device)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.learning_rate)
    critic_1 = Critic(**actor_critic_kwargs)
    critic_2 = Critic(**actor_critic_kwargs)
    critic_1.to(config.device)
    critic_2.to(config.device)
    critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=config.learning_rate)
    critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=config.learning_rate)

    awac = AdvantageWeightedActorCritic(
        state_dim=env.observation_space.shape[0],
        action_dim=env.action_space.shape[0],
        max_action=max_action,
        actor=actor,
        actor_optimizer=actor_optimizer,
        critic_1=critic_1,
        critic_1_optimizer=critic_1_optimizer,
        critic_2=critic_2,
        critic_2_optimizer=critic_2_optimizer,
        gamma=config.gamma,
        vae_lr=config.vae_lr,
        tau=config.tau,
        awac_lambda=config.awac_lambda,
        device=config.device,
        vae_hidden_dim=config.vae_hidden_dim,
    )

    if config.load_model != "":
        policy_file = Path(config.load_model)
        awac.load_state_dict(torch.load(policy_file))
        actor = awac._actor
        awac.vae = torch.load(config.vae_file)['vae']
        awac.vae_optim = torch.load(config.vae_file)['vae_optim']

    awac.prob_mean, awac.prob_std = awac.vae.judge_normal_distribution(offline_buffer)
    print('vae information', awac.prob_mean, awac.prob_std, '\n\n')

    replay_buffer.initialize(env, awac, config, max_action, max_steps)

    full_eval_scores, full_normalized_eval_scores = [], []
    state, done = env.reset(), False
    episode_step = 0
    episode_return = 0
    goal_achieved = False

    eval_successes = []
    train_successes = []

    print("Online tuning")
    for t in trange(
        int(config.online_iterations), ncols=80
    ):
        online_log = {}
        episode_step += 1
        action, _ = actor(
            torch.tensor(
                state.reshape(1, -1), device=config.device, dtype=torch.float32
            )
        )
        action = action.cpu().data.numpy().flatten()
        next_state, reward, done, env_infos = env.step(action)

        if not goal_achieved:
            goal_achieved = is_goal_reached(reward, env_infos)
        episode_return += reward
        real_done = False  # Episode can timeout which is different from done
        if done and episode_step < max_steps:
            real_done = True

        if config.normalize_reward:
            reward = modify_reward_online(reward, config.env, **reward_mod_dict)
        weights = awac.get_weights(to_tensor(state, config.device).unsqueeze(0),
                                      to_tensor(action, config.device).unsqueeze(0),
                                      reward, to_tensor(next_state, config.device).unsqueeze(0),
                                      to_tensor(np.array(real_done), config.device).unsqueeze(0))
        replay_buffer.add_transition(state, action, reward, next_state, real_done, weights)
        state = next_state

        if replay_buffer.is_update_intervals():
            if awac.count >= 4:
                print("To update vae")
                print(f'Model updated on epoch {t}', "\n"*2)
                awac.update_vae(replay_buffer)
            awac.count += 1
            replay_buffer.update_left_pointer()

        if done:
            state, done = env.reset(), False
            # Valid only for envs with goal, e.g. AntMaze, Adroit
            if is_env_with_goal:
                train_successes.append(goal_achieved)
                online_log["train/regret"] = np.mean(1 - np.array(train_successes))
                online_log["train/is_success"] = float(goal_achieved)
            online_log["train/episode_return"] = episode_return
            normalized_return = eval_env.get_normalized_score(episode_return)
            online_log["train/d4rl_normalized_episode_return"] = (
                normalized_return * 100.0
            )
            online_log["train/episode_length"] = episode_step
            episode_return = 0
            episode_step = 0
            goal_achieved = False

        batch = replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        update_result = awac.update(batch)
        update_result[
            "online_iter"
        ] = t
        update_result.update(online_log)
        wandb.log(update_result, step=t)
        if (t + 1) % config.eval_frequency == 0:
            eval_scores, success_rate = eval_actor(
                eval_env, actor, config.device, config.n_test_episodes, config.seed
            )
            eval_log = {}

            full_eval_scores.append(eval_scores)
            wandb.log({"eval/eval_score": eval_scores.mean()}, step=t)
            if hasattr(eval_env, "get_normalized_score"):
                normalized = eval_env.get_normalized_score(np.mean(eval_scores))
                # Valid only for envs with goal, e.g. AntMaze, Adroit
                if t >= config.offline_iterations and is_env_with_goal:
                    eval_successes.append(success_rate)
                    eval_log["eval/regret"] = np.mean(1 - np.array(train_successes))
                    eval_log["eval/success_rate"] = success_rate
                normalized_eval_scores = normalized * 100.0
                full_normalized_eval_scores.append(normalized_eval_scores)
                eval_log["eval/d4rl_normalized_score"] = normalized_eval_scores
                wandb.log(eval_log, step=t)
            if config.checkpoints_path and t >= 490000:
                torch.save(
                    awac.state_dict(),
                    os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
                )


if __name__ == "__main__":
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    for seed in range(0, 5, 1):
        env = 'halfcheetah-medium-replay-v2'
        config = get_awac_train_configs(env, 'algo-AWAC', f'incremental-{seed}', seed)
        wandb_init(config)
        train(config)
        wandb.finish()
        sys.stdout.flush()
