import copy
import os
import random
import uuid
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
TensorBatch = List[torch.Tensor]
import os
import torchopt



@dataclass
class TrainConfig:
    # Experiment
    device: str = "cuda"
    env: str = "halfcheetah-medium-v2"  # OpenAI gym environment name
    # halfcheetah-medium-v2，halfcheetah-medium-replay-v2, halfcheetah-medium-expert-v2,
    seed: int = 0  # Sets Gym, PyTorch and Numpy seeds
    eval_freq: int = int(5e3)  # How often (time steps) we evaluate
    n_episodes: int = 10  # How many episodes run during evaluation
    max_timesteps: int = int(1e6)  # Max time steps to run environment
    checkpoints_path: Optional[str] = None  # Save path
    load_model: str = ""  # Model load file name, "" doesn't load
    # TD3
    buffer_size: int = 10000000  # Replay buffer size
    batch_size: int = 256  # Batch size for all networks
    discount: float = 0.99  # Discount ffor
    expl_noise: float = 0.1  # Std of Gaussian exploration noise
    tau: float = 0.005  # Target network update rate
    policy_noise: float = 0.2  # Noise added to target actor during critic update
    noise_clip: float = 0.5  # Range to clip target actor noise
    policy_freq: int = 2  # Frequency of delayed actor updates
    # TD3 + BC
    alpha: float = 2.5  # Coefficient for Q function in actor loss
    normalize: bool = True  # Normalize states
    normalize_reward: bool = False  # Normalize reward
    # adaptive 
    alpha_freq: int = 10  # Frequency of updating alpha
    ema_alpha: float = 0.995 # 0.005 # EMA smoothing coefficient for q values

    loss_function: int = 4 # 1: eq1, 2: eq1 + eq2, 3: eq1 + eq3, 4: eq1 + eq2 + eq3

    # Wandb logging
    project: str = "ASPC"
    group: str = "aspc"
    name: str = ""

    def __post_init__(self):
        self.name = f"{self.name}-{self.env}-{str(uuid.uuid4())[:8]}"
        if self.checkpoints_path is not None:
            self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)


def soft_update(target: nn.Module, source: nn.Module, tau: float):
    for target_param, source_param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)


def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
    mean = states.mean(0)
    std = states.std(0) + eps
    return mean, std


def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
    return (states - mean) / std


def wrap_env(
    env: gym.Env,
    state_mean: Union[np.ndarray, float] = 0.0,
    state_std: Union[np.ndarray, float] = 1.0,
    reward_scale: float = 1.0,
) -> gym.Env:
    # PEP 8: E731 do not assign a lambda expression, use a def
    def normalize_state(state):
        return (
            state - state_mean
        ) / state_std  # epsilon should be already added in std.

    def scale_reward(reward):
        # Please be careful, here reward is multiplied by scale!
        return reward_scale * reward

    env = gym.wrappers.TransformObservation(env, normalize_state)
    if reward_scale != 1.0:
        env = gym.wrappers.TransformReward(env, scale_reward)
    return env


class ReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        device: str = "cpu",
    ):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._device = device

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32, device=self._device)

    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[:n_transitions] = self._to_tensor(data["observations"])
        self._actions[:n_transitions] = self._to_tensor(data["actions"])
        self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._size += n_transitions
        self._pointer = min(self._size, n_transitions)

        print(f"Dataset size: {n_transitions}")

    def sample(self, batch_size: int) -> TensorBatch:
        indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size)
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        return [states, actions, rewards, next_states, dones]

    def add_transition(self):
        # Use this method to add new data into the replay buffer during fine-tuning.
        # I left it unimplemented since now we do not do fine-tuning.
        raise NotImplementedError


