# Inspired by:
# 1. paper for SAC-N: https://arxiv.org/abs/2110.01548
# 2. implementation: https://github.com/snu-mllab/EDAC
import argparse
import math
import os
import random
import uuid
from copy import deepcopy
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union

import gym
import h5py
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import yaml
from epicare import evaluations
from epicare.envs import EpiCare  # noqa: F401
from torch.nn import functional as F
from tqdm import trange

import wandb
from drn import load_q_nets


@dataclass
class TrainConfig:
    # number of episodes
    episodes_avail: int = 65536 * 2
    # path to the dataset
    dataset_path: Optional[str] = None
    # wandb project name
    project: str = "EDAC-Benchmark"
    # wandb group name
    group: Optional[str] = ""  # DEPCRECATED
    # wandb run name
    name: str = "EDAC"
    # actor and critic hidden dim
    hidden_dim: int = 256
    # critic ensemble size
    num_critics: int = 230
    # discount factor
    gamma: float = 1.0
    # coefficient for the target critic Polyak's update
    tau: float = 5e-3
    # coefficient for the ensemble diversification loss
    eta: float = 500.0
    # actor learning rate
    actor_learning_rate: float = 3e-4
    # critic learning rate
    critic_learning_rate: float = 3e-4
    # alpha learning rate
    alpha_learning_rate: float = 3e-4
    # maximum size of the replay buffer
    buffer_size: int = 1_000_000
    # training dataset and evaluation environment
    env: str = "EpiCare-v0"
    # training batch size
    batch_size: int = 256
    # total number of training epochs
    num_epochs: int = 500
    # number of gradient updates during one epoch
    num_updates_on_epoch: int = 1000
    # number of episodes to run during evaluation
    eval_episodes: int = 100
    # evaluation frequency, will evaluate eval_every training steps
    eval_every: int = 5
    # path for checkpoints saving, optional
    checkpoints_path: Optional[str] = "./algorithms/checkpoints"
    # configure PyTorch to use deterministic algorithms instead
    # of nondeterministic ones
    deterministic_torch: bool = False
    # training random seed
    train_seed: int = 10
    # evaluation random seed
    eval_seed: int = 42
    # environment random seed
    env_seed: int = 1
    # frequency of metrics logging to the wandb
    log_every: int = 100
    # training device
    device: str = "cuda"
    # environment seed
    seed: int = 1
    # temperature for the Gumbel-Softmax distribution
    temperature: float = 6.5
    # number of checkpoints to save
    num_checkpoints: int = 0
    # frame stacking memory
    frame_stack: int = 1
    # behavior policy of the dataset
    behavior_policy: str = "smart"
    # include previous action in the observation
    include_previous_action: bool = False

    sweep_config: Optional[dict] = field(default=None)

    # Update the parameters with the parameters of the sweep
    def update_params(self, params: Dict[str, Any]) -> "TrainConfig":
        for key, value in params.items():
            setattr(self, key, value)
        self.dataset_path = (
            f"./data/{self.behavior_policy}/train_seed_{self.env_seed}.hdf5"
        )
        self.name = f"{self.name}-{self.env_name}-{self.train_seed}-{self.env_seed}-{str(uuid.uuid4())[:8]}"
        if self.checkpoints_path is not None:
            self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)
        return self


# general utils
TensorBatch = List[torch.Tensor]


def soft_update(target: nn.Module, source: nn.Module, tau: float):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)


def wandb_init(config: dict) -> None:
    wandb.init(
        config=config,
        project=config["project"],
        name=config["name"],
    )
    wandb.run.save()


