# based on https://github.com/chwoong/LiRE
import os, sys, random, datetime, copy
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import pyrallis
import tqdm

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter

from logger import Logger

sys.path.append("./Reward_learning")
from reward_learning import reward_model

TensorBatch = List[torch.Tensor]
TrajTensorBatch = List[List[torch.Tensor]]


EXP_ADV_MAX = 100.0
LOG_STD_MIN = -20.0
LOG_STD_MAX = 2.0


@dataclass
class TrainConfig:
    # Experiment
    device: str = None
    dataset: str = "medium-replay"
    env: str = "metaworld_box-close-v2"  # environment name
    seed: int = 0  # Sets Gym, PyTorch and Numpy seeds
    eval_freq: int = int(5e3)  # How often (time steps) we evaluate
    n_episodes: int = 50  # How many episodes run during evaluation
    max_timesteps: int = 250000  # Max time steps to run environment
    checkpoints_path: Optional[str] = None  # Save path
    load_model: str = ""  # Model load file name, "" doesn't load
    data_quality: float = None # Replay buffer size (data_quality * 100000)
    trivial_reward: int = 0  # 0: GT reward, 1: zero reward, 2: constant reward, 3: negative reward
    # Algorithm
    buffer_size: int = 2_000_000  # Replay buffer size
    batch_size: int = 256  # Batch size for all networks
    traj_batch_size: int = 16  # Batch size for trajectory sampling
    discount: float = 0.99  # Discount factor
    tau: float = 0.005  # Target network update rate
    beta: float = 1.0  # Temperature
    deterministic: bool = False  # Use deterministic actor
    normalize: bool = True  # Normalize states
    normalize_reward: bool = True  # Normalize reward
    vf_lr: float = 3e-4  # V function learning rate
    qf_lr: float = 3e-4  # Critic learning rate
    actor_lr: float = 3e-4  # Actor learning rate
    actor_dropout: Optional[float] = None  # Adroit uses dropout for policy network
    # Reward model
    feedback_num: int = 1000
    use_reward_model: bool = True
    epochs: int = 0
    batch_size: int = 256
    activation: str = "tanh"
    lr: float = 1e-3
    threshold: float = 0.5
    segment_size: int = 25
    data_aug: str = "none"
    hidden_sizes: int = 128
    ensemble_num: int = 3
    ensemble_method: str = "mean"
    q_budget: int = 1
    feedback_type: str = "RLT"
    model_type: str = "BT"
    noise: float = 0.0
    human: bool = False

    def __post_init__(self):
        if "dmc" in self.env:
            self.log_path = f"log/{self.env}/data_{self.data_quality}_fn_{self.feedback_num}/s_{self.seed}"
        elif self.dataset=="medium-replay":
            self.log_path = f"log/{self.env}/medium-replay/data_{self.data_quality}_fn_{self.feedback_num}_qb_{self.q_budget}_ft_{self.feedback_type}_m_{self.model_type}/s_{self.seed}"
        elif self.dataset=="medium-expert":
            self.log_path = f"log/{self.env}/medium-expert/data_{self.data_quality}_fn_{self.feedback_num}/s_{self.seed}"


def soft_update(target: nn.Module, source: nn.Module, tau: float):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)


def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
    mean = states.mean(0)
    std = states.std(0) + eps
    return mean, std


def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
    return (states - mean) / std


def wrap_env(
    env: gym.Env,
    state_mean: Union[np.ndarray, float] = 0.0,
    state_std: Union[np.ndarray, float] = 1.0,
    reward_scale: float = 1.0,
) -> gym.Env:
    # PEP 8: E731 do not assign a lambda expression, use a def
    def normalize_state(state):
        # if state 2 dim
        if len(state) == 2:
            state = state[0]
        return (
            state - state_mean
        ) / state_std  # epsilon should be already added in std.

    def scale_reward(reward):
        # Please be careful, here reward is multiplied by scale!
        return reward_scale * reward

    env = gym.wrappers.TransformObservation(env, normalize_state)
    if reward_scale != 1.0:
        env = gym.wrappers.TransformReward(env, scale_reward)
    return env


class ReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        device: str = "cpu",
    ):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros(
            (buffer_size, 1), dtype=torch.float32, device=device
        )
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._device = device

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32, device=self._device)

    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)

        print(f"Dataset size: {n_transitions}")

    def load_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)
        self.num_traj = n_transitions // 500

        print(f"Dataset size: {n_transitions}")

    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)

        print(f"Dataset size: {n_transitions}")

    def sample(self, batch_size: int) -> TensorBatch:
        indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        return [states, actions, rewards, next_states, dones]

    def sample_trajectory(self, batch_size: int, segment_size: int) -> TensorBatch:
        traj_idx = np.random.choice(self.num_traj, batch_size, replace=True)
        idx_start = [500 * i + np.random.randint(0, 500 - segment_size) for i in traj_idx]
        indices = []
        for i in idx_start:
            indices.extend([j for j in range(i, i + segment_size)])
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        return [states, actions, rewards, next_states]



def set_seed(
    seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
    if env is not None:
        env.seed(seed)
        env.action_space.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


@torch.no_grad()
def eval_actor(
    env: gym.Env,
    env_name: str,
    actor: nn.Module,
    device: str,
    n_episodes: int,
    seed: int,
) -> np.ndarray:
    # env.seed(seed)
    actor.eval()
    episode_rewards = []
    episode_success_list = []
    for _ in range(n_episodes):
        # perturn initial arm position
        state, done = env.reset(), False

        episode_reward = 0.0
        episode_succes = 0
        while not done:
            action = actor.act(state, device)
            state, reward, done, info = env.step(action)
            episode_reward += reward
            if "metaworld" in env_name:
                episode_succes = max(episode_succes, info["success"])

        episode_rewards.append(episode_reward)
        episode_success_list.append(episode_succes)

    actor.train()
    return np.array(episode_rewards), np.array(episode_success_list)


def return_reward_range(dataset, max_episode_steps):
    returns, lengths = [], []
    ep_ret, ep_len = 0.0, 0
    for r, d in zip(dataset["rewards"], dataset["terminals"]):
        ep_ret += float(r)
        ep_len += 1
        if d or ep_len == max_episode_steps:
            returns.append(ep_ret)
            lengths.append(ep_len)
            ep_ret, ep_len = 0.0, 0
    lengths.append(ep_len)  # but still keep track of number of steps
    assert sum(lengths) == len(dataset["rewards"])
    return min(returns), max(returns)


def modify_reward(
    dataset,
    max_episode_steps=1000,
    trivial_reward=0,
):
    # min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
    # GT reward
    if trivial_reward == 0:
        dataset["rewards"] = (dataset["rewards"] - min(dataset["rewards"])) / (
            max(dataset["rewards"]) - min(dataset["rewards"])
        )
    # zero reward
    elif trivial_reward == 1:
        dataset["rewards"] *= 0.0
    # random reward
    elif trivial_reward == 2:
        dataset["rewards"] = (dataset["rewards"] - min(dataset["rewards"])) / (
            max(dataset["rewards"]) - min(dataset["rewards"])
        )
        min_reward, max_reward = min(dataset["rewards"]), max(dataset["rewards"])
        dataset["rewards"] = np.random.uniform(
            min_reward, max_reward, size=dataset["rewards"].shape
        )
        # min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
    # negative reward
    elif trivial_reward == 3:
        dataset["rewards"] = 1 - (dataset["rewards"] - min(dataset["rewards"])) / (
            max(dataset["rewards"]) - min(dataset["rewards"])
        )


def gumbel_rescale_loss(diff, beta, max_clip=5.0):
    z = diff/beta
    if max_clip is not None:
        z = torch.clip(z, -max_clip, max_clip) # clip max value
    max_z = torch.max(z)
    max_z = torch.where(max_z < -1.0, -1.0, max_z)
    max_z = max_z.detach()  # Detach the gradients
    loss = torch.exp(z - max_z) - z*torch.exp(-max_z) - torch.exp(-max_z)  # scale by e^max_z
    return loss.mean()


class Squeeze(nn.Module):
    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.squeeze(dim=self.dim)


class MLP(nn.Module):
    def __init__(
        self,
        dims,
        activation_fn: Callable[[], nn.Module] = nn.ReLU,
        output_activation_fn: Callable[[], nn.Module] = None,
        squeeze_output: bool = False,
        dropout: Optional[float] = None,
    ):
        super().__init__()
        n_dims = len(dims)
        if n_dims < 2:
            raise ValueError("MLP requires at least two dims (input and output)")

        layers = []
        for i in range(n_dims - 2):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            layers.append(activation_fn())

            if dropout is not None:
                layers.append(nn.Dropout(dropout))

        layers.append(nn.Linear(dims[-2], dims[-1]))
        if output_activation_fn is not None:
            layers.append(output_activation_fn())
        if squeeze_output:
            if dims[-1] != 1:
                raise ValueError("Last dim must be 1 when squeezing")
            layers.append(Squeeze(-1))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class GaussianPolicy(nn.Module):
    def __init__(
        self,
        state_dim: int,
        act_dim: int,
        max_action: float,
        hidden_dim: int = 256,
        n_hidden: int = 2,
        dropout: Optional[float] = None,
    ):
        super().__init__()
        self.net = MLP(
            [state_dim, *([hidden_dim] * n_hidden), act_dim],
            output_activation_fn=nn.Tanh,
        )
        self.log_std = nn.Parameter(torch.zeros(act_dim, dtype=torch.float32))
        self.max_action = max_action

    def forward(self, obs: torch.Tensor) -> Normal:
        mean = self.net(obs)
        std = torch.exp(self.log_std.clamp(LOG_STD_MIN, LOG_STD_MAX))
        return Normal(mean, std)

    @torch.no_grad()
    def act(self, state: np.ndarray, device: str = "cpu"):
        state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32)
        dist = self(state)
        action = dist.mean if not self.training else dist.sample()
        action = torch.clamp(
            self.max_action * action, -self.max_action, self.max_action
        )
        return action.cpu().data.numpy().flatten()