def set_seed(
    seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
    if env is not None:
        env.seed(seed)
        env.action_space.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.use_deterministic_algorithms(deterministic_torch)


def wandb_init(config: dict) -> None:
    wandb.init(
        config=config,
        project=config["project"],
        group=config["group"],
        name=config["name"],
        id=str(uuid.uuid4()),
    )
    wandb.run.save()


@torch.no_grad()
def eval_actor(
    env: gym.Env, actor: nn.Module, device: str, n_episodes: int, seed: int
) -> np.ndarray:
    env.seed(seed)
    actor.eval()
    episode_rewards = []
    for _ in range(n_episodes):
        state, done = env.reset(), False
        episode_reward = 0.0
        while not done:
            action = actor.act(state, device)
            state, reward, done, _ = env.step(action)
            episode_reward += reward
        episode_rewards.append(episode_reward)

    actor.train()
    return np.asarray(episode_rewards)


def return_reward_range(dataset, max_episode_steps):
    returns, lengths = [], []
    ep_ret, ep_len = 0.0, 0
    for r, d in zip(dataset["rewards"], dataset["terminals"]):
        ep_ret += float(r)
        ep_len += 1
        if d or ep_len == max_episode_steps:
            returns.append(ep_ret)
            lengths.append(ep_len)
            ep_ret, ep_len = 0.0, 0
    lengths.append(ep_len)  # but still keep track of number of steps
    assert sum(lengths) == len(dataset["rewards"])
    return min(returns), max(returns)


def modify_reward(dataset, env_name, max_episode_steps=1000):
    if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
        min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
        dataset["rewards"] /= max_ret - min_ret
        dataset["rewards"] *= max_episode_steps
    elif "antmaze" in env_name:
        dataset["rewards"] -= 1.0


class Actor(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, max_action: float):
        super(Actor, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim),
            nn.Tanh(),
        )

        self.max_action = max_action

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.max_action * self.net(state)

    @torch.no_grad()
    def act(self, state: np.ndarray, device: str = "cpu") -> np.ndarray:
        state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32)
        return self(state).cpu().data.numpy().flatten()


# class Critic(nn.Module):
#     def __init__(self, state_dim: int, action_dim: int):
#         super(Critic, self).__init__()

#         self.net = nn.Sequential(
#             nn.Linear(state_dim + action_dim, 256),
#             nn.ReLU(),
#             nn.Linear(256, 256),
#             nn.ReLU(),
#             nn.Linear(256, 1),
#         )

#     def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
#         sa = torch.cat([state, action], 1)
#         return self.net(sa)

class DyT(nn.Module):
    def __init__(self, num_features, alpha_init_value=0.5):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value)
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
    
    def forward(self, x):
        x = torch.tanh(self.alpha * x)
        return x * self.weight + self.bias

class Critic(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256, layernorm: bool = True, use_dyt: bool = True, n_hiddens: int = 3):
        super(Critic, self).__init__()

        layers = [
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU()
        ]
        if layernorm:
            layers.append(nn.LayerNorm(hidden_dim))
        # if use_dyt:
        #     layers.append(DyT(hidden_dim))  # 替换 LayerNorm

        for _ in range(n_hiddens - 1):
            layers += [
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU()
            ]
            if layernorm:
                layers.append(nn.LayerNorm(hidden_dim))
            # if use_dyt:
            #     layers.append(DyT(hidden_dim))  # 替换 LayerNorm

        layers.append(nn.Linear(hidden_dim, 1))

        self.net = nn.Sequential(*layers)
        
        # Initialize weights
        self._initialize_weights(state_dim, action_dim, hidden_dim)

    def _initialize_weights(self, state_dim, action_dim, hidden_dim):
        nn.init.uniform_(self.net[0].weight, -1.0 / np.sqrt(state_dim + action_dim), 1.0 / np.sqrt(state_dim + action_dim))
        nn.init.constant_(self.net[0].bias, 0.1)
        
        for i in range(1, len(self.net) - 2):
            if isinstance(self.net[i], nn.Linear):
                nn.init.uniform_(self.net[i].weight, -1.0 / np.sqrt(hidden_dim), 1.0 / np.sqrt(hidden_dim))
                nn.init.constant_(self.net[i].bias, 0.1)
                
        nn.init.uniform_(self.net[-1].weight, -3e-3, 3e-3)
        nn.init.uniform_(self.net[-1].bias, -3e-3, 3e-3)

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        sa = torch.cat([state, action], 1)
        return self.net(sa)

