# Inspired by:
# 1. paper for SAC-N: https://arxiv.org/abs/2110.01548
# 2. implementation: https://github.com/snu-mllab/EDAC
from typing import Any, Dict, List, Optional, Tuple, Union
from copy import deepcopy
from dataclasses import asdict, dataclass
import math
import os
import random
import uuid

import d4rl
import gym
import numpy as np
import pyrallis
import torch
from torch.distributions import Normal
import torch.nn as nn
import torch.nn.functional as F
from tqdm import trange
import wandb

from video_recorder import VideoRecorder


@dataclass
class TrainConfig:
    # wandb params
    project: str = "offline-RL-init"
    group: str = "EDAC-D4RL"
    name: str = "EDAC"
    # model params
    hidden_dim: int = 256
    num_critics: int = 10
    gamma: float = 0.99
    tau: float = 5e-3
    eta: float = 1.0
    # pretrain_eta: float = 1.0
    actor_learning_rate: float = 3e-4
    critic_learning_rate: float = 3e-4
    alpha_learning_rate: float = 3e-4
    max_action: float = 1.0
    # training params
    buffer_size: int = 1_000_000
    env_name: str = "halfcheetah-medium-v2"
    batch_size: int = 256
    num_epochs: int = 3000
    num_updates_on_epoch: int = 1000
    normalize_reward: bool = False
    # evaluation params
    eval_episodes: int = 10
    eval_every: int = 5
    # general params
    checkpoints_path: Optional[str] = None
    deterministic_torch: bool = False
    train_seed: int = 10
    eval_seed: int = 42
    log_every: int = 100
    device: str = "cpu"
    pretrain: Optional[str] = None  # BC or softAC or softC
    pretrain_epochs: int = 10
    percentage_pretrain_BC: float = 0.5
    td_component: float = -1.0
    bc_regulariser: float = -1.0
    soft_bc_regulariser: float = -1.0
    pretrain_cql_regulariser: float = -1.0
    cql_regulariser: float = -1.0
    cql_n_actions: int = 10
    kl_regulariser: float = -1.0
    actor_kl_regulariser: float = -1.0
    actor_LN: bool = True
    critic_LN: bool = True
    render: bool = False

    def __post_init__(self):
        self.name = f"{self.name}-{self.env_name}-{str(uuid.uuid4())[:8]}"
        if self.checkpoints_path is not None:
            self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)


# general utils
TensorBatch = List[torch.Tensor]


def soft_update(target: nn.Module, source: nn.Module, tau: float):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)


def wandb_init(config: dict) -> None:
    wandb.init(
        config=config,
        project=config["project"],
        group=config["group"],
        name=config["name"],
    )
    wandb.run.log_code(".")
    wandb.run.save()


def set_seed(
    seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
    if env is not None:
        env.seed(seed)
        env.action_space.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
    mean = states.mean(0)
    std = states.std(0) + eps
    return mean, std


def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
    return (states - mean) / std


def discount_cumsum(x, discount, include_first=True):
    if isinstance(x, torch.Tensor):
        disc_cumsum = torch.zeros_like(x).detach()
    else:
        disc_cumsum = np.zeros_like(x)
    if include_first:
        disc_cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0] - 1)):
            disc_cumsum[t] = x[t] + discount * disc_cumsum[t + 1]
    else:
        disc_cumsum[-1] = 0
        for t in reversed(range(x.shape[0] - 1)):
            disc_cumsum[t] = discount * x[t + 1] + discount * disc_cumsum[t + 1]
    return disc_cumsum


def wrap_env(
    env: gym.Env,
    state_mean: Union[np.ndarray, float] = 0.0,
    state_std: Union[np.ndarray, float] = 1.0,
    reward_scale: float = 1.0,
) -> gym.Env:
    def normalize_state(state):
        return (state - state_mean) / state_std

    def scale_reward(reward):
        return reward_scale * reward

    env = gym.wrappers.TransformObservation(env, normalize_state)
    if reward_scale != 1.0:
        env = gym.wrappers.TransformReward(env, scale_reward)
    return env


class ReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        batch_size: int,
        discount: float,
        device: str = "cpu",
    ):
        self._action_dim = action_dim
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._returns_to_go = torch.zeros(
            (buffer_size, 1), dtype=torch.float32, device=device
        )
        self._entropy_bonuses = torch.zeros(
            (buffer_size, 1), dtype=torch.float32, device=device
        )
        self._soft_returns_to_go = torch.zeros(
            (buffer_size, 1), dtype=torch.float32, device=device
        )
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._discount = discount
        self._batch_size = batch_size
        self._device = device

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32, device=self._device)

    def compute_returns_to_go(self, data: np.ndarray):
        n_transitions = data["observations"].shape[0]
        episode_rewards = []
        returns_to_go = []

        for i in range(n_transitions):
            episode_rewards.append(data["rewards"][i])
            if (
                data["terminals"][i] or data["timeouts"][i] or i == n_transitions - 1
            ):  # TODO: Could better handle incomplete trajectory case?
                episode_returns_to_go = discount_cumsum(
                    np.array(episode_rewards), self._discount
                )
                returns_to_go.append(episode_returns_to_go)
                episode_rewards = []

        returns_to_go = np.array(
            [
                return_to_go
                for episode_returns in returns_to_go
                for return_to_go in episode_returns
            ]
        ).flatten()

        self._returns_to_go[:n_transitions] = self._to_tensor(returns_to_go[..., None])

    def compute_soft_returns_to_go(self, alpha: torch.Tensor, actor: "Actor"):
        n_transitions = self._states.shape[0]
        episode_rewards = []
        episode_entropy_bonuses = []
        soft_returns_to_go = []

        with torch.no_grad():
            for i in range((n_transitions // self._batch_size) + 1):
                batch_states = self._states[
                    self._batch_size * i : min(self._batch_size * (i + 1), n_transitions)
                ]
                pi, log_pi = actor(
                    batch_states,
                    need_log_prob=True,
                )
                self._entropy_bonuses[
                    self._batch_size * i : min(self._batch_size * (i + 1), n_transitions)
                ] = -log_pi.detach().unsqueeze(-1)

        for i in range(n_transitions):
            episode_rewards.append(self._rewards[i].item())
            episode_entropy_bonuses.append(self._entropy_bonuses[i].item())
            if self._dones[i] or i == n_transitions - 1:
                episode_returns_to_go = discount_cumsum(
                    self._to_tensor(episode_rewards), self._discount
                ) + alpha.detach().item() * discount_cumsum(
                    self._to_tensor(episode_entropy_bonuses),
                    self._discount,
                    include_first=False,
                )
                soft_returns_to_go.append(episode_returns_to_go)
                episode_rewards = []
                episode_entropy_bonuses = []

        self._soft_returns_to_go[:n_transitions] = (
            self._to_tensor(
                [
                    return_to_go
                    for episode_returns in soft_returns_to_go
                    for return_to_go in episode_returns
                ]
            )
            .flatten()
            .unsqueeze(-1)
        )
        self._soft_returns_loaded = True

    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )

        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])

        self._dones[:n_transitions] = self._to_tensor(
            (data["terminals"] + data["timeouts"])[..., None]
        )
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)

        self.compute_returns_to_go(data)
        self._soft_returns_loaded = False

        print(f"Dataset size: {n_transitions}")

    def sample(self, batch_size: int) -> TensorBatch:
        indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        if self._soft_returns_loaded:
            returns_to_go = self._soft_returns_to_go[indices]
        else:
            returns_to_go = self._returns_to_go[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        return [
            states,
            actions,
            rewards,
            returns_to_go,
            next_states,
            dones,
        ]

    def add_transition(self):
        # Use this method to add new data into the replay buffer during fine-tuning.
        raise NotImplementedError


# SAC Actor & Critic implementation
class VectorizedLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int, ensemble_size: int):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ensemble_size = ensemble_size

        self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
        self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))

        self.reset_parameters()

    def reset_parameters(self):
        # default pytorch init for nn.Linear module
        for layer in range(self.ensemble_size):
            nn.init.kaiming_uniform_(self.weight[layer], a=math.sqrt(5))

        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # input: [ensemble_size, batch_size, input_size]
        # weight: [ensemble_size, input_size, out_size]
        # out: [ensemble_size, batch_size, out_size]
        return x @ self.weight + self.bias