def set_seed(
    seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
    if env is not None:
        env.seed(seed)
        env.action_space.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


def wrap_env(
    env: gym.Env,
    state_mean: Union[np.ndarray, float] = 0.0,
    state_std: Union[np.ndarray, float] = 1.0,
    reward_scale: float = 1.0,
) -> gym.Env:
    def normalize_state(state):
        return (state - state_mean) / state_std

    def scale_reward(reward):
        return reward_scale * reward

    env = gym.wrappers.TransformObservation(env, normalize_state)
    if reward_scale != 1.0:
        env = gym.wrappers.TransformReward(env, scale_reward)
    return env


class ReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        device: str = "cpu",
        frame_stack: int = 1,
        include_previous_action: bool = False,
    ):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0
        self._frame_stack = frame_stack
        self._prev_action = include_previous_action
        self._state_dim = state_dim

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros(
            (buffer_size, 1), dtype=torch.float32, device=device
        )
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._device = device

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32, device=self._device)

    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def preprocess_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        # Check if actions are already one-hot encoded
        if len(data["actions"].shape) == 1:
            # One-hot encode the actions if they are not already
            print("One-hot encoding actions")
            print(f"Actions shape: {self._actions.shape}")
            actions = np.eye(self._actions.shape[1])[data["actions"].astype(int)]
            data["actions"] = actions
        else:
            # Actions are already in the correct shape
            print("Actions are already one-hot encoded")
            actions = data["actions"]
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)

        print(f"Dataset size: {n_transitions}")

        # Frame stack the states
        frame_stacked_states = torch.zeros(
            (self._frame_stack,) + data["observations"].shape
        )
        frame_stacked_next_states = torch.zeros_like(frame_stacked_states)

        boundaries = [0] + [i + 1 for i, x in enumerate(self._dones.squeeze()) if x]
        observations = torch.tensor(data["observations"], dtype=torch.float32)
        next_observations = torch.tensor(data["next_observations"], dtype=torch.float32)
        for start, end in zip(boundaries[:-1], boundaries[1:]):
            for i in range(start, end):
                for j in range(i, min(i + self._frame_stack, end)):
                    frame_stacked_states[j - i, j] = observations[i, ...]
                    frame_stacked_next_states[j - i, j] = next_observations[i, ...]

        frame_stacked_states = frame_stacked_states.moveaxis(0, 1).to(self._device)
        frame_stacked_next_states = frame_stacked_next_states.moveaxis(0, 1).to(
            self._device
        )

        if self._prev_action:
            # Get the next action with zero for the terminal states
            next_actions = torch.where(
                self._dones[:n_transitions].bool(),
                torch.zeros_like(self._actions[:n_transitions]),
                self._actions[:n_transitions],
            )
            prev_actions = next_actions.roll(1, dims=0)
            up_to = -self._actions.shape[1]
            self._states[:n_transitions, up_to:] = prev_actions
            self._next_states[:n_transitions, up_to:] = next_actions
        else:
            up_to = self._states.shape[1]

        self._states[:n_transitions, :up_to] = frame_stacked_states.reshape(
            n_transitions, -1
        )
        self._next_states[:n_transitions, :up_to] = frame_stacked_next_states.reshape(
            n_transitions, -1
        )

    def sample(self, batch_size: int) -> TensorBatch:
        indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        return [states, actions, rewards, next_states, dones]

    def add_transition(self):
        # Use this method to add new data into the replay buffer during fine-tuning.
        # I left it unimplemented since now we do not do fine-tuning.
        raise NotImplementedError


# SAC Actor & Critic implementation
class VectorizedLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int, ensemble_size: int):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ensemble_size = ensemble_size

        self.weight = nn.Parameter(
            torch.empty(ensemble_size, in_features, out_features)
        )
        self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))

        self.reset_parameters()

    def reset_parameters(self):
        # default pytorch init for nn.Linear module
        for layer in range(self.ensemble_size):
            nn.init.kaiming_uniform_(self.weight[layer], a=math.sqrt(5))

        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # input: [ensemble_size, batch_size, input_size]
        # weight: [ensemble_size, input_size, out_size]
        # out: [ensemble_size, batch_size, out_size]
        return x @ self.weight + self.bias