class Vnet(nn.Module):
	def __init__(self, state_dim, hidden_dim=256):
		super(Vnet, self).__init__()

		self.l1 = nn.Linear(state_dim, hidden_dim)
		self.l2 = nn.Linear(hidden_dim, hidden_dim)
		self.l3 = nn.Linear(hidden_dim, 1)

	def forward(self, state):

		value = F.relu(self.l1(state))
		value = F.relu(self.l2(value))
		value = self.l3(value)

		return value

class Adaptive_TD3_BC:
    def __init__(
        self,
        max_action: float,
        actor: nn.Module,
        actor_optimizer: torchopt.Optimizer,
        critic_1: nn.Module,
        critic_1_optimizer: torch.optim.Optimizer,
        critic_2: nn.Module,
        critic_2_optimizer: torch.optim.Optimizer,
        vnet: nn.Module,
        vnet_optimizer: torch.optim.Optimizer,
        discount: float = 0.99,
        tau: float = 0.005,
        policy_noise: float = 0.2,
        noise_clip: float = 0.5,
        policy_freq: int = 2,
        alpha: float = 2.5,
        alpha_freq: int = 10,
        ema_alpha: float = 0.995,
        loss_function: int = 4,
        device: str = "cpu",
        action_dim: int = 1,
    ):
        self.actor = actor
        self.actor_target = copy.deepcopy(actor)
        self.actor_optimizer = actor_optimizer
        # 提取actor参数
        named_params = list(self.actor.named_parameters())
        params = [p for _, p in named_params]
        self.actor_opt_state = self.actor_optimizer.init(params)
        self.critic_1 = critic_1
        self.critic_1_target = copy.deepcopy(critic_1)
        self.critic_1_optimizer = critic_1_optimizer
        self.critic_2 = critic_2
        self.critic_2_target = copy.deepcopy(critic_2)
        self.critic_2_optimizer = critic_2_optimizer
        self.vnet = vnet
        self.vnet_optimizer = vnet_optimizer

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq

        self.alpha_freq = alpha_freq
        self.alpha = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha)) - 1.0))
        self.alpha_optimizer = torch.optim.Adam([self.alpha], lr=2e-3 / 10 * self.alpha_freq)
        self.gamma = 0.01 ** (1 / int(1e6 / (self.policy_freq * self.alpha_freq)))
        self.alpha_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.alpha_optimizer, gamma=self.gamma)

        self.total_it = 0
        self.device = device

        self.ema_alpha = ema_alpha 
        self.ema_q = torch.tensor(0.0, device=device)  # EMA Q 值
        self.previous_ema_q = 0.0    

        self.loss_function = loss_function
        self.action_dim = action_dim


    def train(self, batch: TensorBatch) -> Dict[str, float]:
        log_dict = {}
        self.total_it += 1

        state, action, reward, next_state, done = batch
        not_done = 1 - done

        with torch.no_grad():
            noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)
            target_q1 = self.critic_1_target(next_state, next_action)
            target_q2 = self.critic_2_target(next_state, next_action)
            target_q = torch.min(target_q1, target_q2)
            target_q = (reward - self.reward_mean) + not_done * self.discount * target_q

        current_q1 = self.critic_1(state, action)
        current_q2 = self.critic_2(state, action)

        critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
        self.critic_1_optimizer.zero_grad()
        self.critic_2_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_1_optimizer.step()
        self.critic_2_optimizer.step()

        if self.total_it % self.policy_freq == 0:
            alpha = F.softplus(self.alpha)
            log_dict['alpha'] = alpha.item()
            named_params = list(self.actor.named_parameters())
            param_names = [n for n, _ in named_params]
            params = [p for _, p in named_params]
            actor_param_dict = {n: p for n, p in zip(param_names, params)}
            pi = torch.func.functional_call(self.actor, actor_param_dict, (state,))
            q = self.critic_1(state, pi)
            log_dict["q"] = q.mean().item()
            q_abs_mean = q.abs().mean().detach()
            lmbda = alpha / q_abs_mean
            actor_loss = -lmbda * q.mean() + F.mse_loss(pi, action)
            grads = torch.autograd.grad(actor_loss, params, create_graph=True)
            updates, self.actor_opt_state = self.actor_optimizer.update(grads, self.actor_opt_state, inplace=False)
            self.actor_opt_state = torchopt.pytree.tree_map(lambda x: x.detach() if isinstance(x, torch.Tensor) else x, self.actor_opt_state)
            updated_params = torchopt.apply_updates(params, list(updates), inplace=False)

            if self.total_it % (self.policy_freq * self.alpha_freq) == 0:
                updated_actor_param_dict = {n: p_new for n, p_new in zip(param_names, updated_params)}
                pi_new = torch.func.functional_call(self.actor, updated_actor_param_dict, (state,))
                q_new = self.critic_1(state, pi_new)
                q_new_mean = q_new.mean()

                self.ema_q = (1 - self.ema_alpha) * q_new_mean + self.ema_alpha * self.ema_q.detach()

                delta_q = self.ema_q - self.previous_ema_q
                bc_sup = F.mse_loss(pi, action, reduction='none').mean(dim=1).max().detach()
                delta_bc_sup = (F.mse_loss(pi, action, reduction='none').mean(dim=1).detach() - F.mse_loss(pi_new, action, reduction='none').mean(dim=1)).abs().max()


                eq1 = -alpha.detach() * (q_new.mean() / q_new.abs().mean().detach()) + F.mse_loss(pi_new, action) # (torch.square(pi - action) * bc_weight).mean() # 
                eq2 = (delta_q) ** 2 
                eq3 = eq2.detach() * bc_sup * delta_bc_sup 

                if self.loss_function == 1:
                    alpha_loss = eq1
                elif self.loss_function == 2:
                    alpha_loss = eq1 + eq2
                elif self.loss_function == 3:
                    alpha_loss = eq1 + eq3
                elif self.loss_function == 4:
                    alpha_loss = eq1 + eq2 + eq3


                log_dict["bc_loss_sup"] = bc_sup.item()
                log_dict["delta_Q"] = torch.abs(delta_q).item()
                log_dict["delta_bc_sup"] = delta_bc_sup.item()

                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()

                self.alpha_scheduler.step()

                self.previous_ema_q = self.ema_q.detach().item()

            with torch.no_grad():
                for (n, _), p_new in zip(named_params, updated_params):
                    p_old = dict(self.actor.named_parameters())[n]
                    p_old.copy_(p_new)

            soft_update(self.critic_1_target, self.critic_1, self.tau)
            soft_update(self.critic_2_target, self.critic_2, self.tau)
            soft_update(self.actor_target, self.actor, self.tau)

        return log_dict
    
    def compute_actor_loss(self, actor_params, state, action):
        pi = torch.func.functional_call(self.actor, actor_params, (state,))
        q = self.critic_1(state, pi)
        self.old_q_mean = q.mean().detach()
        q_abs_mean = q.abs().mean().detach()
        lmbda = self.alpha / q_abs_mean
        actor_loss = -lmbda * q.mean() + F.mse_loss(pi, action)
        return actor_loss


    def state_dict(self) -> Dict[str, Any]:
        return {
            "critic_1": self.critic_1.state_dict(),
            "critic_1_optimizer": self.critic_1_optimizer.state_dict(),
            "critic_2": self.critic_2.state_dict(),
            "critic_2_optimizer": self.critic_2_optimizer.state_dict(),
            "actor": self.actor.state_dict(),
            "actor_optimizer": self.actor_optimizer.state_dict(),
            "total_it": self.total_it,
        }

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.critic_1.load_state_dict(state_dict["critic_1"])
        self.critic_1_optimizer.load_state_dict(state_dict["critic_1_optimizer"])
        self.critic_1_target = copy.deepcopy(self.critic_1)

        self.critic_2.load_state_dict(state_dict["critic_2"])
        self.critic_2_optimizer.load_state_dict(state_dict["critic_2_optimizer"])
        self.critic_2_target = copy.deepcopy(self.critic_2)

        self.actor.load_state_dict(state_dict["actor"])
        self.actor_optimizer.load_state_dict(state_dict["actor_optimizer"])
        self.actor_target = copy.deepcopy(self.actor)

        self.total_it = state_dict["total_it"]