class Actor(nn.Module):
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dim: int,
        max_action: float = 1.0,
        actor_LN: bool = True,
    ):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.LayerNorm(hidden_dim, elementwise_affine=False)
            if actor_LN
            else nn.Identity(),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim, elementwise_affine=False)
            if actor_LN
            else nn.Identity(),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim, elementwise_affine=False)
            if actor_LN
            else nn.Identity(),
            nn.ReLU(),
        )
        # with separate layers works better than with Linear(hidden_dim, 2 * action_dim)
        self.mu = nn.Linear(hidden_dim, action_dim)
        self.log_sigma = nn.Linear(hidden_dim, action_dim)

        # init as in the EDAC paper
        for layer in self.trunk[::3]:
            torch.nn.init.constant_(layer.bias, 0.1)

        torch.nn.init.uniform_(self.mu.weight, -1e-3, 1e-3)
        torch.nn.init.uniform_(self.mu.bias, -1e-3, 1e-3)
        torch.nn.init.uniform_(self.log_sigma.weight, -1e-3, 1e-3)
        torch.nn.init.uniform_(self.log_sigma.bias, -1e-3, 1e-3)

        self.action_dim = action_dim
        self.max_action = max_action

    def forward(
        self,
        state: torch.Tensor,
        deterministic: bool = False,
        need_log_prob: bool = False,
        need_policy_dist: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        hidden = self.trunk(state)
        mu, log_sigma = self.mu(hidden), self.log_sigma(hidden)
        # wandb.log({"policy_sigma": torch.exp(log_sigma).mean()})

        # clipping params from EDAC paper, not as in SAC paper (-20, 2)
        log_sigma = torch.clip(log_sigma, -5, 2)
        policy_dist = Normal(mu, torch.exp(log_sigma))

        if deterministic:
            action = mu
        else:
            action = policy_dist.rsample()

        tanh_action, log_prob = torch.tanh(action), None
        if need_log_prob:
            # change of variables formula (SAC paper, appendix C, eq 21)
            log_prob = policy_dist.log_prob(action).sum(axis=-1)
            log_prob = log_prob - torch.log(1 - tanh_action.pow(2) + 1e-6).sum(axis=-1)

        if need_policy_dist:
            return tanh_action * self.max_action, log_prob, policy_dist

        return tanh_action * self.max_action, log_prob

    def log_prob(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        hidden = self.trunk(state)
        mu, log_sigma = self.mu(hidden), self.log_sigma(hidden)

        # clipping params from EDAC paper, not as in SAC paper (-20, 2)
        log_sigma = torch.clip(log_sigma, -5, 2)
        policy_dist = Normal(mu, torch.exp(log_sigma))

        action = torch.clip(action, -self.max_action + 1e-6, self.max_action - 1e-6)
        log_prob = policy_dist.log_prob(torch.arctanh(action)).sum(axis=-1)
        log_prob = log_prob - torch.log(1 - action.pow(2) + 1e-6).sum(axis=-1)
        return log_prob

    @torch.no_grad()
    def act(self, state: np.ndarray, device: str) -> np.ndarray:
        deterministic = not self.training
        state = torch.tensor(state, device=device, dtype=torch.float32)
        action = self(state, deterministic=deterministic)[0].cpu().numpy()
        return action


class VectorizedCritic(nn.Module):
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        hidden_dim: int,
        num_critics: int,
        critic_LN: bool = True,
    ):
        super().__init__()
        self.critic = nn.Sequential(
            VectorizedLinear(state_dim + action_dim, hidden_dim, num_critics),
            nn.LayerNorm(hidden_dim, elementwise_affine=False)
            if critic_LN
            else nn.Identity(),
            nn.ReLU(),
            VectorizedLinear(hidden_dim, hidden_dim, num_critics),
            nn.LayerNorm(hidden_dim, elementwise_affine=False)
            if critic_LN
            else nn.Identity(),
            nn.ReLU(),
            VectorizedLinear(hidden_dim, hidden_dim, num_critics),
            nn.LayerNorm(hidden_dim, elementwise_affine=False)
            if critic_LN
            else nn.Identity(),
            nn.ReLU(),
            VectorizedLinear(hidden_dim, 1, num_critics),
        )
        # init as in the EDAC paper
        for layer in self.critic[::3]:
            torch.nn.init.constant_(layer.bias, 0.1)

        torch.nn.init.uniform_(self.critic[-1].weight, -3e-3, 3e-3)
        torch.nn.init.uniform_(self.critic[-1].bias, -3e-3, 3e-3)

        self.num_critics = num_critics

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        # [..., batch_size, state_dim + action_dim]
        state_action = torch.cat([state, action], dim=-1)
        if state_action.dim() != 3:
            assert state_action.dim() == 2
            # [num_critics, batch_size, state_dim + action_dim]
            state_action = state_action.unsqueeze(0).repeat_interleave(
                self.num_critics, dim=0
            )
        assert state_action.dim() == 3
        assert state_action.shape[0] == self.num_critics
        # [num_critics, batch_size]
        q_values = self.critic(state_action).squeeze(-1)
        return q_values


class EDAC:
    def __init__(
        self,
        actor: Actor,
        actor_optimizer: torch.optim.Optimizer,
        critic: VectorizedCritic,
        critic_optimizer: torch.optim.Optimizer,
        gamma: float = 0.99,
        tau: float = 0.005,
        eta: float = 1.0,
        # pretrain_eta: float = 1.0,
        alpha_learning_rate: float = 1e-4,
        bc_regulariser: float = -1.0,
        soft_bc_regulariser: float = -1.0,
        td_component: float = -1.0,
        pretrain_cql_regulariser: float = -1.0,
        cql_regulariser: float = -1.0,
        cql_n_actions: int = 10,
        kl_regulariser: float = -1.0,
        actor_kl_regulariser: float = -1.0,
        device: str = "cpu",  # noqa
    ):
        self.device = device

        self.actor = actor
        self.critic = critic
        with torch.no_grad():
            self.target_critic = deepcopy(self.critic)

        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer

        self.tau = tau
        self.gamma = gamma
        self.eta = eta
        # self.pretrain_eta = eta

        if bc_regulariser > 0.0:
            assert (
                soft_bc_regulariser <= 0.0
            ), "Only consider either hard or soft BC regularisation. Here bc_regulariser and soft_bc_regulariser > 0."
        self.bc_regulariser = bc_regulariser
        self.soft_bc_regulariser = soft_bc_regulariser
        if td_component > 0.0:
            assert (
                td_component <= 1.0
            ), "TD_component_must be between 0 and 1 (default: 0)."
        self.td_component = td_component
        self.pretrain_cql_regulariser = pretrain_cql_regulariser
        self.cql_regulariser = cql_regulariser
        self.cql_n_actions = cql_n_actions
        self.kl_regulariser = kl_regulariser
        self.actor_kl_regulariser = actor_kl_regulariser

        # adaptive alpha setup
        self.target_entropy = -float(self.actor.action_dim)
        self.log_alpha = torch.tensor(
            [0.0], dtype=torch.float32, device=self.device, requires_grad=True
        )
        self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_learning_rate)
        self.alpha = self.log_alpha.exp().detach()

    def _alpha_loss(self, state: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            action, action_log_prob = self.actor(state, need_log_prob=True)

        loss = (-self.log_alpha * (action_log_prob + self.target_entropy)).mean()

        return loss

    def _actor_loss(
        self,
        state: torch.Tensor,
        action: torch.Tensor,
        return_to_go: torch.Tensor,
    ) -> Tuple[torch.Tensor, float, float]:
        pi, log_pi, policy_dist = self.actor(
            state, need_log_prob=True, need_policy_dist=True
        )

        log_prob_action = self.actor.log_prob(state, action)

        q_value_dist = self.critic(state, pi)
        assert q_value_dist.shape[0] == self.critic.num_critics
        q_value_min = q_value_dist.min(0).values
        # needed for logging
        q_value_std = q_value_dist.std(0).mean().item()
        batch_entropy = -log_pi.mean().item()

        assert log_pi.shape == q_value_min.shape
        if self.bc_regulariser > 0.0:
            bc_loss = F.mse_loss(pi, action)
            loss = (self.alpha * log_pi - q_value_min).mean()
            loss = (
                loss / loss.detach() + self.bc_regulariser * bc_loss / bc_loss.detach()
            )
        elif self.soft_bc_regulariser > 0.0:
            bc_loss = (self.alpha * log_pi - log_prob_action).mean()
            loss = (self.soft_bc_regulariser * bc_loss / bc_loss) - (
                q_value_min.mean() / q_value_min.mean().detach()
            )
            loss = (self.alpha * log_pi - q_value_min).mean()
            loss = (
                loss / loss.detach() - self.soft_bc_regulariser * log_prob_action.mean()
            )
        else:
            loss = (self.alpha * log_pi - q_value_min).mean()

        if self.actor_kl_regulariser > 0.0:
            (
                pretrained_action,
                pretrained_log_prob,
                pretrained_policy_dist,
            ) = self.pretrained_actor(state, need_policy_dist=True)
            KL_regulariser = torch.distributions.kl.kl_divergence(
                policy_dist, pretrained_policy_dist
            ).mean()
            loss = loss / loss.detach() + self.actor_kl_regulariser * KL_regulariser / (
                KL_regulariser.detach() + 1e-6
            )

        return loss, batch_entropy, q_value_std

    def _critic_diversity_loss(
        self, state: torch.Tensor, action: torch.Tensor
    ) -> torch.Tensor:
        num_critics = self.critic.num_critics
        # almost exact copy from the original implementation, only style changes:
        # https://github.com/snu-mllab/EDAC/blob/198d5708701b531fd97a918a33152e1914ea14d7/lifelong_rl/trainers/q_learning/sac.py#L192

        # [num_critics, batch_size, *_dim]
        state = state.unsqueeze(0).repeat_interleave(num_critics, dim=0)
        action = (
            action.unsqueeze(0)
            .repeat_interleave(num_critics, dim=0)
            .requires_grad_(True)
        )
        # [num_critics, batch_size]
        q_ensemble = self.critic(state, action)

        q_action_grad = torch.autograd.grad(
            q_ensemble.sum(), action, retain_graph=True, create_graph=True
        )[0]
        q_action_grad = q_action_grad / (
            torch.norm(q_action_grad, p=2, dim=2).unsqueeze(-1) + 1e-10
        )
        # [batch_size, num_critics, action_dim]
        q_action_grad = q_action_grad.transpose(0, 1)

        masks = (
            torch.eye(num_critics, device=self.device)
            .unsqueeze(0)
            .repeat(q_action_grad.shape[0], 1, 1)
        )
        # removed einsum as it is usually slower than just torch.bmm
        # [batch_size, num_critics, num_critics]
        q_action_grad = q_action_grad @ q_action_grad.permute(0, 2, 1)
        q_action_grad = (1 - masks) * q_action_grad

        grad_loss = q_action_grad.sum(dim=(1, 2)).mean()
        grad_loss = grad_loss / (num_critics - 1)

        return grad_loss

    def _critic_loss(
        self,
        state: torch.Tensor,
        action: torch.Tensor,
        reward: torch.Tensor,
        next_state: torch.Tensor,
        done: torch.Tensor,
    ) -> torch.Tensor:
        with torch.no_grad():
            next_action, next_action_log_prob = self.actor(
                next_state, need_log_prob=True
            )
            q_next = self.target_critic(next_state, next_action).min(0).values
            q_next = q_next - self.alpha * next_action_log_prob

            assert q_next.unsqueeze(-1).shape == done.shape == reward.shape
            q_target = reward + self.gamma * (1 - done) * q_next.unsqueeze(-1)

        q_values = self.critic(state, action)

        # [ensemble_size, batch_size] - [1, batch_size]
        critic_loss = ((q_values - q_target.view(1, -1)) ** 2).mean(dim=1).sum(dim=0)
        diversity_loss = self._critic_diversity_loss(state, action)

        loss = critic_loss + self.eta * diversity_loss

        if self.kl_regulariser > 0.0:
            pretrained_q_values = self.pretrained_critic(state, action)

            KL_regulariser = F.mse_loss(
                q_values.min(0).values, pretrained_q_values.min(0).values
            )
            loss = loss / loss.detach() + self.kl_regulariser * KL_regulariser / (
                KL_regulariser.detach() + 1e-6
            )
        if self.cql_regulariser > 0.0:
            random_actions = action.new_empty(
                (self.cql_n_actions * action.shape[0], action.shape[1]),
                requires_grad=False,
            ).uniform_(-1, 1)
            repeated_state = state.repeat(self.cql_n_actions, 1)
            q_random_values = self.critic(repeated_state, random_actions).reshape(
                q_values.shape[0], q_values.shape[1], self.cql_n_actions
            )
            cql_regulariser = (
                torch.logsumexp(
                    q_random_values.min(0).values,
                    dim=-1,
                )
                - q_values.min(0).values
            ).mean()

            loss = (
                loss / loss.detach()
                + self.cql_regulariser * cql_regulariser / cql_regulariser.detach()
            )

        return loss

    def pretrain_BC(self, batch: TensorBatch) -> Dict[str, float]:
        state, action, reward, return_to_go, next_state, done = batch
        # Alpha update
        alpha_loss = self._alpha_loss(state)
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()
        self.alpha = self.log_alpha.exp().detach()

        # Compute actor loss
        pi, log_pi = self.actor(state, need_log_prob=True)
        log_prob_action = self.actor.log_prob(state, action)
        actor_loss = (self.alpha * log_pi - log_prob_action).mean()

        # Optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        log_dict = {
            "alpha_loss": alpha_loss.item(),
            "actor_loss": actor_loss.item(),
            "batch_entropy": -log_pi.mean().item(),
            "alpha": self.alpha.item(),
        }

        return log_dict

    def pretrain_soft_critic(
        self, batch: TensorBatch, epoch, pretrain_epochs
    ) -> Dict[str, float]:
        state, action, reward, return_to_go, next_state, done = batch

        log_dict = {}
        q_values = self.critic(state, action)

        # Compute critic loss
        # [ensemble_size, batch_size] - [1, batch_size]
        MC_critic_loss = ((q_values - return_to_go.view(1, -1)) ** 2).mean()
        diversity_loss = self._critic_diversity_loss(state, action)
        critic_loss = MC_critic_loss + self.eta * diversity_loss
        log_dict["MC_critic_loss"] = MC_critic_loss.item()
        log_dict["diversity_loss"] = diversity_loss.item()

        if self.td_component > 0.0:
            with torch.no_grad():
                next_action, next_action_log_prob = self.actor(
                    next_state, need_log_prob=True
                )
                q_next = self.target_critic(next_state, next_action).min(0).values
                q_next = q_next - self.alpha * next_action_log_prob

                assert q_next.unsqueeze(-1).shape == done.shape == reward.shape
                q_target0 = reward + self.gamma * (1 - done) * q_next.unsqueeze(-1)

            # [ensemble_size, batch_size] - [1, batch_size]
            TD_critic_loss = (
                ((q_values - q_target0.view(1, -1)) ** 2).mean(dim=1).sum(dim=0)
            )
            critic_loss = (
                ((1 - self.td_component) * MC_critic_loss / MC_critic_loss.detach())
                + self.td_component * TD_critic_loss / TD_critic_loss.detach()
                + self.eta * diversity_loss
            )
            log_dict["TD_critic_loss"] = TD_critic_loss.item()

        if self.pretrain_cql_regulariser > 0.0:
            random_actions = action.new_empty(
                (self.cql_n_actions * action.shape[0], action.shape[1]),
                requires_grad=False,
            ).uniform_(-1, 1)
            repeated_state = state.repeat(self.cql_n_actions, 1)
            q_random_values = self.critic(repeated_state, random_actions).reshape(
                q_values.shape[0], q_values.shape[1], self.cql_n_actions
            )
            cql_regulariser = (
                torch.logsumexp(
                    q_random_values.min(0).values,
                    dim=-1,
                )
                - q_values.min(0).values
            ).mean()

            critic_loss = (
                critic_loss / critic_loss.detach()
                + self.pretrain_cql_regulariser
                * cql_regulariser
                / cql_regulariser.detach()
            )
            log_dict["cql_regulariser"] = cql_regulariser.item()

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        #  Target networks soft update
        with torch.no_grad():
            soft_update(self.target_critic, self.critic, self.tau)
            # for logging, Q-ensemble std estimate with the random actions:
            # a ~ U[-max_action, max_action]
            max_action = self.actor.max_action
            random_actions = -max_action + 2 * max_action * torch.rand_like(action)

            q_random_std = self.critic(state, random_actions).std(0).mean().item()
            log_dict["q_random_std"] = q_random_std

        return log_dict

    def update(self, batch: TensorBatch) -> Dict[str, float]:
        state, action, reward, return_to_go, next_state, done = [
            arr.to(self.device) for arr in batch
        ]
        # Usually updates are done in the following order: critic -> actor -> alpha
        # But we found that EDAC paper uses reverse (which gives better results)

        # Alpha update
        alpha_loss = self._alpha_loss(state)
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()
        self.alpha = torch.clip(self.log_alpha, -10, 10).exp().detach()

        # Actor update
        actor_loss, actor_batch_entropy, q_policy_std = self._actor_loss(
            state, action, return_to_go
        )
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # Critic update
        critic_loss = self._critic_loss(state, action, reward, next_state, done)
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        #  Target networks soft update
        with torch.no_grad():
            soft_update(self.target_critic, self.critic, tau=self.tau)
            # for logging, Q-ensemble std estimate with the random actions:
            # a ~ U[-max_action, max_action]
            max_action = self.actor.max_action
            random_actions = -max_action + 2 * max_action * torch.rand_like(action)

            q_random_std = self.critic(state, random_actions).std(0).mean().item()

        update_info = {
            "alpha_loss": alpha_loss.item(),
            "critic_loss": critic_loss.item(),
            "actor_loss": actor_loss.item(),
            "batch_entropy": actor_batch_entropy,
            "alpha": self.alpha.item(),
            "q_policy_std": q_policy_std,
            "q_random_std": q_random_std,
        }
        return update_info

    def state_dict(self) -> Dict[str, Any]:
        state = {
            "actor": self.actor.state_dict(),
            "critic": self.critic.state_dict(),
            "target_critic": self.target_critic.state_dict(),
            "log_alpha": self.log_alpha.item(),
            "actor_optim": self.actor_optimizer.state_dict(),
            "critic_optim": self.critic_optimizer.state_dict(),
            "alpha_optim": self.alpha_optimizer.state_dict(),
        }
        return state

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.actor.load_state_dict(state_dict["actor"])
        self.critic.load_state_dict(state_dict["critic"])
        self.target_critic.load_state_dict(state_dict["target_critic"])
        self.actor_optimizer.load_state_dict(state_dict["actor_optim"])
        self.critic_optimizer.load_state_dict(state_dict["critic_optim"])
        self.alpha_optimizer.load_state_dict(state_dict["alpha_optim"])
        self.log_alpha.data[0] = state_dict["log_alpha"]
        self.alpha = self.log_alpha.exp().detach()