class DeterministicPolicy(nn.Module):
    def __init__(
        self,
        state_dim: int,
        act_dim: int,
        max_action: float,
        hidden_dim: int = 256,
        n_hidden: int = 2,
        dropout: Optional[float] = None,
    ):
        super().__init__()
        self.net = MLP(
            [state_dim, *([hidden_dim] * n_hidden), act_dim],
            output_activation_fn=nn.Tanh,
            dropout=dropout,
        )
        self.max_action = max_action

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.net(obs)

    @torch.no_grad()
    def act(self, state: np.ndarray, device: str = "cpu"):
        state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32)
        return (
            torch.clamp(
                self(state) * self.max_action, -self.max_action, self.max_action
            )
            .cpu()
            .data.numpy()
            .flatten()
        )


class TwinQ(nn.Module):
    def __init__(
        self, state_dim: int, action_dim: int, hidden_dim: int = 256, n_hidden: int = 2
    ):
        super().__init__()
        dims = [state_dim + action_dim, *([hidden_dim] * n_hidden), 1]
        self.q1 = MLP(dims, squeeze_output=True)
        self.q2 = MLP(dims, squeeze_output=True)

    def both(
        self, state: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        sa = torch.cat([state, action], 1)
        return self.q1(sa), self.q2(sa)

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        return torch.min(*self.both(state, action))


class ValueFunction(nn.Module):
    def __init__(self, state_dim: int, hidden_dim: int = 256, n_hidden: int = 2):
        super().__init__()
        dims = [state_dim, *([hidden_dim] * n_hidden), 1]
        self.v = MLP(dims, squeeze_output=True)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.v(state)


class XQL_VA:
    def __init__(
        self,
        max_action: float,
        actor: nn.Module,
        actor_optimizer: torch.optim.Optimizer,
        q_network: nn.Module,
        q_optimizer: torch.optim.Optimizer,
        v_network: nn.Module,
        v_optimizer: torch.optim.Optimizer,
        beta: float = 1.0,
        max_steps: int = 1000000,
        traj_batch_size: int = 16,
        discount: float = 0.99,
        tau: float = 0.005,
        device: str = "cpu",
    ):
        self.max_action = max_action
        self.qf = q_network
        self.q_target = copy.deepcopy(self.qf).requires_grad_(False).to(device)
        self.vf = v_network
        self.actor = actor
        self.v_optimizer = v_optimizer
        self.q_optimizer = q_optimizer
        self.actor_optimizer = actor_optimizer
        self.actor_lr_schedule = CosineAnnealingLR(self.actor_optimizer, max_steps)
        self.beta = beta
        self.discount = discount
        self.tau = tau
        self.traj_batch_size = traj_batch_size

        self.total_it = 0
        self.device = device

    def _update_v(self, observations, actions, log_dict) -> torch.Tensor:
        # Update value function
        with torch.no_grad():
            target_q = self.q_target(observations, actions)

        v = self.vf(observations)
        adv = target_q - v
        v_loss = gumbel_rescale_loss(adv, self.beta)
        log_dict["train/value_loss"] = v_loss.item()
        log_dict["train/v_mean"] = v.mean().item()
        log_dict["train/v_std"] = v.std().item()
        log_dict["train/adv_mean"] = adv.mean().item()
        log_dict["train/adv_std"] = adv.std().item()
        self.v_optimizer.zero_grad()
        v_loss.backward()
        self.v_optimizer.step()
        return adv

    def _update_q(
        self,
        traj_batch: TrajTensorBatch,
        log_dict: Dict,
    ):
        # targets = rewards + (1.0 - terminals.float()) * self.discount * next_v.detach()
        # MetaWolrd has no terminals (only time limit)
        obs_t, action_t, reward_t, next_obs_t = traj_batch

        segment_size = obs_t.size(0)//(2*self.traj_batch_size)
        half_size = self.traj_batch_size
        
        # value alignment loss
        with torch.no_grad():
            next_v_t = self.vf(next_obs_t).flatten()
            target_t = reward_t.flatten() + self.discount * next_v_t.flatten()
            target_t = torch.sum(target_t.view(-1,segment_size), dim=-1).detach()

        q1_t, q2_t = self.qf.both(obs_t, action_t)
        q1_t, q2_t = torch.sum(q1_t.view(-1,segment_size), dim=-1), torch.sum(q2_t.view(-1,segment_size), dim=-1)
        err1 = (q1_t[:half_size]-target_t[:half_size]-q1_t[half_size:]+target_t[half_size:]) / segment_size
        err2 = (q2_t[:half_size]-target_t[:half_size]-q2_t[half_size:]+target_t[half_size:]) / segment_size
        align_loss = torch.mean(err1**2) + torch.mean(err2**2)

        q_loss = align_loss

        log_dict["train/q_loss"] = q_loss.item()
        self.q_optimizer.zero_grad()
        q_loss.backward()
        self.q_optimizer.step()

        soft_update(self.q_target, self.qf, self.tau)

    def _update_policy(
        self,
        adv: torch.Tensor,
        observations: torch.Tensor,
        actions: torch.Tensor,
        log_dict: Dict,
    ):
        exp_adv = torch.exp(adv.detach() / self.beta).clamp(max=EXP_ADV_MAX)
        policy_out = self.actor(observations)
        if isinstance(policy_out, torch.distributions.Distribution):
            bc_losses = -policy_out.log_prob(actions).sum(-1, keepdim=False)
        elif torch.is_tensor(policy_out):
            if policy_out.shape != actions.shape:
                raise RuntimeError("Actions shape missmatch")
            bc_losses = torch.sum((policy_out - actions) ** 2, dim=1)
        else:
            raise NotImplementedError
        policy_loss = torch.mean(exp_adv * bc_losses)
        log_dict["train/actor_loss"] = policy_loss.item()
        self.actor_optimizer.zero_grad()
        policy_loss.backward()
        self.actor_optimizer.step()
        self.actor_lr_schedule.step()

    def train(self, batch: TensorBatch, traj_batch: TensorBatch) -> Dict[str, float]:
        self.total_it += 1
        (
            observations,
            actions,
            rewards,
            next_observations,
            dones,
        ) = batch
        log_dict = {}

        # Update value function
        adv = self._update_v(observations, actions, log_dict)
        # Update Q function
        self._update_q(traj_batch, log_dict)
        # Update actor
        self._update_policy(adv, observations, actions, log_dict)

        return log_dict

    def state_dict(self) -> Dict[str, Any]:
        return {
            "qf": self.qf.state_dict(),
            "q_optimizer": self.q_optimizer.state_dict(),
            "vf": self.vf.state_dict(),
            "v_optimizer": self.v_optimizer.state_dict(),
            "actor": self.actor.state_dict(),
            "actor_optimizer": self.actor_optimizer.state_dict(),
            "actor_lr_schedule": self.actor_lr_schedule.state_dict(),
            "total_it": self.total_it,
        }

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.qf.load_state_dict(state_dict["qf"])
        self.q_optimizer.load_state_dict(state_dict["q_optimizer"])
        self.q_target = copy.deepcopy(self.qf)

        self.vf.load_state_dict(state_dict["vf"])
        self.v_optimizer.load_state_dict(state_dict["v_optimizer"])

        self.actor.load_state_dict(state_dict["actor"])
        self.actor_optimizer.load_state_dict(state_dict["actor_optimizer"])
        self.actor_lr_schedule.load_state_dict(state_dict["actor_lr_schedule"])

        self.total_it = state_dict["total_it"]

@pyrallis.wrap()
def train(config: TrainConfig):
    if config.device==None:
        config.device = "cuda" if torch.cuda.is_available() else "cpu"
    elif config.device!=None and config.device.isdigit():
        assert torch.cuda.device_count()>int(config.device), "invalid device"
        os.environ['CUDA_VISIBLE_DEVICES'] = f"{config.device}"
        config.device = "cuda"

    log_path = os.path.join(config.log_path, f"XQL-VA_beta_{config.beta}")
    writer = SummaryWriter(log_path)
    logger = Logger(writer=writer,log_path=log_path)

    if "dmc" in config.env:
        from env import utils_dmc
        env = utils_dmc.make_dmc_env(config.env, config.seed)
        dataset = utils_dmc.DMC_dataset(config)
    elif config.dataset == "medium-replay":
        from env import utils_metaworld
        env = utils_metaworld.make_metaworld_env(config.env, config.seed)
        dataset = utils_metaworld.MetaWorld_mr_dataset(config)
    elif config.dataset == "medium-expert":
        from env import utils_metaworld
        env = utils_metaworld.make_metaworld_env(config.env, config.seed)
        dataset = utils_metaworld.MetaWorld_me_dataset(config)
    else:
        env = gym.make(config.env)

    state_dim = env.observation_space.shape[0]  # 39 for metaworld
    action_dim = env.action_space.shape[0]  # 4 for metaworld
    dimension = state_dim + action_dim

    seed = config.seed
    set_seed(seed, env)


    if config.normalize:
        state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    else:
        state_mean, state_std = 0, 1

    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )
        
    if config.use_reward_model:
        model = reward_model.RewardModel(config, None, None, None, dimension, None)
        model.load_model(config.log_path)
        dataset["rewards"] = model.get_reward(dataset)
        print("labeled by reward model")

    if config.normalize_reward:
        modify_reward(
            dataset,
            max_episode_steps=500,
            trivial_reward=config.trivial_reward,
        )

    env = wrap_env(env, state_mean=state_mean, state_std=state_std)
    config.buffer_size = dataset["observations"].shape[0]

    replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    replay_buffer.load_dataset(dataset)

    max_action = float(env.action_space.high[0])

    q_network = TwinQ(state_dim, action_dim).to(config.device)
    v_network = ValueFunction(state_dim).to(config.device)
    actor = (
        DeterministicPolicy(
            state_dim, action_dim, max_action, dropout=config.actor_dropout
        )
        if config.deterministic
        else GaussianPolicy(
            state_dim, action_dim, max_action, dropout=config.actor_dropout
        )
    ).to(config.device)
    v_optimizer = torch.optim.Adam(v_network.parameters(), lr=config.vf_lr)
    q_optimizer = torch.optim.Adam(q_network.parameters(), lr=config.qf_lr)
    actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_lr)

    kwargs = {
        "max_action": max_action,
        "actor": actor,
        "actor_optimizer": actor_optimizer,
        "q_network": q_network,
        "q_optimizer": q_optimizer,
        "v_network": v_network,
        "v_optimizer": v_optimizer,
        "discount": config.discount,
        "tau": config.tau,
        "device": config.device,
        # algorithm parameters
        "beta": config.beta,
        "max_steps": config.max_timesteps,
        "traj_batch_size": config.traj_batch_size,
    }

    # Initialize actor
    trainer = XQL_VA(**kwargs)

    if config.load_model != "":
        policy_file = Path(config.load_model)
        trainer.load_state_dict(torch.load(policy_file))
        actor = trainer.actor

    for t in tqdm.tqdm(range(int(config.max_timesteps))):
        batch = replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        traj_batch = replay_buffer.sample_trajectory(2 * config.traj_batch_size, config.segment_size)
        traj_batch = [b.to(config.device) for b in traj_batch]
        log_dict = trainer.train(batch, traj_batch)
        if (t + 1) % 5000 == 0:
            for k,v in log_dict.items():
                logger.record(k, v, trainer.total_it)
        # Evaluate episode
        if (t + 1) % config.eval_freq == 0:
            print(f"Time steps: {t + 1}")
            eval_scores, eval_success = eval_actor(
                env,
                config.env,
                actor,
                device=config.device,
                n_episodes=config.n_episodes,
                seed=config.seed,
            )
            eval_score = eval_scores.mean() 
            eval_success = eval_success.mean() * 100
            print("---------------------------------------")
            print(
                f"Evaluation over {config.n_episodes} episodes: "
                f"{eval_score:.3f} , success: {eval_success:.3f}"
            )
            print("---------------------------------------")
            if (config.checkpoints_path is not None) and (t + 1) % (
                20 * config.eval_freq
            ) == 0:
                torch.save(
                    trainer.state_dict(),
                    os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
                )
            logger.record("eval/eval_success", eval_success, trainer.total_it)
            logger.record("eval/eval_score", eval_score, trainer.total_it)
            logger.flush(trainer.total_it)

if __name__ == "__main__":
    train()