@pyrallis.wrap()
def train(config: TrainConfig):
    env = gym.make(config.env)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    dataset = d4rl.qlearning_dataset(env)

    if config.normalize_reward:
        modify_reward(dataset, config.env)

    if config.normalize:
        state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
    else:
        state_mean, state_std = 0, 1

    dataset["observations"] = normalize_states(
        dataset["observations"], state_mean, state_std
    )
    dataset["next_observations"] = normalize_states(
        dataset["next_observations"], state_mean, state_std
    )
    env = wrap_env(env, state_mean=state_mean, state_std=state_std)
    replay_buffer = ReplayBuffer(
        state_dim,
        action_dim,
        config.buffer_size,
        config.device,
    )
    replay_buffer.load_d4rl_dataset(dataset)

    max_action = float(env.action_space.high[0])

    if config.checkpoints_path is not None:
        print(f"Checkpoints path: {config.checkpoints_path}")
        os.makedirs(config.checkpoints_path, exist_ok=True)
        with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
            pyrallis.dump(config, f)

    # Set seeds
    seed = config.seed
    set_seed(seed, env)

    actor = Actor(state_dim, action_dim, max_action).to(config.device)
    actor_optimizer = torchopt.adam(lr=3e-4, use_accelerated_op=True)

    critic_1 = Critic(state_dim, action_dim).to(config.device)
    critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=3e-4)
    critic_2 = Critic(state_dim, action_dim).to(config.device)
    critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=3e-4)

    vnet = Vnet(state_dim).to(config.device)
    vnet_optimizer = torch.optim.Adam(vnet.parameters(), lr=3e-4)

    kwargs = {
        "max_action": max_action,
        "actor": actor,
        "actor_optimizer": actor_optimizer,
        "critic_1": critic_1,
        "critic_1_optimizer": critic_1_optimizer,
        "critic_2": critic_2,
        "critic_2_optimizer": critic_2_optimizer,
        "vnet": vnet,
        "vnet_optimizer": vnet_optimizer,
        "discount": config.discount,
        "tau": config.tau,
        "device": config.device,
        # TD3
        "policy_noise": config.policy_noise * max_action,
        "noise_clip": config.noise_clip * max_action,
        "policy_freq": config.policy_freq,
        # TD3 + BC
        "alpha": config.alpha,
        "alpha_freq": config.alpha_freq,
        "ema_alpha": config.ema_alpha,
        "loss_function": config.loss_function,
        "action_dim": action_dim,
    }

    print("---------------------------------------")
    print(f"Training Adaptive TD3 + BC, Env: {config.env}, Seed: {seed}")
    print("---------------------------------------")

    # Initialize actor
    trainer = Adaptive_TD3_BC(**kwargs)

    if config.load_model != "":
        policy_file = Path(config.load_model)
        trainer.load_state_dict(torch.load(policy_file))
        actor = trainer.actor

    wandb_init(asdict(config))

    evaluations = []
    for t in range(int(config.max_timesteps)):
        batch = replay_buffer.sample(config.batch_size)
        batch = [b.to(config.device) for b in batch]
        log_dict = trainer.train(batch)
        wandb.log(log_dict, step=trainer.total_it)
        # Evaluate episode
        if (t + 1) % config.eval_freq == 0:
            print(f"Time steps: {t + 1}")
            eval_scores = eval_actor(
                env,
                actor,
                device=config.device,
                n_episodes=config.n_episodes,
                seed=config.seed,
            )
            eval_score = eval_scores.mean()
            normalized_eval_score = env.get_normalized_score(eval_score) * 100.0
            evaluations.append(normalized_eval_score)
            print("---------------------------------------")
            print(
                f"Evaluation over {config.n_episodes} episodes: "
                f"{eval_score:.3f} , D4RL score: {normalized_eval_score:.3f}"
            )
            print("---------------------------------------")

            if config.checkpoints_path is not None:
                torch.save(
                    trainer.state_dict(),
                    os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
                )

            wandb.log(
                {"d4rl_normalized_score": normalized_eval_score},
                step=trainer.total_it,
            )

if __name__ == "__main__":
    train()