@torch.no_grad()
def eval_actor(
    env: gym.Env,
    actor: Actor,
    device: str,
    n_episodes: int,
    seed: int,
    render: bool,
    name: str,
) -> np.ndarray:
    env.seed(seed)
    actor.eval()
    episode_rewards = []
    video = VideoRecorder() if render else None
    for i in range(n_episodes):
        state, done = env.reset(), False
        episode_reward = 0.0
        while not done:
            action = actor.act(state, device)
            state, reward, done, _ = env.step(action)
            episode_reward += reward
            if video is not None and i == 0:  # Record 1 episode
                video.record(env)
        episode_rewards.append(episode_reward)

    actor.train()
    if video is not None:
        video.save(name, wandb=wandb)
    return np.array(episode_rewards)


def return_reward_range(dataset, max_episode_steps):
    returns, lengths = [], []
    ep_ret, ep_len = 0.0, 0
    for r, d in zip(dataset["rewards"], dataset["terminals"]):
        ep_ret += float(r)
        ep_len += 1
        if d or ep_len == max_episode_steps:
            returns.append(ep_ret)
            lengths.append(ep_len)
            ep_ret, ep_len = 0.0, 0
    lengths.append(ep_len)  # but still keep track of number of steps
    assert sum(lengths) == len(dataset["rewards"])
    return min(returns), max(returns)


