import os
import random
import uuid
from copy import deepcopy
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional as F
# import wandb
from torch.distributions import Normal, TanhTransform, TransformedDistribution

import argparse
from tqdm import tqdm

from diffusion_SDE.loss import loss_fn
from diffusion_SDE.schedule import marginal_prob_std
from diffusion_SDE.model import ScoreNet


import warnings
warnings.filterwarnings("ignore", category=Warning, message=".*?.*?.*?")
warnings.filterwarnings("ignore", category=UserWarning, message=".*?.*.*?")
warnings.filterwarnings("ignore", category=FutureWarning, message=".*?.*.*?")
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*?.*.*?")


TensorBatch = List[torch.Tensor]

ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate")


@dataclass
class TrainConfig:
    # Experiment
    device: str = "cuda"
    env: str = "halfcheetah-medium-expert-v2"  # OpenAI gym environment name
    seed: int = 0  # Sets Gym, PyTorch and Numpy seeds
    eval_seed: int = 0  # Eval environment seed
    eval_freq: int = int(6e3)  # How often (time steps) we evaluate
    n_episodes: int = 10  # How many episodes run during evaluation
    online_iterations: int = int(2.1e5+5)  # Number of online updates
    checkpoints_path: Optional[str] = None  # Save path
    load_model: str = ""  # Model load file name, "" doesn't load
    # CQL
    buffer_size: int = 2_000_000  # Replay buffer size
    batch_size: int = 256  # Batch size for all networks
    discount: float = 0.99  # Discount factor
    alpha_multiplier: float = 1.0  # Multiplier for alpha in loss
    use_automatic_entropy_tuning: bool = True  # Tune entropy
    backup_entropy: bool = False  # Use backup entropy
    policy_lr: float = 1e-4  # Policy learning rate
    qf_lr: float = 3e-4  # Critics learning rate
    soft_target_update_rate: float = 5e-3  # Target network update rate
    bc_steps: int = int(0)  # Number of BC steps at start
    target_update_period: int = 1  # Frequency of target nets updates
    cql_alpha: float = 10.0  # CQL offline regularization parameter
    cql_alpha_online: float = 10.0  # CQL online regularization parameter
    cql_n_actions: int = 10  # Number of sampled actions
    cql_importance_sample: bool = True  # Use importance sampling
    cql_lagrange: bool = False  # Use Lagrange version of CQL
    cql_target_action_gap: float = -1.0  # Action gap
    cql_temp: float = 1.0  # CQL temperature
    cql_max_target_backup: bool = False  # Use max target backup
    cql_clip_diff_min: float = -np.inf  # Q-function lower loss clipping
    cql_clip_diff_max: float = np.inf  # Q-function upper loss clipping
    orthogonal_init: bool = True  # Orthogonal initialization
    normalize: bool = True  # Normalize states
    normalize_reward: bool = False  # Normalize reward
    q_n_hidden_layers: int = 3  # Number of hidden layers in Q networks
    reward_scale: float = 1.0  # Reward scale for normalization
    reward_bias: float = 0.0  # Reward bias for normalization
    # Cal-QL
    mixing_ratio: float = 0.5  # Data mixing ratio for online tuning
    is_sparse_reward: bool = False  # Use sparse reward
    # Wandb logging
    project: str = "CORL"
    group: str = "Cal-QL-D4RL"
    name: str = "Cal-QL"

    def __post_init__(self):
        self.name = f"{self.name}-{self.env}-{str(uuid.uuid4())[:8]}"
        if self.checkpoints_path is not None:
            self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)


def soft_update(target: nn.Module, source: nn.Module, tau: float):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)


def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
    mean = states.mean(0)
    std = states.std(0) + eps
    return mean, std


def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
    return (states - mean) / std


def wrap_env(
    env: gym.Env,
    state_mean: Union[np.ndarray, float] = 0.0,
    state_std: Union[np.ndarray, float] = 1.0,
    reward_scale: float = 1.0,
) -> gym.Env:
    # PEP 8: E731 do not assign a lambda expression, use a def
    def normalize_state(state):
        return (
            state - state_mean
        ) / state_std  # epsilon should be already added in std.

    def scale_reward(reward):
        # Please be careful, here reward is multiplied by scale!
        return reward_scale * reward

    env = gym.wrappers.TransformObservation(env, normalize_state)
    if reward_scale != 1.0:
        env = gym.wrappers.TransformReward(env, scale_reward)
    return env


class ReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        device: str = "cpu",
    ):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._mc_returns = torch.zeros(
            (buffer_size, 1), dtype=torch.float32, device=device
        )

        self._device = device

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32, device=self._device)

    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._mc_returns[:n_transitions] = self._to_tensor(data["mc_returns"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)

        print(f"Dataset size: {n_transitions}")

    def sample(self, batch_size: int) -> TensorBatch:
        indices = torch.randint(0, self._size, (batch_size,), device=self._device)
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        mc_returns = self._mc_returns[indices]
        return [states, actions, rewards, next_states, dones, mc_returns]

    def add_transition(
        self,
        state: np.ndarray,
        action: np.ndarray,
        reward: float,
        next_state: np.ndarray,
        done: bool,
    ):
        # Use this method to add new data into the replay buffer during fine-tuning.
        self._states[self._pointer] = self._to_tensor(state)
        self._actions[self._pointer] = self._to_tensor(action)
        self._rewards[self._pointer] = self._to_tensor(reward)
        self._next_states[self._pointer] = self._to_tensor(next_state)
        self._dones[self._pointer] = self._to_tensor(done)
        self._mc_returns[self._pointer] = 0.0

        self._pointer = (self._pointer + 1) % self._buffer_size
        self._size = min(self._size + 1, self._buffer_size)


def set_env_seed(env: Optional[gym.Env], seed: int):
    env.seed(seed)
    env.action_space.seed(seed)


def set_seed(
    seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
    if env is not None:
        set_env_seed(env, seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


# def wandb_init(config: dict) -> None:
#     wandb.init(
#         config=config,
#         project=config["project"],
#         group=config["group"],
#         name=config["name"],
#         id=str(uuid.uuid4()),
#     )
#     wandb.run.save()


def is_goal_reached(reward: float, info: Dict) -> bool:
    if "goal_achieved" in info:
        return info["goal_achieved"]
    return reward > 0  # Assuming that reaching target is a positive reward


@torch.no_grad()
def eval_actor(
    env: gym.Env, actor: nn.Module, device: str, n_episodes: int, seed: int
) -> Tuple[np.ndarray, np.ndarray]:
    env.seed(seed)
    actor.eval()
    episode_rewards = []
    successes = []
    for _ in range(n_episodes):
        state, done = env.reset(), False
        episode_reward = 0.0
        goal_achieved = False
        while not done:
            action = actor.act(state, device)
            state, reward, done, env_infos = env.step(action)
            episode_reward += reward
            if not goal_achieved:
                goal_achieved = is_goal_reached(reward, env_infos)
                # Valid only for environments with goal
        successes.append(float(goal_achieved))
        episode_rewards.append(episode_reward)

    actor.train()
    return np.asarray(episode_rewards), np.mean(successes)


def return_reward_range(dataset: Dict, max_episode_steps: int) -> Tuple[float, float]:
    returns, lengths = [], []
    ep_ret, ep_len = 0.0, 0
    for r, d in zip(dataset["rewards"], dataset["terminals"]):
        ep_ret += float(r)
        ep_len += 1
        if d or ep_len == max_episode_steps:
            returns.append(ep_ret)
            lengths.append(ep_len)
            ep_ret, ep_len = 0.0, 0
    lengths.append(ep_len)  # but still keep track of number of steps
    assert sum(lengths) == len(dataset["rewards"])
    return min(returns), max(returns)


def get_return_to_go(dataset: Dict, env: gym.Env, config: TrainConfig) -> np.ndarray:
    returns = []
    ep_ret, ep_len = 0.0, 0
    cur_rewards = []
    terminals = []
    N = len(dataset["rewards"])
    for t, (r, d) in enumerate(zip(dataset["rewards"], dataset["terminals"])):
        ep_ret += float(r)
        cur_rewards.append(float(r))
        terminals.append(float(d))
        ep_len += 1
        is_last_step = (
            (t == N - 1)
            or (
                np.linalg.norm(
                    dataset["observations"][t + 1] - dataset["next_observations"][t]
                )
                > 1e-6
            )
            or ep_len == env._max_episode_steps
        )

        if d or is_last_step:
            discounted_returns = [0] * ep_len
            prev_return = 0
            if (
                config.is_sparse_reward
                and r
                == env.ref_min_score * config.reward_scale + config.reward_bias
            ):
                discounted_returns = [r / (1 - config.discount)] * ep_len
            else:
                for i in reversed(range(ep_len)):
                    discounted_returns[i] = cur_rewards[
                        i
                    ] + config.discount * prev_return * (1 - terminals[i])
                    prev_return = discounted_returns[i]
            returns += discounted_returns
            ep_ret, ep_len = 0.0, 0
            cur_rewards = []
            terminals = []
    return returns


def modify_reward(
    dataset: Dict,
    env_name: str,
    max_episode_steps: int = 1000,
    reward_scale: float = 1.0,
    reward_bias: float = 0.0,
) -> Dict:
    modification_data = {}
    if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
        min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
        dataset["rewards"] /= max_ret - min_ret
        dataset["rewards"] *= max_episode_steps
        modification_data = {
            "max_ret": max_ret,
            "min_ret": min_ret,
            "max_episode_steps": max_episode_steps,
        }
    dataset["rewards"] = dataset["rewards"] * reward_scale + reward_bias
    return modification_data


def modify_reward_online(
    reward: float,
    env_name: str,
    reward_scale: float = 1.0,
    reward_bias: float = 0.0,
    **kwargs,
) -> float:
    if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
        reward /= kwargs["max_ret"] - kwargs["min_ret"]
        reward *= kwargs["max_episode_steps"]
    reward = reward * reward_scale + reward_bias
    return reward


def extend_and_repeat(tensor: torch.Tensor, dim: int, repeat: int) -> torch.Tensor:
    return tensor.unsqueeze(dim).repeat_interleave(repeat, dim=dim)


def init_module_weights(module: torch.nn.Module, orthogonal_init: bool = False):
    if isinstance(module, nn.Linear):
        if orthogonal_init:
            nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
            nn.init.constant_(module.bias, 0.0)
        else:
            nn.init.xavier_uniform_(module.weight, gain=1e-2)


class ReparameterizedTanhGaussian(nn.Module):
    def __init__(
        self,
        log_std_min: float = -20.0,
        log_std_max: float = 2.0,
        no_tanh: bool = False,
    ):
        super().__init__()
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        self.no_tanh = no_tanh

    def log_prob(
        self, mean: torch.Tensor, log_std: torch.Tensor, sample: torch.Tensor
    ) -> torch.Tensor:
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        std = torch.exp(log_std)
        if self.no_tanh:
            action_distribution = Normal(mean, std)
        else:
            action_distribution = TransformedDistribution(
                Normal(mean, std), TanhTransform(cache_size=1)
            )
        return torch.sum(action_distribution.log_prob(sample), dim=-1)

    def forward(
        self,
        mean: torch.Tensor,
        log_std: torch.Tensor,
        deterministic: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        std = torch.exp(log_std)

        if self.no_tanh:
            action_distribution = Normal(mean, std)
        else:
            action_distribution = TransformedDistribution(
                Normal(mean, std), TanhTransform(cache_size=1)
            )

        if deterministic:
            action_sample = torch.tanh(mean)
        else:
            action_sample = action_distribution.rsample()

        log_prob = torch.sum(action_distribution.log_prob(action_sample), dim=-1)

        return action_sample, log_prob


class TanhGaussianPolicy(nn.Module):
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        max_action: float,
        log_std_multiplier: float = 1.0,
        log_std_offset: float = -1.0,
        orthogonal_init: bool = False,
        no_tanh: bool = False,
    ):
        super().__init__()
        self.observation_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.orthogonal_init = orthogonal_init
        self.no_tanh = no_tanh

        self.base_network = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 2 * action_dim),
        )

        if orthogonal_init:
            self.base_network.apply(lambda m: init_module_weights(m, True))
        else:
            init_module_weights(self.base_network[-1], False)

        self.log_std_multiplier = Scalar(log_std_multiplier)
        self.log_std_offset = Scalar(log_std_offset)
        self.tanh_gaussian = ReparameterizedTanhGaussian(no_tanh=no_tanh)

    def log_prob(
        self, observations: torch.Tensor, actions: torch.Tensor
    ) -> torch.Tensor:
        if actions.ndim == 3:
            observations = extend_and_repeat(observations, 1, actions.shape[1])
        base_network_output = self.base_network(observations)
        mean, log_std = torch.split(base_network_output, self.action_dim, dim=-1)
        log_std = self.log_std_multiplier() * log_std + self.log_std_offset()
        _, log_probs = self.tanh_gaussian(mean, log_std, False)
        return log_probs

    def forward(
        self,
        observations: torch.Tensor,
        deterministic: bool = False,
        repeat: bool = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if repeat is not None:
            observations = extend_and_repeat(observations, 1, repeat)
        base_network_output = self.base_network(observations)
        mean, log_std = torch.split(base_network_output, self.action_dim, dim=-1)
        log_std = self.log_std_multiplier() * log_std + self.log_std_offset()
        actions, log_probs = self.tanh_gaussian(mean, log_std, deterministic)
        return self.max_action * actions, log_probs

    @torch.no_grad()
    def act(self, state: np.ndarray, device: str = "cpu"):
        state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32)
        with torch.no_grad():
            actions, _ = self(state, not self.training)
        return actions.cpu().data.numpy().flatten()


class FullyConnectedQFunction(nn.Module):
    def __init__(
        self,
        observation_dim: int,
        action_dim: int,
        orthogonal_init: bool = False,
        n_hidden_layers: int = 2,
    ):
        super().__init__()
        self.observation_dim = observation_dim
        self.action_dim = action_dim
        self.orthogonal_init = orthogonal_init

        layers = [
            nn.Linear(observation_dim + action_dim, 256),
            nn.ReLU(),
        ]
        for _ in range(n_hidden_layers - 1):
            layers.append(nn.Linear(256, 256))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(256, 1))

        self.network = nn.Sequential(*layers)
        if orthogonal_init:
            self.network.apply(lambda m: init_module_weights(m, True))
        else:
            init_module_weights(self.network[-1], False)

    def forward(self, observations: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        multiple_actions = False
        batch_size = observations.shape[0]
        if actions.ndim == 3 and observations.ndim == 2:
            multiple_actions = True
            observations = extend_and_repeat(observations, 1, actions.shape[1]).reshape(
                -1, observations.shape[-1]
            )
            actions = actions.reshape(-1, actions.shape[-1])
        input_tensor = torch.cat([observations, actions], dim=-1)
        q_values = torch.squeeze(self.network(input_tensor), dim=-1)
        if multiple_actions:
            q_values = q_values.reshape(batch_size, -1)
        return q_values


class Scalar(nn.Module):
    def __init__(self, init_value: float):
        super().__init__()
        self.constant = nn.Parameter(torch.tensor(init_value, dtype=torch.float32))

    def forward(self) -> nn.Parameter:
        return self.constant


class CalQL:
    def __init__(
        self,
        critic_1,
        critic_1_optimizer,
        critic_2,
        critic_2_optimizer,
        actor,
        actor_optimizer,
        target_entropy: float,
        discount: float = 0.99,
        alpha_multiplier: float = 1.0,
        use_automatic_entropy_tuning: bool = True,
        backup_entropy: bool = False,
        policy_lr: bool = 3e-4,
        qf_lr: bool = 3e-4,
        soft_target_update_rate: float = 5e-3,
        bc_steps=100000,
        target_update_period: int = 1,
        cql_n_actions: int = 10,
        cql_importance_sample: bool = True,
        cql_lagrange: bool = False,
        cql_target_action_gap: float = -1.0,
        cql_temp: float = 1.0,
        cql_alpha: float = 5.0,
        cql_max_target_backup: bool = False,
        cql_clip_diff_min: float = -np.inf,
        cql_clip_diff_max: float = np.inf,
        device: str = "cpu",
    ):
        super().__init__()

        self.discount = discount
        self.target_entropy = target_entropy
        self.alpha_multiplier = alpha_multiplier
        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        self.backup_entropy = backup_entropy
        self.policy_lr = policy_lr
        self.qf_lr = qf_lr
        self.soft_target_update_rate = soft_target_update_rate
        self.bc_steps = bc_steps
        self.target_update_period = target_update_period
        self.cql_n_actions = cql_n_actions
        self.cql_importance_sample = cql_importance_sample
        self.cql_lagrange = cql_lagrange
        self.cql_target_action_gap = cql_target_action_gap
        self.cql_temp = cql_temp
        self.cql_alpha = cql_alpha
        self.cql_max_target_backup = cql_max_target_backup
        self.cql_clip_diff_min = cql_clip_diff_min
        self.cql_clip_diff_max = cql_clip_diff_max
        self._device = device

        self.total_it = 0

        self.critic_1 = critic_1
        self.critic_2 = critic_2

        self.target_critic_1 = deepcopy(self.critic_1).to(device)
        self.target_critic_2 = deepcopy(self.critic_2).to(device)

        self.actor = actor

        self.actor_optimizer = actor_optimizer
        self.critic_1_optimizer = critic_1_optimizer
        self.critic_2_optimizer = critic_2_optimizer

        if self.use_automatic_entropy_tuning:
            self.log_alpha = Scalar(0.0)
            self.alpha_optimizer = torch.optim.Adam(
                self.log_alpha.parameters(),
                lr=self.policy_lr,
            )
        else:
            self.log_alpha = None

        self.log_alpha_prime = Scalar(1.0)
        self.alpha_prime_optimizer = torch.optim.Adam(
            self.log_alpha_prime.parameters(),
            lr=self.qf_lr,
        )

        self._calibration_enabled = False
        self.total_it = 0

    def update_target_network(self, soft_target_update_rate: float):
        soft_update(self.target_critic_1, self.critic_1, soft_target_update_rate)
        soft_update(self.target_critic_2, self.critic_2, soft_target_update_rate)

    def switch_calibration(self):
        self._calibration_enabled = not self._calibration_enabled

    def _alpha_and_alpha_loss(self, observations: torch.Tensor, log_pi: torch.Tensor):
        if self.use_automatic_entropy_tuning:
            alpha_loss = -(
                self.log_alpha() * (log_pi + self.target_entropy).detach()
            ).mean()
            alpha = self.log_alpha().exp() * self.alpha_multiplier
        else:
            alpha_loss = observations.new_tensor(0.0)
            alpha = observations.new_tensor(self.alpha_multiplier)
        return alpha, alpha_loss

    def _policy_loss(
        self,
        observations: torch.Tensor,
        actions: torch.Tensor,
        new_actions: torch.Tensor,
        alpha: torch.Tensor,
        log_pi: torch.Tensor,
    ) -> torch.Tensor:
        if self.total_it <= self.bc_steps:
            log_probs = self.actor.log_prob(observations, actions)
            policy_loss = (alpha * log_pi - log_probs).mean()
        else:
            q_new_actions = torch.min(
                self.critic_1(observations, new_actions),
                self.critic_2(observations, new_actions),
            )
            policy_loss = (alpha * log_pi - q_new_actions).mean()
        return policy_loss

    def _q_loss(
        self,
        observations: torch.Tensor,
        actions: torch.Tensor,
        next_observations: torch.Tensor,
        rewards: torch.Tensor,
        dones: torch.Tensor,
        mc_returns: torch.Tensor,
        alpha: torch.Tensor,
        log_dict: Dict,
        flag: bool
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        q1_predicted = self.critic_1(observations, actions)
        q2_predicted = self.critic_2(observations, actions)

        if self.cql_max_target_backup:
            new_next_actions, next_log_pi = self.actor(
                next_observations, repeat=self.cql_n_actions
            )
            target_q_values, max_target_indices = torch.max(
                torch.min(
                    self.target_critic_1(next_observations, new_next_actions),
                    self.target_critic_2(next_observations, new_next_actions),
                ),
                dim=-1,
            )
            next_log_pi = torch.gather(
                next_log_pi, -1, max_target_indices.unsqueeze(-1)
            ).squeeze(-1)
        else:
            new_next_actions, next_log_pi = self.actor(next_observations)
            target_q_values = torch.min(
                self.target_critic_1(next_observations, new_next_actions),
                self.target_critic_2(next_observations, new_next_actions),
            )

        if self.backup_entropy:
            target_q_values = target_q_values - alpha * next_log_pi

        target_q_values = target_q_values.unsqueeze(-1)
        td_target = rewards + (1.0 - dones) * self.discount * target_q_values.detach()
        td_target = td_target.squeeze(-1)
        qf1_loss = F.mse_loss(q1_predicted, td_target.detach())
        qf2_loss = F.mse_loss(q2_predicted, td_target.detach())

        if flag:
            # CQL
            batch_size = actions.shape[0]
            action_dim = actions.shape[-1]
            cql_random_actions = actions.new_empty(
                (batch_size, self.cql_n_actions, action_dim), requires_grad=False
            ).uniform_(-1, 1)
            cql_current_actions, cql_current_log_pis = self.actor(
                observations, repeat=self.cql_n_actions
            )
            cql_next_actions, cql_next_log_pis = self.actor(
                next_observations, repeat=self.cql_n_actions
            )
            cql_current_actions, cql_current_log_pis = (
                cql_current_actions.detach(),
                cql_current_log_pis.detach(),
            )
            cql_next_actions, cql_next_log_pis = (
                cql_next_actions.detach(),
                cql_next_log_pis.detach(),
            )

            cql_q1_rand = self.critic_1(observations, cql_random_actions)
            cql_q2_rand = self.critic_2(observations, cql_random_actions)
            cql_q1_current_actions = self.critic_1(observations, cql_current_actions)
            cql_q2_current_actions = self.critic_2(observations, cql_current_actions)
            cql_q1_next_actions = self.critic_1(observations, cql_next_actions)
            cql_q2_next_actions = self.critic_2(observations, cql_next_actions)

            # Calibration
            lower_bounds = mc_returns.reshape(-1, 1).repeat(
                1, cql_q1_current_actions.shape[1]
            )

            num_vals = torch.sum(lower_bounds == lower_bounds)
            bound_rate_cql_q1_current_actions = (
                torch.sum(cql_q1_current_actions < lower_bounds) / num_vals
            )
            bound_rate_cql_q2_current_actions = (
                torch.sum(cql_q2_current_actions < lower_bounds) / num_vals
            )
            bound_rate_cql_q1_next_actions = (
                torch.sum(cql_q1_next_actions < lower_bounds) / num_vals
            )
            bound_rate_cql_q2_next_actions = (
                torch.sum(cql_q2_next_actions < lower_bounds) / num_vals
            )

            """ Cal-QL: bound Q-values with MC return-to-go """
            if self._calibration_enabled:
                cql_q1_current_actions = torch.maximum(cql_q1_current_actions, lower_bounds)
                cql_q2_current_actions = torch.maximum(cql_q2_current_actions, lower_bounds)
                cql_q1_next_actions = torch.maximum(cql_q1_next_actions, lower_bounds)
                cql_q2_next_actions = torch.maximum(cql_q2_next_actions, lower_bounds)

            cql_cat_q1 = torch.cat(
                [
                    cql_q1_rand,
                    torch.unsqueeze(q1_predicted, 1),
                    cql_q1_next_actions,
                    cql_q1_current_actions,
                ],
                dim=1,
            )
            cql_cat_q2 = torch.cat(
                [
                    cql_q2_rand,
                    torch.unsqueeze(q2_predicted, 1),
                    cql_q2_next_actions,
                    cql_q2_current_actions,
                ],
                dim=1,
            )
            cql_std_q1 = torch.std(cql_cat_q1, dim=1)
            cql_std_q2 = torch.std(cql_cat_q2, dim=1)

            if self.cql_importance_sample:
                random_density = np.log(0.5**action_dim)
                cql_cat_q1 = torch.cat(
                    [
                        cql_q1_rand - random_density,
                        cql_q1_next_actions - cql_next_log_pis.detach(),
                        cql_q1_current_actions - cql_current_log_pis.detach(),
                    ],
                    dim=1,
                )
                cql_cat_q2 = torch.cat(
                    [
                        cql_q2_rand - random_density,
                        cql_q2_next_actions - cql_next_log_pis.detach(),
                        cql_q2_current_actions - cql_current_log_pis.detach(),
                    ],
                    dim=1,
                )

            cql_qf1_ood = torch.logsumexp(cql_cat_q1 / self.cql_temp, dim=1) * self.cql_temp
            cql_qf2_ood = torch.logsumexp(cql_cat_q2 / self.cql_temp, dim=1) * self.cql_temp

            """Subtract the log likelihood of data"""
            cql_qf1_diff = torch.clamp(
                cql_qf1_ood - q1_predicted,
                self.cql_clip_diff_min,
                self.cql_clip_diff_max,
            ).mean()
            cql_qf2_diff = torch.clamp(
                cql_qf2_ood - q2_predicted,
                self.cql_clip_diff_min,
                self.cql_clip_diff_max,
            ).mean()

            if self.cql_lagrange:
                alpha_prime = torch.clamp(
                    torch.exp(self.log_alpha_prime()), min=0.0, max=1000000.0
                )
                cql_min_qf1_loss = (
                    alpha_prime
                    * self.cql_alpha
                    * (cql_qf1_diff - self.cql_target_action_gap)
                )
                cql_min_qf2_loss = (
                    alpha_prime
                    * self.cql_alpha
                    * (cql_qf2_diff - self.cql_target_action_gap)
                )

                self.alpha_prime_optimizer.zero_grad()
                alpha_prime_loss = (-cql_min_qf1_loss - cql_min_qf2_loss) * 0.5
                alpha_prime_loss.backward(retain_graph=True)
                self.alpha_prime_optimizer.step()
            else:
                cql_min_qf1_loss = cql_qf1_diff * self.cql_alpha
                cql_min_qf2_loss = cql_qf2_diff * self.cql_alpha
                alpha_prime_loss = observations.new_tensor(0.0)
                alpha_prime = observations.new_tensor(0.0)

            qf_loss = qf1_loss + qf2_loss + cql_min_qf1_loss + cql_min_qf2_loss

            log_dict.update(
                dict(
                    qf1_loss=qf1_loss.item(),
                    qf2_loss=qf2_loss.item(),
                    alpha=alpha.item(),
                    average_qf1=q1_predicted.mean().item(),
                    average_qf2=q2_predicted.mean().item(),
                    average_target_q=target_q_values.mean().item(),
                )
            )

            log_dict.update(
                dict(
                    cql_std_q1=cql_std_q1.mean().item(),
                    cql_std_q2=cql_std_q2.mean().item(),
                    cql_q1_rand=cql_q1_rand.mean().item(),
                    cql_q2_rand=cql_q2_rand.mean().item(),
                    cql_min_qf1_loss=cql_min_qf1_loss.mean().item(),
                    cql_min_qf2_loss=cql_min_qf2_loss.mean().item(),
                    cql_qf1_diff=cql_qf1_diff.mean().item(),
                    cql_qf2_diff=cql_qf2_diff.mean().item(),
                    cql_q1_current_actions=cql_q1_current_actions.mean().item(),
                    cql_q2_current_actions=cql_q2_current_actions.mean().item(),
                    cql_q1_next_actions=cql_q1_next_actions.mean().item(),
                    cql_q2_next_actions=cql_q2_next_actions.mean().item(),
                    alpha_prime_loss=alpha_prime_loss.item(),
                    alpha_prime=alpha_prime.item(),
                    bound_rate_cql_q1_current_actions=bound_rate_cql_q1_current_actions.item(),  # noqa
                    bound_rate_cql_q2_current_actions=bound_rate_cql_q2_current_actions.item(),  # noqa
                    bound_rate_cql_q1_next_actions=bound_rate_cql_q1_next_actions.item(),
                    bound_rate_cql_q2_next_actions=bound_rate_cql_q2_next_actions.item(),
                )
            )
            return qf_loss, alpha_prime, alpha_prime_loss

        else:
            qf_loss = qf1_loss + qf2_loss
            log_dict.update(
                dict(
                    qf1_loss=qf1_loss.item(),
                    qf2_loss=qf2_loss.item(),
                    alpha=alpha.item(),
                    average_qf1=q1_predicted.mean().item(),
                    average_qf2=q2_predicted.mean().item(),
                    average_target_q=target_q_values.mean().item(),
                )
            )

            return qf_loss

    def train(self, batch: TensorBatch) -> Dict[str, float]:
        (
            observations,
            actions,
            rewards,
            next_observations,
            dones,
            mc_returns,
        ) = batch
        self.total_it += 1

        new_actions, log_pi = self.actor(observations)

        alpha, alpha_loss = self._alpha_and_alpha_loss(observations, log_pi)

        """ Policy loss """
        policy_loss = self._policy_loss(
            observations, actions, new_actions, alpha, log_pi
        )

        log_dict = dict(
            log_pi=log_pi.mean().item(),
            policy_loss=policy_loss.item(),
            alpha_loss=alpha_loss.item(),
            alpha=alpha.item(),
        )

        """ Q function loss """

        qf_loss_off, alpha_prime, alpha_prime_loss = self._q_loss(
            observations[:128],
            actions[:128],
            next_observations[:128],
            rewards[:128],
            dones[:128],
            mc_returns[:128],
            alpha,
            log_dict,
            True,
        )

        qf_loss_on = self._q_loss(
            observations[128:],
            actions[128:],
            next_observations[128:],
            rewards[128:],
            dones[128:],
            mc_returns,
            alpha,
            log_dict,
            False,
        )

        if self.use_automatic_entropy_tuning:
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()

        self.actor_optimizer.zero_grad()
        policy_loss.backward()
        self.actor_optimizer.step()

        self.critic_1_optimizer.zero_grad()
        self.critic_2_optimizer.zero_grad()
        qf_loss = (qf_loss_off + qf_loss_on) / 2.
        qf_loss.backward(retain_graph=True)
        self.critic_1_optimizer.step()
        self.critic_2_optimizer.step()

        if self.total_it % self.target_update_period == 0:
            self.update_target_network(self.soft_target_update_rate)

        return log_dict

    def state_dict(self) -> Dict[str, Any]:
        return {
            "actor": self.actor.state_dict(),
            "critic1": self.critic_1.state_dict(),
            "critic2": self.critic_2.state_dict(),
            "critic1_target": self.target_critic_1.state_dict(),
            "critic2_target": self.target_critic_2.state_dict(),
            "critic_1_optimizer": self.critic_1_optimizer.state_dict(),
            "critic_2_optimizer": self.critic_2_optimizer.state_dict(),
            "actor_optim": self.actor_optimizer.state_dict(),
            "sac_log_alpha": self.log_alpha,
            "sac_log_alpha_optim": self.alpha_optimizer.state_dict(),
            "cql_log_alpha": self.log_alpha_prime,
            "cql_log_alpha_optim": self.alpha_prime_optimizer.state_dict(),
            "total_it": self.total_it,
        }

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.actor.load_state_dict(state_dict=state_dict["actor"])
        self.critic_1.load_state_dict(state_dict=state_dict["critic1"])
        self.critic_2.load_state_dict(state_dict=state_dict["critic2"])

        self.target_critic_1.load_state_dict(state_dict=state_dict["critic1_target"])
        self.target_critic_2.load_state_dict(state_dict=state_dict["critic2_target"])

        self.critic_1_optimizer.load_state_dict(
            state_dict=state_dict["critic_1_optimizer"]
        )
        self.critic_2_optimizer.load_state_dict(
            state_dict=state_dict["critic_2_optimizer"]
        )
        self.actor_optimizer.load_state_dict(state_dict=state_dict["actor_optim"])

        self.log_alpha = state_dict["sac_log_alpha"]
        self.alpha_optimizer.load_state_dict(
            state_dict=state_dict["sac_log_alpha_optim"]
        )

        self.log_alpha_prime = state_dict["cql_log_alpha"]
        self.alpha_prime_optimizer.load_state_dict(
            state_dict=state_dict["cql_log_alpha_optim"]
        )
        self.total_it = state_dict["total_it"]


# @pyrallis.wrap()
def train(args):
    config=TrainConfig()
    config.env = args.env 
    config.seed = args.seed
    config.load_model = args.load_model

    env = gym.make(config.env)
    eval_env = gym.make(config.env)

    is_env_with_goal = config.env.startswith(ENVS_WITH_GOAL)
    batch_size_offline = int(config.batch_size * config.mixing_ratio)
    batch_size_online = config.batch_size - batch_size_offline
    max_steps = env._max_episode_steps

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    dataset = d4rl.qlearning_dataset(env)

    reward_mod_dict = {}
    if config.normalize_reward:
        reward_mod_dict = modify_reward(
            dataset,
            config.env,
            reward_scale=config.reward_scale,
            reward_bias=config.reward_bias,
        )
    mc_returns = get_return_to_go(dataset, env, config)
    dataset["mc_returns"] = np.array(mc_returns)
    assert len(dataset["mc_returns"]) == len(dataset["rewards"])

    if config.normalize:
        state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    else:
        state_mean, state_std = 0, 1

    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )
    env = wrap_env(env, state_mean=state_mean, state_std=state_std)
    eval_env = wrap_env(eval_env, state_mean=state_mean, state_std=state_std)
    offline_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    online_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    offline_buffer.load_d4rl_dataset(dataset)

    max_action = float(env.action_space.high[0])

    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    # Set seeds
    seed = config.seed
    set_seed(seed, env)
    set_env_seed(eval_env, config.eval_seed)

    critic_1 = FullyConnectedQFunction(
        state_dim,
        action_dim,
        config.orthogonal_init,
        config.q_n_hidden_layers,
    ).to(config.device)
    critic_2 = FullyConnectedQFunction(
        state_dim,
        action_dim,
        config.orthogonal_init,
        config.q_n_hidden_layers,
    ).to(config.device)
    critic_1_optimizer = torch.optim.Adam(list(critic_1.parameters()), config.qf_lr)
    critic_2_optimizer = torch.optim.Adam(list(critic_2.parameters()), config.qf_lr)

    actor = TanhGaussianPolicy(
        state_dim,
        action_dim,
        max_action,
        orthogonal_init=config.orthogonal_init,
    ).to(config.device)
    actor_optimizer = torch.optim.Adam(actor.parameters(), config.policy_lr)

    kwargs = {
        "critic_1": critic_1,
        "critic_2": critic_2,
        "critic_1_optimizer": critic_1_optimizer,
        "critic_2_optimizer": critic_2_optimizer,
        "actor": actor,
        "actor_optimizer": actor_optimizer,
        "discount": config.discount,
        "soft_target_update_rate": config.soft_target_update_rate,
        "device": config.device,
        # CQL
        "target_entropy": -np.prod(env.action_space.shape).item(),
        "alpha_multiplier": config.alpha_multiplier,
        "use_automatic_entropy_tuning": config.use_automatic_entropy_tuning,
        "backup_entropy": config.backup_entropy,
        "policy_lr": config.policy_lr,
        "qf_lr": config.qf_lr,
        "bc_steps": config.bc_steps,
        "target_update_period": config.target_update_period,
        "cql_n_actions": config.cql_n_actions,
        "cql_importance_sample": config.cql_importance_sample,
        "cql_lagrange": config.cql_lagrange,
        "cql_target_action_gap": config.cql_target_action_gap,
        "cql_temp": config.cql_temp,
        "cql_alpha": config.cql_alpha,
        "cql_max_target_backup": config.cql_max_target_backup,
        "cql_clip_diff_min": config.cql_clip_diff_min,
        "cql_clip_diff_max": config.cql_clip_diff_max,
    }

    print("---------------------------------------")
    print(f"Training Cal-QL, Env: {config.env}, Seed: {seed}")
    print("---------------------------------------")

    # Initialize actor
    trainer = CalQL(**kwargs)

    if config.load_model != "":
        #policy_file = Path(config.load_model)
        trainer.load_state_dict(torch.load(config.load_model))
        actor = trainer.actor
        print(f"========================================================Loaded model from {config.load_model}")

    # wandb_init(asdict(config))

    evaluations = []
    state, done = env.reset(), False
    episode_return = 0
    episode_step = 0
    goal_achieved = False

    eval_successes = []
    train_successes = []

    save_path = './logs_k_'+f"{args.k}/{args.log_name}/{config.env}_seed_{config.seed}/"
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    #========================
    import functools
    from torch.utils.data import DataLoader
    from dataset.dataset import D4RL_dataset

    marginal_prob_std_fn = functools.partial(marginal_prob_std, device=args.device)
    args.marginal_prob_std_fn = marginal_prob_std_fn
    score_model= ScoreNet(input_dim=state_dim+action_dim, output_dim=action_dim, marginal_prob_std=marginal_prob_std_fn, args=args).to(args.device)
    score_model.q[0].to(args.device)

    print("loading actor...")
    ckpt = torch.load(args.actor_load_path, map_location=args.device)
    score_model.load_state_dict(ckpt)
    score_model.q[0].guidance_scale = args.s
    dataset_dm = D4RL_dataset(args)
    data_loader = DataLoader(dataset_dm, batch_size=256, shuffle=True)
    dataset_dm.fake_actions = torch.Tensor(np.load('./models_rl/'+args.env+'/actions{}_raw.npy'.format(args.diffusion_steps)).astype(np.float32)).to(args.device)

    def datas_():
        while True:
            yield from data_loader
    datas = datas_()
    #========================

    print("Offline pretraining")
    for t in tqdm(range(int(config.online_iterations)), ncols=100):

        online_log = {}
        episode_step += 1

        action, _ = actor(
            torch.tensor(
                state.reshape(1, -1),
                device=config.device,
                dtype=torch.float32,
            )
        )
        action = action.cpu().data.numpy().flatten()
        next_state, reward, done, env_infos = env.step(action)

        if not goal_achieved:
            goal_achieved = is_goal_reached(reward, env_infos)
        episode_return += reward
        real_done = False  # Episode can timeout which is different from done
        if done and episode_step < max_steps:
            real_done = True

        if config.normalize_reward:
            reward = modify_reward_online(
                reward,
                config.env,
                reward_scale=config.reward_scale,
                reward_bias=config.reward_bias,
                **reward_mod_dict,
            )
        online_buffer.add_transition(state, action, reward, next_state, real_done)
        state = next_state

        if done:
            state, done = env.reset(), False
            # Valid only for envs with goal, e.g. AntMaze, Adroit
            if is_env_with_goal:
                train_successes.append(goal_achieved)
                online_log["train/regret"] = np.mean(1 - np.array(train_successes))
                online_log["train/is_success"] = float(goal_achieved)
            online_log["train/episode_return"] = episode_return
            normalized_return = eval_env.get_normalized_score(episode_return)
            online_log["train/d4rl_normalized_episode_return"] = (
                normalized_return * 100.0
            )
            online_log["train/episode_length"] = episode_step
            episode_return = 0
            episode_step = 0
            goal_achieved = False

        offline_batch = offline_buffer.sample(batch_size_offline)   
        online_batch  = online_buffer.sample(batch_size_online)     

        if t > 1000:
            dev = config.device

            # ----- distances (smaller = more offline-like) -----
            obs_off, act_off = offline_batch[0], offline_batch[1]
            obs_on,  act_on  = online_batch[0],  online_batch[1]
            with torch.no_grad():
                a_hat_off = score_model.select_actions_ours(obs_off)
                a_hat_on  = score_model.select_actions_ours(obs_on)
            adim = act_off.shape[1]
            d_off = 0.5 * ((a_hat_off - act_off)**2).sum(dim=1) / float(adim)   # [128]
            d_on  = 0.5 * ((a_hat_on  - act_on )**2).sum(dim=1) / float(adim)   # [128]

            # ----- fit Gaussians -----
            mu_off = d_off.mean().item()
            mu_on  = d_on.mean().item()
            s_off  = max(d_off.std(unbiased=False).item(), 1e-12)
            s_on   = max(d_on.std(unbiased=False).item(),  1e-12)

            # ----- rare case guard: offline has larger distances -> do nothing -----
            if mu_off >= mu_on:
                # final batch: offline first, then online (no exchange)
                batch_exchanged = [
                    torch.cat([off.to(dev), on.to(dev)], dim=0)
                    for off, on in zip(offline_batch, online_batch)
                ]
                log_dict = trainer.train(batch_exchanged)
            else:
                # ----- PDF intersection: solve N(mu_off,s_off^2) = N(mu_on,s_on^2) -----
                A = (1.0/(s_on*s_on)) - (1.0/(s_off*s_off))
                B = -2.0*(mu_on/(s_on*s_on) - mu_off/(s_off*s_off))
                C = (mu_on*mu_on)/(s_on*s_on) - (mu_off*mu_off)/(s_off*s_off) \
                    - 2.0*float(torch.log(torch.tensor(s_on/s_off)))

                if abs(A) < 1e-14:
                    tau = 0.5*(mu_off + mu_on)
                else:
                    disc = B*B - 4*A*C
                    if disc < 0:
                        tau = 0.5*(mu_off + mu_on)
                    else:
                        sqrt_disc = disc**0.5
                        x1 = (-B + sqrt_disc)/(2*A)
                        x2 = (-B - sqrt_disc)/(2*A)
                        lo, hi = (mu_off, mu_on) if mu_off <= mu_on else (mu_on, mu_off)
                        between = [x for x in (x1, x2) if lo <= x <= hi]
                        if between:
                            mid = 0.5*(mu_off + mu_on)
                            tau = between[0] if abs(between[0]-mid) <= abs(between[-1]-mid) else between[-1]
                        else:
                            mid = 0.5*(mu_off + mu_on)
                            tau = x1 if abs(x1-mid) <= abs(x2-mid) else x2

                # ----- exchange pools -----
                off2on_pool = torch.nonzero(d_off >= tau, as_tuple=False).squeeze(1)  # offline that look online
                on2off_pool = torch.nonzero(d_on  <  tau, as_tuple=False).squeeze(1)  # online  that look offline

                # balanced K, cap at 32 (≈25% of 128)
                K = min(off2on_pool.numel(), on2off_pool.numel(), args.k)
                # select most confident K from each pool
                if K > 0:
                    off_sel = off2on_pool[torch.topk(d_off[off2on_pool], K, largest=True ).indices]
                    on_sel  = on2off_pool[torch.topk(d_on[on2off_pool],   K, largest=False).indices]
                else:
                    off_sel = off2on_pool[:0]
                    on_sel  = on2off_pool[:0]

                # complements
                def complement_idx(n, idx):
                    mask = torch.ones(n, dtype=torch.bool, device=dev)
                    if idx.numel() > 0: mask[idx] = False
                    return torch.nonzero(mask, as_tuple=False).squeeze(1)

                off_rem = complement_idx(d_off.shape[0], off_sel)
                on_rem  = complement_idx(d_on.shape[0],  on_sel)

                # swapped per-field tensors (still 128 each), then offline first + online
                offline_new = [
                    torch.cat([off[off_rem], on[on_sel]], dim=0)
                    for off, on in zip(offline_batch, online_batch)
                ]
                online_new = [
                    torch.cat([on[on_rem],  off[off_sel]], dim=0)
                    for off, on in zip(offline_batch, online_batch)
                ]
                batch_exchanged = [
                    torch.cat([off_f, on_f], dim=0)
                    for off_f, on_f in zip(offline_new, online_new)
                ]
                log_dict = trainer.train(batch_exchanged)
        else:
            batch = [
                torch.vstack((off.to(config.device, non_blocking=True),
                            on.to(config.device, non_blocking=True)))
                for off, on in zip(offline_batch, online_batch)
            ]
            log_dict = trainer.train(batch)


        log_dict["online_iter"] = (t)
        log_dict.update(online_log)
        
        data = next(datas)
        s = data['s']
        fake_a = data['fake_a']
        B, N, _ = fake_a.shape
        s_flat = s.unsqueeze(1).expand(-1, N, -1).reshape(B * N, -1)
        a_flat = fake_a.reshape(B * N, -1)

        energy = torch.min(
            trainer.target_critic_1(s_flat, a_flat),
            trainer.target_critic_2(s_flat, a_flat)
        ).view(B, N).detach()
        loss_energy = score_model.q[0].update_qt(data, energy)
        log_dict["loss_energy"] = loss_energy.item()
        # wandb.log(log_dict, step=trainer.total_it)
        # Evaluate episode

        if t % config.eval_freq == 0:
            print(f"Time steps: {t + 1}")
            eval_scores, success_rate = eval_actor(
                eval_env,
                actor,
                device=config.device,
                n_episodes=config.n_episodes,
                seed=config.seed,
            )
            eval_score = eval_scores.mean()
            eval_log = {}
            normalized = eval_env.get_normalized_score(np.mean(eval_scores))
            # Valid only for envs with goal, e.g. AntMaze, Adroit
            if is_env_with_goal:
                eval_successes.append(success_rate)
                eval_log["eval/regret"] = np.mean(1 - np.array(train_successes))
                eval_log["eval/success_rate"] = success_rate
            
            normalized_eval_score = normalized * 100.0
            eval_log["eval/d4rl_normalized_score"] = normalized_eval_score
            evaluations.append(normalized_eval_score)
            print("---------------------------------------")
            print(
                f"Evaluation over {config.n_episodes} episodes: "
                f"{eval_score:.3f} , D4RL score: {normalized_eval_score:.3f}"
            )
            print("---------------------------------------")
            
            # wandb.log(eval_log, step=trainer.total_it)

            with open(os.path.join(save_path, "eval_results.csv"), "a") as f:
                f.write(f"{t}, {eval_score},{normalized_eval_score}\n")

        if t % 210000 ==0:
                torch.save(
                    trainer.state_dict(),
                    os.path.join(save_path, f"checkpoint_{t}.pth"),
                )

def get_guidance_scale(env_name):
    guidance_scale_map = {
        # Locomotion-Medium
        'walker2d-medium-v2': 10.0,
        'halfcheetah-medium-v2': 10.0,
        'hopper-medium-v2': 8.0,

        # Locomotion-Medium-Expert
        'walker2d-medium-expert-v2': 5.0,
        'halfcheetah-medium-expert-v2': 3.0,
        'hopper-medium-expert-v2': 2.0,

        # Locomotion-Medium-Replay
        'walker2d-medium-replay-v2': 5.0,
        'halfcheetah-medium-replay-v2': 8.0,
        'hopper-medium-replay-v2': 3.0,

        # AntMaze-Fixed
        'antmaze-umaze-v2': 3.0,
        'antmaze-medium-play-v2': 4.0,
        'antmaze-large-play-v2': 3.0,  # Assuming from table

        # AntMaze-Diverse
        'antmaze-umaze-diverse-v2': 1.0,
        'antmaze-medium-diverse-v2': 3.0,
        'antmaze-large-diverse-v2': 2.0,  # Assuming from table
    }

    return guidance_scale_map.get(env_name, None)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", type=str, default="hopper-medium-expert-v2") ################ task
    parser.add_argument("--seed", type=int, default=10) ############################ seed
    parser.add_argument("--log_name", type=str, default="ours")
    parser.add_argument("--load_model", type=str, default="") ######################### load model
    
    # behaviour model
    parser.add_argument("--device", default="cuda", type=str)      #
    parser.add_argument("--save_model", default=1, type=int)       #
    parser.add_argument('--alpha', type=float, default=3.0)        # beta parameter in the paper, use alpha because of legacy
    parser.add_argument('--actor_load_path', type=str, default='./models_rl/hopper-medium-expert-v2/behavior_ckpt600.pth')
    parser.add_argument('--diffusion_steps', type=int, default=15)
    parser.add_argument('--M', type=int, default=16)               # support action number
    parser.add_argument('--s', type=float, default=None)# guidance scale
    parser.add_argument('--method', type=str, default="CEP")
    parser.add_argument('--k', type=int, default=128)                # exchange size

    args = parser.parse_args()
    if "antmaze" not in args.env:
        args.M = 16
    else:
        args.M = 32
    if args.s is None:
        args.s = get_guidance_scale(args.env)
        if args.s is None:
            raise ValueError(f"No guidance scale defined for {args.env}")
    args.actor_load_path = args.actor_load_path.replace('hopper-medium-expert-v2', args.env)
    train(args)
