# source: https://github.com/gwthomas/IQL-PyTorch
# https://arxiv.org/pdf/2110.06169.pdf
import copy
import os
import random
import uuid
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import math

import d4rl
import gym
import numpy as np
# import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch.distributions import Normal
from torch.optim.lr_scheduler import CosineAnnealingLR
from scipy import stats
import sys
sys.path.append("../..")
from configs.configs import get_iql_train_configs

TensorBatch = List[torch.Tensor]


EXP_ADV_MAX = 100.0
LOG_STD_MIN = -20.0
LOG_STD_MAX = 2.0
ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate")
ENVS_ABBR = {'medium-replay': 'mr', 'medium': 'm', 'medium-expert': 'me'}


class TrainConfig:
    # Experiment
    device: str = "cuda"
    env_sim_name: str = "halfcheetah"
    level: str = "medium"
    env: str = f"{env_sim_name}-{level}-v2"  # OpenAI gym environment name
    seed: int = 0  # Sets Gym, PyTorch and Numpy seeds
    eval_seed: int = 0  # Eval environment seed
    eval_freq: int = int(2e3)  # How often (time steps) we evaluate
    save_freq: int = int(1e4)
    n_episodes: int = 100  # How many episodes run during evaluation
    offline_iterations: int = int(300000)  # Number of offline updates
    online_iterations: int = int(500000)  # Number of online updates
    checkpoints_path: Optional[str] = f'checkpoints/{env}'  # Save path
    load_model: str = f"../offline/checkpoints/{env}/checkpoint_999999.pt"  # Model load file name, "" doesn't load
    vae_file: str = f"../../vae_checkpoints/{env}/checkpoint_999999.pt"  # Model load file name, "" doesn't load
    # IQL
    actor_dropout: float = 0.0  # Dropout in actor network
    buffer_size: int = 50000  # Replay buffer size
    vae_hidden_dim: int = 400  # hidden dimension of vae network
    batch_size: int = 256  # Batch size for all networks
    discount: float = 0.99  # Discount factor
    tau: float = 0.005  # Target network update rate
    beta: float = 3.0  # Inverse temperature. Small beta -> BC, big beta -> maximizing Q
    iql_tau: float = 0.7  # Coefficient for asymmetric loss
    expl_noise: float = 0.03  # Std of Gaussian exploration noise
    noise_clip: float = 0.5  # Range to clip noise
    iql_deterministic: bool = False  # Use deterministic actor
    use_off_policy: bool = False  # Use deterministic actor
    normalize: bool = True  # Normalize states
    normalize_reward: bool = False  # Normalize reward
    vf_lr: float = 3e-4  # V function learning rate
    qf_lr: float = 3e-4  # Critic learning rate
    actor_lr: float = 3e-4  # Actor learning rate
    # Wandb logging
    project: str = "CORL"
    group: str = f"{ENVS_ABBR[level]}-{env_sim_name}-increment"
    name: str = f"best-0-{env}"

    # def __post_init__(self):
    #     self.name = f"{self.name}-{self.env}"
    #     if self.checkpoints_path is not None:
    #         self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)


def soft_update(target: nn.Module, source: nn.Module, tau: float):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)


def to_tensor(data, device) -> torch.Tensor:
    return torch.tensor(data, dtype=torch.float32).to(device)


def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
    mean = states.mean(0)
    std = states.std(0) + eps
    return mean, std


def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
    return (states - mean) / std


def wrap_env(
    env: gym.Env,
    state_mean: Union[np.ndarray, float] = 0.0,
    state_std: Union[np.ndarray, float] = 1.0,
    reward_scale: float = 1.0,
) -> gym.Env:
    def normalize_state(state):
        return (
            state - state_mean
        ) / state_std  # epsilon should be already added in std.

    def scale_reward(reward):
        # Please be careful, reward is multiplied by scale!
        return reward_scale * reward

    env = gym.wrappers.TransformObservation(env, normalize_state)
    if reward_scale != 1.0:
        env = gym.wrappers.TransformReward(env, scale_reward)
    return env


class ReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        device: str = "cpu",
        incre: bool = False,
    ):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0
        self.init_size = 2000
        self.incre = incre

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        if self.incre:
            self._weights = torch.ones((buffer_size, 3), dtype=torch.float32, device=device)
        self._device = device
        self.eps = 1e-3
        self.ood_pointer = 0
        self.ood_size = 10000
        self._update_size = 500

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32).to(self._device)

    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        self.n_transitions = n_transitions
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)

        print(f"Dataset size: {n_transitions}")

    # Add extra data in d4rl format, i.e. from Dict[str, np.array].
    def add_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        print(self._pointer)
        n_transitions = data["observations"].shape[0]
        if n_transitions + self._pointer > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[self._pointer:self._pointer + n_transitions] = self._to_tensor(data["observations"])
        self._actions[self._pointer:self._pointer + n_transitions] = self._to_tensor(data["actions"])
        self._rewards[self._pointer:self._pointer + n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[self._pointer:self._pointer + n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[self._pointer:self._pointer + n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, self._pointer + n_transitions)
        self.length = self._pointer + n_transitions
        print(f"Dataset size: {n_transitions}")
        print(self._pointer)
        print("Added success.\n")

    def filter_resolved_samples(self):
        try:
            ood_indices = (self._weights[self.ood_pointer * self.ood_size:(self.ood_pointer + 1) * self.ood_size, 0] == 0).nonzero()
        except ValueError:
            raise ValueError("OOD indices out of index")
        ood_indices += self.ood_pointer * self.ood_size
        print(ood_indices.shape)
        states = self._states[ood_indices, :].squeeze()
        actions = self._actions[ood_indices, :].squeeze()
        rewards = self._rewards[ood_indices, :]
        next_states = self._next_states[ood_indices, :].squeeze()
        dones = self._dones[ood_indices, :]
        weights = self._weights[ood_indices, :].squeeze()
        if len(weights.shape) == 1:
            weights = weights.unsqueeze(0)
            states = states.unsqueeze(0)
            actions = actions.unsqueeze(0)
            next_states = next_states.unsqueeze(0)
            rewards = rewards.unsqueeze(0)
            dones = dones.unsqueeze(0)

        resolved_indices = torch.argsort(weights[:, 2])[:self._update_size]
        print(resolved_indices.shape)
        print('rewards and dones', rewards.shape, dones.shape)
        states = states[resolved_indices, :]
        actions = actions[resolved_indices, :]
        rewards = rewards[resolved_indices, :]
        next_states = next_states[resolved_indices, :]
        dones = dones[resolved_indices, :]

        self.ood_pointer += 1
        if self.ood_pointer * self.ood_size == self._buffer_size:
            print('ood index rounded')
            self.ood_pointer = 0

        return [states.squeeze(), actions.squeeze(), rewards.squeeze(), next_states.squeeze(), dones.squeeze()]

    def filter_in_distribution(self):
        try:
            in_indices = (self._weights[self.ood_pointer * self.ood_size:(self.ood_pointer + 1) * self.ood_size, 0] != 0).nonzero()
        except ValueError:
            raise ValueError("In distribution indices out of index")
        states = self._states[in_indices, :].squeeze()
        actions = self._actions[in_indices, :].squeeze()
        next_states = self._next_states[in_indices, :].squeeze()

        return [states.squeeze(), actions.squeeze(), next_states.squeeze()]

    def weighted_random_sample(self, weights, batch_size):
        this_weights = weights.cpu().squeeze().numpy()[:self._size]
        Ni = np.sum(this_weights) + self.eps
        w_i = this_weights / Ni
        result = np.random.choice(np.arange(0, self._size), size=batch_size, replace=True, p=w_i)

        return result

    def sample(self, batch_size: int) -> TensorBatch:
        if self.incre:
            indices = np.random.randint(0, self._size, size=batch_size)
        else:
            indices = np.random.randint(0, self._size, size=batch_size)
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        return [states, actions, rewards, next_states, dones]

    def add_transition(
        self,
        state: np.ndarray,
        action: np.ndarray,
        reward: float,
        next_state: np.ndarray,
        done: bool,
        weight: 0,
    ):
        # Use this method to add new data into the replay buffer during fine-tuning.
        self._states[self._pointer] = self._to_tensor(state)
        self._actions[self._pointer] = self._to_tensor(action)
        self._rewards[self._pointer] = self._to_tensor(reward)
        self._next_states[self._pointer] = self._to_tensor(next_state)
        self._dones[self._pointer] = self._to_tensor(done)
        self._weights[self._pointer] = self._to_tensor(weight)

        self._pointer = (self._pointer + 1) % self._buffer_size
        self._size = min(self._size + 1, self._buffer_size)

    # todo: how to improve this
    def initialize(self, env, trainer, config, max_action, max_steps):
        state, done = env.reset(), False
        while self._size < self.init_size:
            episode_step = 0
            while not done:
                episode_step += 1
                action = trainer.actor(
                    torch.tensor(
                        state.reshape(1, -1), device=config.device, dtype=torch.float32
                    )
                )
                if not config.iql_deterministic:
                    action = action.sample()
                else:
                    noise = (torch.randn_like(action) * config.expl_noise).clamp(
                        -config.noise_clip, config.noise_clip
                    )
                    action += noise
                action = torch.clamp(max_action * action, -max_action, max_action)
                action = action.cpu().data.numpy().flatten()
                next_state, reward, done, env_infos = env.step(action)
                real_done = False  # Episode can timeout which is different from done
                if done and episode_step < max_steps:
                    real_done = True
                weights = trainer.get_weights(self._to_tensor(state).unsqueeze(0), self._to_tensor(action).unsqueeze(0),
                                              reward, self._to_tensor(next_state).unsqueeze(0),
                                              self._to_tensor(np.array(real_done)).unsqueeze(0))
                self.add_transition(state, action, reward, next_state, done, weights)
                state = next_state
                if self._size == self.init_size:
                    break
            if done:
                state, done = env.reset(), False

    def __len__(self):
        return self.n_transitions

    def __getitem__(self, indices):
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        return [states, actions, rewards, next_states, dones]


def set_env_seed(env: Optional[gym.Env], seed: int):
    env.seed(seed)
    env.action_space.seed(seed)


def set_seed(
    seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
    if env is not None:
        set_env_seed(env, seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


def wandb_init(config) -> None:
    wandb.init(
        # config=config.dict(),
        project=config.project,
        group=config.group,
        name=config.name,
        # id=str(uuid.uuid4()),
        settings=wandb.Settings(start_method="thread")
    )
    wandb.run.save()


def is_goal_reached(reward: float, info: Dict) -> bool:
    if "goal_achieved" in info:
        return info["goal_achieved"]
    return reward > 0  # Assuming that reaching target is a positive reward


@torch.no_grad()
def eval_actor(
    env: gym.Env, actor: nn.Module, device: str, n_episodes: int, seed: int
) -> Tuple[np.ndarray, np.ndarray]:
    env.seed(seed)
    actor.eval()
    episode_rewards = []
    successes = []
    for _ in range(n_episodes):
        state, done = env.reset(), False
        episode_reward = 0.0
        goal_achieved = False
        while not done:
            action = actor.act(state, device)
            state, reward, done, env_infos = env.step(action)
            episode_reward += reward
            if not goal_achieved:
                goal_achieved = is_goal_reached(reward, env_infos)
        # Valid only for environments with goal
        successes.append(float(goal_achieved))
        episode_rewards.append(episode_reward)

    actor.train()
    return np.asarray(episode_rewards), np.mean(successes)


def return_reward_range(dataset: Dict, max_episode_steps: int) -> Tuple[float, float]:
    returns, lengths = [], []
    ep_ret, ep_len = 0.0, 0
    for r, d in zip(dataset["rewards"], dataset["terminals"]):
        ep_ret += float(r)
        ep_len += 1
        if d or ep_len == max_episode_steps:
            returns.append(ep_ret)
            lengths.append(ep_len)
            ep_ret, ep_len = 0.0, 0
    lengths.append(ep_len)  # but still keep track of number of steps
    assert sum(lengths) == len(dataset["rewards"])
    return min(returns), max(returns)


def modify_reward(dataset: Dict, env_name: str, max_episode_steps: int = 1000) -> Dict:
    if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
        min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
        dataset["rewards"] /= max_ret - min_ret
        dataset["rewards"] *= max_episode_steps
        return {
            "max_ret": max_ret,
            "min_ret": min_ret,
            "max_episode_steps": max_episode_steps,
        }
    elif "antmaze" in env_name:
        dataset["rewards"] -= 1.0
    return {}


def modify_reward_online(reward: float, env_name: str, **kwargs) -> float:
    if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
        reward /= kwargs["max_ret"] - kwargs["min_ret"]
        reward *= kwargs["max_episode_steps"]
    elif "antmaze" in env_name:
        reward -= 1.0
    return reward


def asymmetric_l2_loss(u: torch.Tensor, tau: float) -> torch.Tensor:
    return torch.mean(torch.abs(tau - (u < 0).float()) * u**2)


class Squeeze(nn.Module):
    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.squeeze(dim=self.dim)


class VAE(nn.Module):
    """
        Variational Auto-Encoder

        Args:
            obs_dim (int): The dimension of the observation space.
            act_dim (int): The dimension of the action space.
            hidden_size (int): The number of hidden units in the encoder and decoder networks, default=64.
            latent_dim (int): The dimensionality of the latent space.
            act_lim (float): The upper limit of the action space.
            device (str): The device to use for computation (cpu or cuda).
        """

    def __init__(self, obs_dim, act_dim, hidden_size, latent_dim, act_lim, obs_lim, device='cuda'):
        super(VAE, self).__init__()
        self.device = device
        self.latent_dim = latent_dim
        self.e1 = nn.Linear(obs_dim + act_dim, hidden_size).to(self.device)
        self.e2 = nn.Linear(hidden_size, hidden_size).to(self.device)

        self.mean = nn.Linear(hidden_size, latent_dim).to(self.device)
        self.log_std = nn.Linear(hidden_size, latent_dim).to(self.device)

        self.d1 = nn.Linear(obs_dim + act_dim + latent_dim, hidden_size).to(self.device)
        self.d2 = nn.Linear(hidden_size, hidden_size).to(self.device)
        self.d3 = nn.Linear(hidden_size, obs_dim).to(self.device)

        self.act_lim = act_lim
        self.obs_lim = obs_lim

    def forward(self, obs, act, next_obs):
        z, mean, std = self.encoder(obs, act, next_obs)
        u = self.decode(obs, act, next_obs, z)
        return u, mean, std

    def encoder(self, obs, act, next_obs):
        z = F.relu(self.e1(torch.cat([obs, act], 1)))
        z = F.relu(self.e2(z))

        mean = self.mean(z)
        # clamp for numerical stability
        log_std = self.log_std(z).clamp(-4, 15)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        return z, mean, std

    def decode(self, obs, act, next_obs, z=None):
        if z is None:
            z = torch.randn((obs.shape[0], self.latent_dim)).clamp(-0.5, 0.5).to(self.device)
        s = F.relu(self.d1(torch.cat([obs, act, z], 1)))
        s = F.relu(self.d2(s))

        return self.d3(s)

    # for BEARL only
    def decode_multiple(self, obs, z=None, num_decode=10):
        if z is None:
            z = torch.randn(
                (obs.shape[0], num_decode, self.latent_dim)).clamp(-0.5,
                                                                   0.5).to(self.device)

        a = F.relu(
            self.d1(
                torch.cat(
                    [obs.unsqueeze(0).repeat(num_decode, 1, 1).permute(1, 0, 2), z], 2)))
        a = F.relu(self.d2(a))
        return torch.tanh(self.d3(a)), self.d3(a)

    def test_one_step(self, observations, next_observations, actions, rewards, done):
        # update VAE
        with torch.no_grad():
            _, mean, std = self.encoder(observations, actions, next_observations)
            return mean, std

    def judge_normal_distribution(self, replay_buffer):
        data_mean_list = []
        data_std_list = []
        for i in range(len(replay_buffer)):
            batch = replay_buffer[i]
            (
                observations,
                actions,
                rewards,
                next_observations,
                dones,
            ) = batch
            observations = observations.unsqueeze(0)
            actions = actions.unsqueeze(0)
            next_observations = next_observations.unsqueeze(0)
            mean, std = self.test_one_step(observations, next_observations, actions, rewards, dones)
            data_mean_list.extend(mean.cpu().squeeze())
            data_std_list.extend(std.cpu().squeeze())
        import fitter
        mean_f = fitter.Fitter(data_mean_list, distributions=['norm'], timeout=10000)
        mean_f.fit()
        mean_dict = mean_f.fitted_param['norm']

        std_f = fitter.Fitter(data_std_list, distributions=['norm'], timeout=10000)
        std_f.fit()
        std_dict = std_f.fitted_param['norm']
        print(mean_dict, std_dict)

        return mean_dict, std_dict


class MLP(nn.Module):
    def __init__(
        self,
        dims,
        activation_fn: Callable[[], nn.Module] = nn.ReLU,
        output_activation_fn: Callable[[], nn.Module] = None,
        squeeze_output: bool = False,
        dropout: float = 0.0,
    ):
        super().__init__()
        n_dims = len(dims)
        if n_dims < 2:
            raise ValueError("MLP requires at least two dims (input and output)")

        layers = []
        for i in range(n_dims - 2):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            layers.append(activation_fn())
            if dropout > 0.0:
                layers.append(nn.Dropout(0.0))
        layers.append(nn.Linear(dims[-2], dims[-1]))
        if output_activation_fn is not None:
            layers.append(output_activation_fn())
        if squeeze_output:
            if dims[-1] != 1:
                raise ValueError("Last dim must be 1 when squeezing")
            layers.append(Squeeze(-1))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class GaussianPolicy(nn.Module):
    def __init__(
        self,
        state_dim: int,
        act_dim: int,
        max_action: float,
        hidden_dim: int = 256,
        n_hidden: int = 2,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.net = MLP(
            [state_dim, *([hidden_dim] * n_hidden), act_dim],
            output_activation_fn=nn.Tanh,
            dropout=dropout,
        )
        self.log_std = nn.Parameter(torch.zeros(act_dim, dtype=torch.float32))
        self.max_action = max_action

    def forward(self, obs: torch.Tensor) -> Normal:
        mean = self.net(obs)
        std = torch.exp(self.log_std.clamp(LOG_STD_MIN, LOG_STD_MAX))
        return Normal(mean, std)

    @torch.no_grad()
    def act(self, state: np.ndarray, device: str = "cpu"):
        state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32)
        dist = self(state)
        # action = dist.sample()
        action = dist.mean if not self.training else dist.sample()
        action = torch.clamp(self.max_action * action, -self.max_action, self.max_action)
        return action.cpu().data.numpy().flatten()


class DeterministicPolicy(nn.Module):
    def __init__(
        self,
        state_dim: int,
        act_dim: int,
        max_action: float,
        hidden_dim: int = 256,
        n_hidden: int = 2,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.net = MLP(
            [state_dim, *([hidden_dim] * n_hidden), act_dim],
            output_activation_fn=nn.Tanh,
            dropout=1,
        )
        self.max_action = max_action

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.net(obs)

    @torch.no_grad()
    def act(self, state: np.ndarray, device: str = "cpu"):
        state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32)
        return (
            torch.clamp(self(state) * self.max_action, -self.max_action, self.max_action)
            .cpu()
            .data.numpy()
            .flatten()
        )


class TwinQ(nn.Module):
    def __init__(
        self, state_dim: int, action_dim: int, hidden_dim: int = 256, n_hidden: int = 2
    ):
        super().__init__()
        dims = [state_dim + action_dim, *([hidden_dim] * n_hidden), 1]
        self.q1 = MLP(dims, squeeze_output=True)
        self.q2 = MLP(dims, squeeze_output=True)

    def both(
        self, state: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        sa = torch.cat([state, action], 1)
        return self.q1(sa), self.q2(sa)

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        return torch.min(*self.both(state, action))


class ValueFunction(nn.Module):
    def __init__(self, state_dim: int, hidden_dim: int = 256, n_hidden: int = 2):
        super().__init__()
        dims = [state_dim, *([hidden_dim] * n_hidden), 1]
        self.v = MLP(dims, squeeze_output=True)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.v(state)


class ImplicitQLearning:
    def __init__(
        self,
        max_action: float,
        actor: nn.Module,
        actor_optimizer: torch.optim.Optimizer,
        q_network: nn.Module,
        q_optimizer: torch.optim.Optimizer,
        v_network: nn.Module,
        v_optimizer: torch.optim.Optimizer,
        state_dim,
        action_dim,
        vae_hidden_dim,
        vae_lr,
        use_off_policy=False,
        iql_tau: float = 0.7,
        beta: float = 3.0,
        max_steps: int = 1000000,
        discount: float = 0.99,
        tau: float = 0.005,
        device: str = "cpu",
    ):
        self.max_action = max_action
        self.qf = q_network
        self.offline_qf = copy.deepcopy(self.qf).requires_grad_(False).to(device)
        self.q_target = copy.deepcopy(self.qf).requires_grad_(False).to(device)
        self. vf = v_network
        self.offline_vf = copy.deepcopy(self.vf).requires_grad_(False).to(device)

        self.actor = actor
        self.offline_actor = copy.deepcopy(self.actor).requires_grad_(False).to(device)

        self.v_optimizer = v_optimizer
        self.q_optimizer = q_optimizer
        self.actor_optimizer = actor_optimizer

        self.actor_lr_schedule = CosineAnnealingLR(self.actor_optimizer, max_steps)
        self.iql_tau = iql_tau
        self.beta = beta
        self.discount = discount
        self.tau = tau
        self.total_it = 0
        self.device = device

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.vae_hidden_dim = vae_hidden_dim
        self.latent_dim = state_dim + action_dim
        self.t = 0
        self.prob = torch.zeros(256)
        self.use_off_policy = use_off_policy
        self.pm = 0.7
        self.vae = VAE(self.state_dim, self.action_dim, self.vae_hidden_dim,
                       self.latent_dim, self.max_action, self.device).to(self.device)
        self.vae_optim = torch.optim.Adam(self.vae.parameters(), lr=vae_lr)
        self.prob_mean, self.prob_std = (), ()
        self.offline_time = 10000
        self.count = 0
        self.vae_batch_size = 64
        self.vae_step = 0

    def _update_v(self, observations, actions, log_dict) -> torch.Tensor:
        # Update value function
        with torch.no_grad():
            target_q = self.q_target(observations, actions)

        v = self.vf(observations)
        adv = target_q - v
        v_loss = asymmetric_l2_loss(adv, self.iql_tau)
        log_dict["value_loss"] = v_loss.item()
        self.v_optimizer.zero_grad()
        v_loss.backward()
        self.v_optimizer.step()
        return adv

    def _update_off_v(self, observations, actions, log_dict) -> torch.Tensor:
        # Update value function
        target_q = self.offline_qf(observations, actions)

        v = self.offline_vf(observations)
        adv = target_q - v
        return adv

    def _update_q(
        self,
        next_v: torch.Tensor,
        offline_next_v: torch.Tensor,
        observations: torch.Tensor,
        next_observations: torch.Tensor,
        actions: torch.Tensor,
        rewards: torch.Tensor,
        terminals: torch.Tensor,
        log_dict: Dict,
    ):
        targets = rewards + (1.0 - terminals.float()) * self.discount * next_v.detach()
        qs = self.qf.both(observations, actions)
        with torch.no_grad():
            log_dict['q0'] = torch.sum(qs[0])
            log_dict['q1'] = torch.sum(qs[1])
            log_dict['online_target'] = targets
        q_error0 = qs[0] - targets
        q_error1 = qs[1] - targets

        with torch.no_grad():
            offline_next_q = self.offline_qf.both(observations, actions)
            off_targets1 = rewards + (1.0 - terminals.float()) * self.discount * offline_next_q[0].detach()
            off_targets2 = rewards + (1.0 - terminals.float()) * self.discount * offline_next_q[1].detach()
            off_q_error0 = qs[0] - off_targets1
            off_q_error1 = qs[1] - off_targets2
            log_dict['off_q_error_0'] = off_q_error0.sum()
            log_dict['off_q_error_1'] = off_q_error1.sum()

            z, mean, std = self.vae.encoder(observations, actions, next_observations)
            prob = stats.norm.cdf((self.prob_mean[0] - abs(mean.cpu() - self.prob_mean[0])), self.prob_mean[0],
                                  self.prob_mean[1]) * 2
            prob = torch.tensor(prob, dtype=torch.float32).to(self.device)
            prob = torch.mean(prob, dim=1).unsqueeze(1)
            zeros = torch.zeros_like(prob)
            prob = torch.where(prob < self.pm, zeros, prob)
            log_dict['prob'] = prob.mean()
            self.prob = prob
        q_loss0 = torch.square((1 - prob) * q_error0 + prob * off_q_error0)
        q_loss1 = torch.square((1 - prob) * q_error1 + prob * off_q_error1)
        q_loss = (q_loss0 + q_loss1) / 2
        q_loss = torch.mean(q_loss)
        log_dict["q_loss"] = torch.mean(q_loss).item()
        self.q_optimizer.zero_grad()
        q_loss.backward()
        self.q_optimizer.step()
        # Update target Q network
        soft_update(self.q_target, self.qf, self.tau)

    def _update_policy(
        self,
        adv: torch.Tensor,
        observations: torch.Tensor,
        actions: torch.Tensor,
        log_dict: Dict,
    ):
        exp_adv = torch.exp(self.beta * adv.detach()).clamp(max=EXP_ADV_MAX)
        policy_out = self.actor(observations)
        if isinstance(policy_out, torch.distributions.Distribution):
            bc_losses = -policy_out.log_prob(actions).sum(-1, keepdim=False)
        elif torch.is_tensor(policy_out):
            if policy_out.shape != actions.shape:
                raise RuntimeError("Actions shape missmatch")
            bc_losses = torch.sum((policy_out - actions) ** 2, dim=1)
        else:
            raise NotImplementedError

        policy_loss = torch.mean(exp_adv * bc_losses)
        log_dict["actor_loss"] = policy_loss.item()
        self.actor_optimizer.zero_grad()
        policy_loss.backward()
        self.actor_optimizer.step()
        self.actor_lr_schedule.step()

    def train(self, batch: TensorBatch) -> Dict[str, float]:
        self.total_it += 1
        (
            observations,
            actions,
            rewards,
            next_observations,
            dones,
        ) = batch
        log_dict = {}
        self.t += 1
        wandb.log({'time': self.t}, commit=False)
        with torch.no_grad():
            next_v = self.vf(next_observations)
            offline_next_v = self.offline_vf(next_observations)
        # Update value function
        adv = self._update_v(observations, actions, log_dict)
        rewards = rewards.squeeze(dim=-1)
        dones = dones.squeeze(dim=-1)
        # Update Q function
        self._update_q(next_v, offline_next_v, observations, next_observations, actions, rewards, dones, log_dict)
        self._update_policy(adv, observations, actions, log_dict)
        return log_dict

    def vae_loss(self, obs, act, next_observations, t):
        recon, mean, std = self.vae(obs, act, next_observations)
        recon_loss = nn.functional.mse_loss(recon, next_observations)
        KL_loss = - 0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()

        loss_vae = recon_loss + 20 * KL_loss

        self.vae_optim.zero_grad()
        loss_vae.backward()
        self.vae_optim.step()
        stats_vae = {"loss/loss_vae": loss_vae.item(),
                     "loss/KL_loss": KL_loss.item(),
                     "loss/recon_loss": recon_loss.item(),
                     # "beta": self.this_beta
                     }
        print(stats_vae)
        return loss_vae, stats_vae

    def train_one_step(self, observations, next_observations, actions, rewards, done, t):
        # update VAE
        loss_vae, log_dict = self.vae_loss(observations, actions, next_observations, t)

        return log_dict

    def update_vae(self, buffer):
        batch = buffer.filter_resolved_samples()
        batch = [b.to(config.device) for b in batch]
        (
            observations,
            actions,
            rewards,
            next_observations,
            dones,
        ) = batch
        # print(indices)
        if len(observations.shape) == 1:
            observations = observations.unsqueeze(0)
            actions = actions.unsqueeze(0)
            rewards = rewards.unsqueeze(0)
            next_observations = next_observations.unsqueeze(0)
            dones = dones.unsqueeze(0)

        self.offline_vf = copy.deepcopy(self.vf).to(self.device)
        self.offline_qf = copy.deepcopy(self.qf).to(self.device)

        print('observation shape', observations.shape)
        print('action shape', actions.shape)
        print('next observation shape', next_observations.shape)
        if observations.shape[0] >= self.vae_batch_size:
            for epoch in range(observations.shape[0] // self.vae_batch_size):
                indices = np.random.randint(0, observations.shape[0], size=self.vae_batch_size)
                log_dict = self.train_one_step(observations[indices].to(self.device),
                                               next_observations[indices].to(self.device),
                                               actions[indices].to(self.device),
                                               rewards[indices].to(self.device),
                                               dones[indices].to(self.device),
                                               0)
                log_dict['vae_step'] = self.vae_step
                wandb.log(log_dict, commit=False)
                self.vae_step += 1
        else:
            print("No extra OOD samples detected.")

    def update_vae_parameters(self, observations, actions, next_observations):
        data_mean_list = []
        data_std_list = []
        for i in range(len(observations)):
            observation = observations[i].unsqueeze(0)
            action = actions[i].unsqueeze(0)
            next_observation = next_observations[i].unsqueeze(0)
            mean, std = self.vae.test_one_step(observation, next_observation, action, 1, 0)
            data_mean_list.extend(mean.cpu().squeeze())
            data_std_list.extend(std.cpu().squeeze())
        import fitter
        mean_f = fitter.Fitter(data_mean_list, distributions=['norm'], timeout=10000)
        mean_f.fit()
        mean_dict = mean_f.fitted_param['norm']

        std_f = fitter.Fitter(data_std_list, distributions=['norm'], timeout=10000)
        std_f.fit()
        std_dict = std_f.fitted_param['norm']
        print(mean_dict, std_dict)

        return mean_dict, std_dict

    def state_dict(self) -> Dict[str, Any]:
        return {
            "qf": self.qf.state_dict(),
            "q_optimizer": self.q_optimizer.state_dict(),
            "vf": self.vf.state_dict(),
            "v_optimizer": self.v_optimizer.state_dict(),
            "actor": self.actor.state_dict(),
            "actor_optimizer": self.actor_optimizer.state_dict(),
            "actor_lr_schedule": self.actor_lr_schedule.state_dict(),
            "total_it": self.total_it,
        }

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.qf.load_state_dict(state_dict["qf"])
        self.q_optimizer.load_state_dict(state_dict["q_optimizer"])
        self.q_target = copy.deepcopy(self.qf)
        self.offline_qf = copy.deepcopy(self.qf)

        self.vf.load_state_dict(state_dict["vf"])
        self.offline_vf = copy.deepcopy(self.vf)
        self.v_optimizer.load_state_dict(state_dict["v_optimizer"])

        self.actor.load_state_dict(state_dict["actor"])
        self.offline_actor = copy.deepcopy(self.actor)

        self.actor_optimizer.load_state_dict(state_dict["actor_optimizer"])
        self.actor_lr_schedule.load_state_dict(state_dict["actor_lr_schedule"])

        self.total_it = state_dict["total_it"]

    @torch.no_grad()
    def get_weights(self,
                    observations,
                    actions,
                    rewards,
                    next_observations,
                    dones,
                    ):
        with torch.no_grad():
            z, vae_mean, std = self.vae.encoder(observations, actions, next_observations)
            prob = stats.norm.cdf((self.prob_mean[0] - abs(vae_mean.cpu() - self.prob_mean[0])), self.prob_mean[0],
                                  self.prob_mean[1]) * 2
            prob = torch.tensor(prob, dtype=torch.float32).to(self.device)
            prob = torch.mean(prob, dim=1).unsqueeze(1)
            zeros = torch.zeros_like(prob)
            prob = torch.where(prob < self.pm, zeros, prob)

            qs = self.qf.both(observations, actions)
            offline_next_v = self.offline_vf(next_observations)
            off_targets = rewards + (1.0 - dones.float()) * self.discount * offline_next_v.detach()
            off_q_error0 = qs[0] - off_targets
            off_q_error1 = qs[1] - off_targets
            off_values = torch.mean(torch.abs(off_q_error0.sum()) + torch.abs(off_q_error1.sum()))

            online_next_v = self.vf(next_observations)
            online_targets = rewards + (1.0 - dones.float()) * self.discount * online_next_v.detach()
            off_q_error0 = qs[0] - online_targets
            off_q_error1 = qs[1] - online_targets
            online_values = torch.mean(torch.abs(off_q_error0.sum()) + torch.abs(off_q_error1.sum()))
        return torch.tensor([prob, off_values, online_values])


def train(config):
    env = gym.make(config.env)
    eval_env = gym.make(config.env)

    is_env_with_goal = config.env.startswith(ENVS_WITH_GOAL)

    max_steps = env._max_episode_steps

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    dataset = d4rl.qlearning_dataset(env)

    reward_mod_dict = {}
    if config.normalize_reward:
        reward_mod_dict = modify_reward(dataset, config.env)

    if config.normalize:
        state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    else:
        state_mean, state_std = 0, 1

    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )
    env = wrap_env(env, state_mean=state_mean, state_std=state_std)
    eval_env = wrap_env(eval_env, state_mean=state_mean, state_std=state_std)
    replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
        True
    )

    offline_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        2_000_000,
        config.device,
    )
    offline_buffer.load_d4rl_dataset(dataset)

    max_action = float(env.action_space.high[0])

    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)

    # Set seeds
    seed = config.seed
    set_seed(seed, env)
    set_env_seed(eval_env, config.eval_seed)

    q_network = TwinQ(state_dim, action_dim).to(config.device)
    v_network = ValueFunction(state_dim).to(config.device)
    actor = (
        DeterministicPolicy(
            state_dim, action_dim, max_action, dropout=config.actor_dropout
        )
        if config.iql_deterministic
        else GaussianPolicy(
            state_dim, action_dim, max_action, dropout=config.actor_dropout
        )
    ).to(config.device)
    v_optimizer = torch.optim.Adam(v_network.parameters(), lr=config.vf_lr)
    q_optimizer = torch.optim.Adam(q_network.parameters(), lr=config.qf_lr)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_lr)

    kwargs = {
        "max_action": max_action,
        "actor": actor,
        "actor_optimizer": actor_optimizer,
        "q_network": q_network,
        "q_optimizer": q_optimizer,
        "v_network": v_network,
        "v_optimizer": v_optimizer,
        "state_dim": state_dim,
        "action_dim": action_dim,
        "vae_hidden_dim": config.vae_hidden_dim,
        "discount": config.discount,
        "tau": config.tau,
        "vae_lr": config.vae_lr,
        "device": config.device,
        # IQL
        "beta": config.beta,
        "iql_tau": config.iql_tau,
        "max_steps": config.offline_iterations,
        "use_off_policy": config.use_off_policy
    }

    print("---------------------------------------")
    print(f"Training IQL, Env: {config.env}, Seed: {seed}")
    print("---------------------------------------")

    # Initialize actor
    trainer = ImplicitQLearning(**kwargs)

    if config.load_model != "":
        policy_file = Path(config.load_model)
        trainer.load_state_dict(torch.load(policy_file))
        trainer.vae = torch.load(config.vae_file)['vae']
        trainer.vae_optim = torch.load(config.vae_file)['vae_optim']
        actor = trainer.actor

    trainer.prob_mean, trainer.prob_std = trainer.vae.judge_normal_distribution(offline_buffer)
    print('vae information', trainer.prob_mean, trainer.prob_std, '\n\n')

    replay_buffer.initialize(env, trainer, config, max_action, max_steps)
    print(replay_buffer._pointer)
    evaluations = []

    state, done = env.reset(), False
    episode_return = 0
    episode_step = 0
    goal_achieved = False

    eval_successes = []
    train_successes = []

    print("Online incremental tuning")
    for t in range(int(config.online_iterations)):
        online_log = {}
        episode_step += 1
        action = actor(
            torch.tensor(
                state.reshape(1, -1), device=config.device, dtype=torch.float32
            )
        )
        if not config.iql_deterministic:
            action = action.sample()
        else:
            noise = (torch.randn_like(action) * config.expl_noise).clamp(
                -config.noise_clip, config.noise_clip
            )
            action += noise
        action = torch.clamp(max_action * action, -max_action, max_action)
        action = action.cpu().data.numpy().flatten()
        next_state, reward, done, env_infos = env.step(action)

        if not goal_achieved:
            goal_achieved = is_goal_reached(reward, env_infos)
        episode_return += reward

        real_done = False  # Episode can timeout which is different from done
        if done and episode_step < max_steps:
            real_done = True

        if config.normalize_reward:
            reward = modify_reward_online(reward, config.env, **reward_mod_dict)
        weights = trainer.get_weights(to_tensor(state, config.device).unsqueeze(0),
                                      to_tensor(action, config.device).unsqueeze(0),
                                      reward, to_tensor(next_state, config.device).unsqueeze(0),
                                      to_tensor(np.array(real_done), config.device).unsqueeze(0))
        replay_buffer.add_transition(state, action, reward, next_state, real_done, weights)
        state = next_state

        if trainer.t == trainer.offline_time:
            if trainer.count >= 4:
                print('updating vae and set t to zero')
                print(f'Offline model updated on epoch {t}')
                trainer.update_vae(replay_buffer)

            trainer.count += 1
            trainer.t = 0

        # print(t, trainer.t, trainer.offline_time)

        if done:
            state, done = env.reset(), False
            # Valid only for envs with goal, e.g. AntMaze, Adroit
            if is_env_with_goal:
                train_successes.append(goal_achieved)
                online_log["train/regret"] = np.mean(1 - np.array(train_successes))
                online_log["train/is_success"] = float(goal_achieved)
            online_log["train/episode_return"] = episode_return
            normalized_return = eval_env.get_normalized_score(episode_return)
            online_log["train/d4rl_normalized_episode_return"] = (
                normalized_return * 100.0
            )
            online_log["train/episode_length"] = episode_step
            episode_return = 0
            episode_step = 0
            goal_achieved = False

        batch = replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        log_dict = trainer.train(batch)
        log_dict["online_iter"] = (
            t
        )
        log_dict.update(online_log)
        wandb.log(log_dict)
        # Evaluate episode
        if t % config.eval_freq == 0:
            print(f"Time steps: {t + 1}")
            eval_scores, success_rate = eval_actor(
                eval_env,
                actor,
                device=config.device,
                n_episodes=config.n_episodes,
                seed=config.seed,
            )
            eval_score = eval_scores.mean()
            eval_log = {}
            normalized = eval_env.get_normalized_score(eval_score)
            # Valid only for envs with goal, e.g. AntMaze, Adroit
            if is_env_with_goal:
                eval_successes.append(success_rate)
                eval_log["eval/regret"] = np.mean(1 - np.array(train_successes))
                eval_log["eval/success_rate"] = success_rate
            normalized_eval_score = normalized * 100.0
            evaluations.append(normalized_eval_score)
            eval_log["eval/d4rl_normalized_score"] = normalized_eval_score
            print("---------------------------------------")
            print(
                f"Evaluation over {config.n_episodes} episodes: "
                f"{eval_score:.3f} , D4RL score: {normalized_eval_score:.3f}"
            )
            print("---------------------------------------")

            # wandb.log(eval_log)
            wandb.log({"mean_return": eval_score, "eval/d4rl_normalized_score": normalized_eval_score}, commit=False)

            if t % config.save_freq == 0 and t > 450000:
                if config.checkpoints_path is not None:
                    torch.save(
                        trainer.state_dict(),
                        os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
                    )


if __name__ == "__main__":
    for seed in range(0, 5, 1):
        env = 'antmaze-large-diverse-v2'
        config = get_iql_train_configs(env, 'algo-IQL', f'incremental-{seed}', seed)
        wandb_init(config)
        train(config)
        wandb.finish()
        sys.stdout.flush()