def modify_reward(dataset, env_name, max_episode_steps=1000):
    if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
        min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
        dataset["rewards"] /= max_ret - min_ret
        dataset["rewards"] *= max_episode_steps
    elif "antmaze" in env_name:
        dataset["rewards"] -= 1.0


@pyrallis.wrap()
def train(config: TrainConfig):
    set_seed(config.train_seed, deterministic_torch=config.deterministic_torch)
    wandb_init(asdict(config))

    # data, evaluation, env setup
    env = gym.make(config.env_name)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    d4rl_dataset = env.get_dataset()

    if config.normalize_reward:
        modify_reward(d4rl_dataset, config.env_name)

    state_mean, state_std = compute_mean_std(d4rl_dataset["observations"], eps=1e-3)

    if "next_observations" not in d4rl_dataset.keys():
        d4rl_dataset["next_observations"] = np.roll(
            d4rl_dataset["observations"], shift=-1, axis=0
        )  # Terminals/timeouts block next observations
        print("Loaded next state observations from current state observations.")

    d4rl_dataset["observations"] = normalize_states(
        d4rl_dataset["observations"], state_mean, state_std
    )
    d4rl_dataset["next_observations"] = normalize_states(
        d4rl_dataset["next_observations"], state_mean, state_std
    )

    eval_env = wrap_env(env, state_mean=state_mean, state_std=state_std)

    buffer = ReplayBuffer(
        state_dim=state_dim,
        action_dim=action_dim,
        buffer_size=config.buffer_size,
        batch_size=config.batch_size,
        discount=config.gamma,
        device=config.device,
    )
    buffer.load_d4rl_dataset(d4rl_dataset)

    # Actor & Critic setup
    actor = Actor(
        state_dim, action_dim, config.hidden_dim, config.max_action, config.actor_LN
    )
    actor.to(config.device)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_learning_rate)
    critic = VectorizedCritic(
        state_dim, action_dim, config.hidden_dim, config.num_critics, config.critic_LN
    )
    critic.to(config.device)
    critic_optimizer = torch.optim.Adam(
        critic.parameters(), lr=config.critic_learning_rate
    )

    trainer = EDAC(
        actor=actor,
        actor_optimizer=actor_optimizer,
        critic=critic,
        critic_optimizer=critic_optimizer,
        gamma=config.gamma,
        tau=config.tau,
        eta=config.eta,
        # pretrain_eta=config.pretrain_eta,
        alpha_learning_rate=config.alpha_learning_rate,
        bc_regulariser=config.bc_regulariser,
        td_component=config.td_component,
        pretrain_cql_regulariser=config.pretrain_cql_regulariser,
        cql_regulariser=config.cql_regulariser,
        cql_n_actions=config.cql_n_actions,
        kl_regulariser=config.kl_regulariser,
        actor_kl_regulariser=config.actor_kl_regulariser,
        device=config.device,
    )
    # saving config to the checkpoint
    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    total_updates = 0.0
    for epoch in trange(config.num_epochs, desc="Training"):
        # training
        for _ in trange(config.num_updates_on_epoch, desc="Epoch", leave=False):
            batch = buffer.sample(config.batch_size)
            if config.pretrain is not None:
                if epoch <= config.pretrain_epochs:
                    if config.pretrain == "BC":
                        update_info = trainer.pretrain_BC(batch)
                    elif config.pretrain == "softAC":
                        if epoch < config.pretrain_epochs // (
                            1 / config.percentage_pretrain_BC
                        ):
                            update_info = trainer.pretrain_BC(batch)
                        else:
                            if buffer._soft_returns_loaded == False:
                                buffer.compute_soft_returns_to_go(
                                    alpha=trainer.alpha,
                                    actor=trainer.actor,
                                )
                                print("Soft returns to go loaded for BC actor!")
                            assert buffer._soft_returns_loaded == True
                            update_info = trainer.pretrain_soft_critic(
                                batch, epoch, config.pretrain_epochs
                            )
                    elif config.pretrain == "softC":
                        if buffer._soft_returns_loaded == False:
                            buffer.compute_soft_returns_to_go(
                                alpha=trainer.alpha,
                                actor=trainer.actor,
                            )
                            print("Soft returns to go loaded for initialised actor!")
                        assert buffer._soft_returns_loaded == True
                        update_info = trainer.pretrain_soft_critic(
                            batch, epoch, config.pretrain_epochs
                        )
                    else:
                        raise ValueError(
                            f"Pretrain type {config.pretrain} not recognised."
                        )
                else:
                    if epoch == config.pretrain_epochs + 1:
                        with torch.no_grad():
                            trainer.pretrained_critic = deepcopy(trainer.critic)
                            trainer.pretrained_actor = deepcopy(trainer.actor)
                    update_info = trainer.update(batch)
            else:
                update_info = trainer.update(batch)

            if total_updates % config.log_every == 0:
                wandb.log({"epoch": epoch, **update_info})

            total_updates += 1

        # evaluation
        if epoch % config.eval_every == 0 or epoch == config.num_epochs - 1:
            eval_returns = eval_actor(
                env=eval_env,
                actor=actor,
                n_episodes=config.eval_episodes,
                seed=config.eval_seed,
                device=config.device,
                render=False,
                name=config.name,
            )
            eval_log = {
                "eval/reward_mean": np.mean(eval_returns),
                "eval/reward_std": np.std(eval_returns),
                "epoch": epoch,
            }
            if hasattr(eval_env, "get_normalized_score"):
                normalized_score = eval_env.get_normalized_score(eval_returns) * 100.0
                eval_log["eval/normalized_score_mean"] = np.mean(normalized_score)
                eval_log["eval/normalized_score_std"] = np.std(normalized_score)

            wandb.log(eval_log)

            if config.checkpoints_path is not None:
                torch.save(
                    trainer.state_dict(),
                    os.path.join(config.checkpoints_path, f"{epoch}.pt"),
                )

    # testing
    test_returns = eval_actor(
        env=eval_env,
        actor=actor,
        n_episodes=100,
        seed=config.eval_seed,
        device=config.device,
        render=config.render,
        name=config.name,
    )
    test_log = {
        "test/reward_mean": np.mean(test_returns),
        "test/reward_std": np.std(test_returns),
    }
    if hasattr(eval_env, "get_normalized_score"):
        normalized_score = eval_env.get_normalized_score(test_returns) * 100.0
        test_log["test/normalized_score_mean"] = np.mean(normalized_score)
        test_log["test/normalized_score_std"] = np.std(normalized_score)

    wandb.log(test_log)

    wandb.finish()


if __name__ == "__main__":
    train()