class Actor(nn.Module):
    def __init__(
        self, state_dim: int, action_dim: int, hidden_dim: int, temperature: float
    ):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        # with separate layers works better than with Linear(hidden_dim, 2 * action_dim)
        self.logits = nn.Linear(hidden_dim, action_dim)

        # init as in the EDAC paper
        for layer in self.trunk[::2]:
            torch.nn.init.constant_(layer.bias, 0.1)

        torch.nn.init.uniform_(self.logits.weight, -1e-3, 1e-3)
        torch.nn.init.uniform_(self.logits.bias, -1e-3, 1e-3)

        self.action_dim = action_dim
        self.temperature = temperature

    def get_action_probabilities(self, state: torch.Tensor) -> torch.Tensor:
        """
        Computes the action probabilities using the Gumbel-Softmax distribution for the given state.

        :param_state: A torch.Tensor representing the state.
        :return: A torch.Tensor representing the action probabilities.
        """
        with torch.no_grad():
            hidden = self.trunk(state)
            logits = self.logits(hidden)
            probabiltiies = F.softmax(logits / self.temperature, dim=-1)
        return probabiltiies

    def forward(
        self,
        state: torch.Tensor,
        deterministic: bool = False,
        need_log_prob: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        hidden = self.trunk(state)
        logits = self.logits(hidden)

        if deterministic:
            action = torch.argmax(logits, dim=-1)
            action = F.one_hot(action.long(), num_classes=self.action_dim).float()
        else:
            action_probs = F.gumbel_softmax(logits, tau=self.temperature, hard=True)
            action = action_probs

        log_prob = None
        if need_log_prob:
            soft_prob = F.gumbel_softmax(logits, tau=self.temperature, hard=False)
            log_prob = torch.log(torch.sum(action * soft_prob, dim=-1) + 1e-10)

        return action, log_prob

    @torch.no_grad()
    def act(self, state: np.ndarray, device: str = "cpu") -> np.ndarray:
        deterministic = self.training
        state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32)
        action = self(state, deterministic=deterministic)[0].cpu().numpy()
        return action


class VectorizedCritic(nn.Module):
    def __init__(
        self, state_dim: int, action_dim: int, hidden_dim: int, num_critics: int
    ):
        super().__init__()
        self.critic = nn.Sequential(
            VectorizedLinear(state_dim + action_dim, hidden_dim, num_critics),
            nn.ReLU(),
            VectorizedLinear(hidden_dim, hidden_dim, num_critics),
            nn.ReLU(),
            VectorizedLinear(hidden_dim, hidden_dim, num_critics),
            nn.ReLU(),
            VectorizedLinear(hidden_dim, 1, num_critics),
        )
        # init as in the EDAC paper
        for layer in self.critic[::2]:
            torch.nn.init.constant_(layer.bias, 0.1)

        torch.nn.init.uniform_(self.critic[-1].weight, -3e-3, 3e-3)
        torch.nn.init.uniform_(self.critic[-1].bias, -3e-3, 3e-3)

        self.num_critics = num_critics

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        # [..., batch_size, state_dim + action_dim]
        state_action = torch.cat([state, action], dim=-1)
        if state_action.dim() != 3:
            assert state_action.dim() == 2
            # [num_critics, batch_size, state_dim + action_dim]
            state_action = state_action.unsqueeze(0).repeat_interleave(
                self.num_critics, dim=0
            )
        assert state_action.dim() == 3
        assert state_action.shape[0] == self.num_critics
        # [num_critics, batch_size]
        q_values = self.critic(state_action).squeeze(-1)
        return q_values


class EDAC:
    def __init__(
        self,
        actor: Actor,
        actor_optimizer: torch.optim.Optimizer,
        critic: VectorizedCritic,
        critic_optimizer: torch.optim.Optimizer,
        gamma: float = 0.99,
        tau: float = 0.005,
        eta: float = 1.0,
        alpha_learning_rate: float = 1e-4,
        device: str = "cuda",
    ):
        self.device = device

        self.actor = actor
        self.critic = critic
        with torch.no_grad():
            self.target_critic = deepcopy(self.critic)

        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer

        self.tau = tau
        self.gamma = gamma
        self.eta = eta

        self.actor_loss_visit = 0

        # adaptive alpha setup
        self.target_entropy = -float(self.actor.action_dim)
        self.log_alpha = torch.tensor(
            [0.0], dtype=torch.float32, device=self.device, requires_grad=True
        )
        self.alpha_optimizer = torch.optim.Adam(
            [self.log_alpha], lr=alpha_learning_rate
        )
        self.alpha = self.log_alpha.exp().detach()

    def _alpha_loss(self, state: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            action, action_log_prob = self.actor(state, need_log_prob=True)

        loss = (-self.log_alpha * (action_log_prob + self.target_entropy)).mean()

        return loss

    def _actor_loss(self, state: torch.Tensor) -> Tuple[torch.Tensor, float, float]:
        self.actor_loss_visit += 1
        action, action_log_prob = self.actor(state, need_log_prob=True)
        q_value_dist = self.critic(state, action)
        assert q_value_dist.shape[0] == self.critic.num_critics
        q_value_min = q_value_dist.min(0).values
        # needed for logging
        q_value_std = q_value_dist.std(0).mean().item()
        batch_entropy = -action_log_prob.mean().item()

        assert action_log_prob.shape == q_value_min.shape
        loss = (self.alpha * action_log_prob - q_value_min).mean()

        return loss, batch_entropy, q_value_std

    def _critic_diversity_loss(
        self, state: torch.Tensor, action: torch.Tensor
    ) -> torch.Tensor:
        num_critics = self.critic.num_critics
        # almost exact copy from the original implementation, only style changes:
        # https://github.com/snu-mllab/EDAC/blob/198d5708701b531fd97a918a33152e1914ea14d7/lifelong_rl/trainers/q_learning/sac.py#L192

        # [num_critics, batch_size, *_dim]
        state = state.unsqueeze(0).repeat_interleave(num_critics, dim=0)
        action = (
            action.unsqueeze(0)
            .repeat_interleave(num_critics, dim=0)
            .requires_grad_(True)
        )
        # [num_critics, batch_size]
        q_ensemble = self.critic(state, action)

        q_action_grad = torch.autograd.grad(
            q_ensemble.sum(), action, retain_graph=True, create_graph=True
        )[0]
        q_action_grad = q_action_grad / (
            torch.norm(q_action_grad, p=2, dim=2).unsqueeze(-1) + 1e-10
        )
        # [batch_size, num_critics, action_dim]
        q_action_grad = q_action_grad.transpose(0, 1)

        masks = (
            torch.eye(num_critics, device=self.device)
            .unsqueeze(0)
            .repeat(q_action_grad.shape[0], 1, 1)
        )
        # removed einsum as it is usually slower than just torch.bmm
        # [batch_size, num_critics, num_critics]
        q_action_grad = q_action_grad @ q_action_grad.permute(0, 2, 1)
        q_action_grad = (1 - masks) * q_action_grad

        grad_loss = q_action_grad.sum(dim=(1, 2)).mean()
        grad_loss = grad_loss / (num_critics - 1)

        return grad_loss

    def _critic_loss(
        self,
        state: torch.Tensor,
        action: torch.Tensor,
        reward: torch.Tensor,
        next_state: torch.Tensor,
        done: torch.Tensor,
    ) -> torch.Tensor:
        with torch.no_grad():
            next_action, next_action_log_prob = self.actor(
                next_state, need_log_prob=True
            )
            q_next = self.target_critic(next_state, next_action).min(0).values
            q_next = q_next - self.alpha * next_action_log_prob

            assert q_next.unsqueeze(-1).shape == done.shape == reward.shape
            q_target = reward + self.gamma * (1 - done) * q_next.unsqueeze(-1)

        q_values = self.critic(state, action)
        # [ensemble_size, batch_size] - [1, batch_size]
        critic_loss = ((q_values - q_target.view(1, -1)) ** 2).mean(dim=1).sum(dim=0)
        diversity_loss = self._critic_diversity_loss(state, action)
        loss = critic_loss + self.eta * diversity_loss

        return loss, diversity_loss

    def update(self, batch: TensorBatch) -> Dict[str, float]:
        state, action, reward, next_state, done = [arr.to(self.device) for arr in batch]
        # Usually updates are done in the following order: critic -> actor -> alpha
        # But we found that EDAC paper uses reverse (which gives better results)

        # Alpha update
        alpha_loss = self._alpha_loss(state)
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()

        self.alpha = self.log_alpha.exp().detach()

        # Actor update
        actor_loss, actor_batch_entropy, q_policy_std = self._actor_loss(state)
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # Critic update
        critic_loss, diversity_loss = self._critic_loss(
            state, action, reward, next_state, done
        )
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        #  Target networks soft update
        with torch.no_grad():
            soft_update(self.target_critic, self.critic, tau=self.tau)
            # for logging, Q-ensemble std estimate with the random actions:
            # a ~ U[action_space]
            action_dim = self.actor.action_dim
            # Generate random one-hot encoded actions of shape [batch_size, action_dim]
            random_actions = torch.zeros(
                (state.shape[0], action_dim), device=self.device
            )
            random_actions[
                torch.arange(state.shape[0]),
                torch.randint(action_dim, (state.shape[0],)),
            ] = 1.0

            q_random_std = self.critic(state, random_actions).std(0).mean().item()

        update_info = {
            "alpha_loss": alpha_loss.item(),
            "critic_loss": critic_loss.item(),
            "diversity_loss": diversity_loss.item(),
            "eta": self.eta,
            "actor_loss": actor_loss.item(),
            "batch_entropy": actor_batch_entropy,
            "alpha": self.alpha.item(),
            "q_policy_std": q_policy_std,
            "q_random_std": q_random_std,
        }
        return update_info

    def state_dict(self) -> Dict[str, Any]:
        state = {
            "actor": self.actor.state_dict(),
            "critic": self.critic.state_dict(),
            "target_critic": self.target_critic.state_dict(),
            "log_alpha": self.log_alpha.item(),
            "actor_optim": self.actor_optimizer.state_dict(),
            "critic_optim": self.critic_optimizer.state_dict(),
            "alpha_optim": self.alpha_optimizer.state_dict(),
        }
        return state

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.actor.load_state_dict(state_dict["actor"])
        self.critic.load_state_dict(state_dict["critic"])
        self.target_critic.load_state_dict(state_dict["target_critic"])
        self.actor_optimizer.load_state_dict(state_dict["actor_optim"])
        self.critic_optimizer.load_state_dict(state_dict["critic_optim"])
        self.alpha_optimizer.load_state_dict(state_dict["alpha_optim"])
        self.log_alpha.data[0] = state_dict["log_alpha"]
        self.alpha = self.log_alpha.exp().detach()


@torch.no_grad()
def eval_actor(
    env: gym.Env,
    actor: nn.Module,
    device: str,
    n_episodes: int,
    seed: int,
    frame_stack: int,
    include_previous_action: bool = False,
    action_dim: int = 0,
) -> np.ndarray:
    env.seed(seed)
    actor.eval()
    episode_rewards = []
    for _ in range(n_episodes):
        state_history = np.zeros((frame_stack, env.observation_space.shape[0]))
        prev_action = np.zeros((action_dim,))
        state, done = env.reset(), False
        episode_reward = 0.0
        while not done:
            state_history = np.roll(state_history, shift=1, axis=0)
            state_history[0] = state

            # Prepare the actor input depending on whether previous action is included
            if include_previous_action:
                state = np.concatenate((state_history.flatten(), prev_action))
            else:
                state = state_history.flatten()

            action = actor.act(state, device=device)
            # Convert back from one-hot encoding
            action_idx = np.argmax(action)
            state, reward, done, _ = env.step(action_idx)
            episode_reward += reward

            # Update prev_action for the next iteration
            if include_previous_action:
                prev_action = np.arange(action_dim) == action_idx

        episode_rewards.append(episode_reward)

    actor.train()
    return np.asarray(episode_rewards)


def return_reward_range(dataset, max_episode_steps):
    returns, lengths = [], []
    ep_ret, ep_len = 0.0, 0
    for r, d in zip(dataset["rewards"], dataset["terminals"]):
        ep_ret += float(r)
        ep_len += 1
        if d or ep_len == max_episode_steps:
            returns.append(ep_ret)
            lengths.append(ep_len)
            ep_ret, ep_len = 0.0, 0
    lengths.append(ep_len)  # but still keep track of number of steps
    assert sum(lengths) == len(dataset["rewards"])
    return min(returns), max(returns)


def load_custom_dataset(config: TrainConfig) -> Dict[str, np.ndarray]:
    # Load your custom dataset from an HDF5 file
    with h5py.File(config.dataset_path, "r") as dataset_file:
        # Here, 'observations', 'actions', etc. are keys in your HDF5 file that correspond to your data.
        # If they are named differently in your file, you'll need to adjust the keys accordingly.
        observations = dataset_file["observations"][:]
        actions = dataset_file["actions"][:]
        rewards = dataset_file["rewards"][:]
        next_observations = dataset_file["next_observations"][:]
        terminals = dataset_file["terminals"][:]

    # Convert to float32 for consistency with other Gym environments and D4RL datasets
    observations = observations.astype(np.float32)
    actions = actions.astype(np.float32)
    rewards = rewards.astype(np.float32)
    next_observations = next_observations.astype(np.float32)
    terminals = terminals.astype(np.float32)

    # Ensure terminals are boolean
    terminals = terminals.astype(np.bool_)

    # Create the dataset in the expected format
    custom_dataset = {
        "observations": observations,
        "actions": actions,
        "rewards": rewards,
        "next_observations": next_observations,
        "terminals": terminals,
    }

    return custom_dataset


def train(config: TrainConfig):
    set_seed(config.train_seed, deterministic_torch=config.deterministic_torch)
    wandb_init(asdict(config))

    # data, evaluation, env setup
    eval_env = wrap_env(gym.make(config.env_name, seed=config.env_seed))
    state_dim = eval_env.observation_space.shape[0] * config.frame_stack
    if config.include_previous_action:
        state_dim += eval_env.action_space.n
    action_dim = eval_env.action_space.n

    dataset = load_custom_dataset(config)

    buffer = ReplayBuffer(
        state_dim=state_dim,
        action_dim=action_dim,
        buffer_size=config.buffer_size,
        device=config.device,
        frame_stack=config.frame_stack,
        include_previous_action=config.include_previous_action,
    )
    buffer.preprocess_dataset(dataset)

    # Actor & Critic setup
    actor = Actor(state_dim, action_dim, config.hidden_dim, config.temperature)
    actor.to(config.device)
    actor_optimizer = torch.optim.Adam(
        actor.parameters(), lr=config.actor_learning_rate
    )
    critic = VectorizedCritic(
        state_dim, action_dim, config.hidden_dim, config.num_critics
    )
    critic.to(config.device)
    critic_optimizer = torch.optim.Adam(
        critic.parameters(), lr=config.critic_learning_rate
    )

    trainer = EDAC(
        actor=actor,
        actor_optimizer=actor_optimizer,
        critic=critic,
        critic_optimizer=critic_optimizer,
        gamma=config.gamma,
        tau=config.tau,
        eta=config.eta,
        alpha_learning_rate=config.alpha_learning_rate,
        device=config.device,
    )

    if config.num_checkpoints > 0:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    # Generate a list of training steps as close as possible to evenly spaced
    # throughout the training process.
    checkpoint_num = 0
    checkpoint_steps = [
        int(round(x))
        for x in np.linspace(
            config.num_epochs - 1, 0, config.num_checkpoints, endpoint=False
        )
    ]

    total_updates = 0.0
    for epoch in trange(config.num_epochs, desc="Training"):
        # training
        for _ in trange(config.num_updates_on_epoch, desc="Epoch", leave=False):
            batch = buffer.sample(config.batch_size)
            update_info = trainer.update(batch)

            if total_updates % config.log_every == 0:
                wandb.log({"epoch": epoch, **update_info})

            total_updates += 1

        # evaluation
        if epoch % config.eval_every == 0 or epoch == config.num_epochs - 1:
            eval_returns = eval_actor(
                env=eval_env,
                actor=actor,
                n_episodes=config.eval_episodes,
                seed=config.eval_seed,
                device=config.device,
                frame_stack=config.frame_stack,
                include_previous_action=config.include_previous_action,
                action_dim=action_dim,
            )
            eval_log = {
                "eval/reward_mean": np.mean(eval_returns),
                "eval/reward_std": np.std(eval_returns),
                "epoch": epoch,
            }
            if hasattr(eval_env, "get_normalized_score"):
                eval_log["eval/normalized_score_mean"] = (
                    eval_env.get_normalized_score(eval_returns) * 100.0
                )
                eval_log["eval/normalized_score_std"] = (
                    eval_env.get_normalized_score(np.std(eval_returns)) * 100.0
                )

            wandb.log(eval_log)

        if config.num_checkpoints and epoch == checkpoint_steps[-1]:
            checkpoint_steps.pop()
            torch.save(
                trainer.state_dict(),
                os.path.join(
                    config.checkpoints_path, f"checkpoint_{checkpoint_num}.pt"
                ),
            )
            checkpoint_num += 1

    wandb.finish()


def load_model(checkpoint_path, config):
    # Create an environment to get state_dim and action_dim
    env = gym.make(config.env_name, seed=config.env_seed)
    state_dim, action_dim = evaluations.state_and_action_dims(env, config)

    # Initialize the actor with the correct dimensions
    actor = Actor(
        state_dim,
        action_dim,
        config.hidden_dim,
        config.temperature,
    ).to(config.device)

    # Load the state dictionary
    state_dict = torch.load(checkpoint_path)
    actor.load_state_dict(state_dict["actor"])
    return actor


if __name__ == "__main__":
    base_parser = argparse.ArgumentParser(add_help=False)
    subparsers = base_parser.add_subparsers(title="subcommands", dest="subcommand")

    eval_parser = subparsers.add_parser("eval", help="Evaluate all trained checkpoints")
    eval_parser.add_argument(
        "--base-path", type=str, metavar="NAME", help="path to the checkpoint directory"
    )
    eval_parser.add_argument(
        "--out-name", type=str, metavar="NAME", help="name of the results file"
    )

    train_parser = subparsers.add_parser("train", help="Train an instance of the model")
    train_parser.add_argument(
        "config_loc", type=str, metavar="NAME", help="location of config file"
    )

    args = base_parser.parse_args()

    if args.subcommand == "eval":
        results_df = evaluations.process_checkpoints(
            args.base_path,
            "EDAC",
            TrainConfig,
            load_model,
            wrap_env,
            out_name=args.out_name,
            load_q_nets=load_q_nets,
        )
        if len(results_df) == 0:
            print("No results to evaluate")
            exit(1)

        combined_stats_df = evaluations.combine_stats(results_df)
        evaluations.grand_stats(combined_stats_df)

    elif args.subcommand == "train":
        with open(f"algorithms/sweep_configs/{args.config_loc}", "r") as f:
            sweep_config = yaml.load(f, Loader=yaml.FullLoader)

        # Start a new wandb run
        run = wandb.init(config=sweep_config)

        # Update the TrainConfig instance with parameters from wandb
        # This assumes that update_params will handle single value parameters correctly
        config = TrainConfig()
        config.update_params(dict(wandb.config))

        # Now pass the updated config to the train function
        train(config)

    else:
        base_parser.print_help()
