import hydra
import wandb
import random
import minari
import os
import numpy as np
import gymnasium as gym
from pathlib import Path
from datetime import datetime
from omegaconf import DictConfig
from lamb import Lamb

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from model import DecisionMaxTransformer, DecisionMaxConvTransformer
from utils import KMeansEpisodicTrajectoryDataset, MaxEpisodicTrajectoryDataset, VisionKMeansEpisodicTrajectoryDataset, VisionMaxEpisodicTrajectoryDataset, get_test_start_state_goals, get_lr, AntmazeWrapper, PreprocessObservationWrapper 
from gymnasium.wrappers import PixelObservationWrapper
from torch import distributions, nn
from torch.nn import functional as F
from functools import partial
from typing import List, Optional, Type

from common import MLP, EnsembleMLP, LinearEnsemble

def weight_init(m: nn.Module, gain: int = 1) -> None:
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data, gain=gain)
        if hasattr(m.bias, "data"):
            m.bias.data.fill_(0.0)
    if isinstance(m, LinearEnsemble):
        for i in range(m.ensemble_size):
            # Orthogonal initialization doesn't care about which dim is first
            # Thus, we can just use ortho init as normal on each matrix.
            nn.init.orthogonal_(m.weight.data[i], gain=gain)
        if hasattr(m.bias, "data"):
            m.bias.data.fill_(0.0)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def norm_distance(env , state, goal, info=None):
    obs, _ = env.reset()
    o, ag, g = obs["observation"], obs["achieved_goal"], obs["desired_goal"]
    # some hack to get the goal from observations
    start_idx = None
    state_dim =state.shape[-1]
    #print("state_dim:",state_dim)
    goal_dim = goal.shape[-1]
    #print("goal_dim:",goal_dim)
    for i in range(state_dim - goal_dim + 1):
        sub_o = o[i:i+goal_dim]
        if (sub_o == ag).sum() ==goal_dim:
            start_idx = i
            break

    #print("start_idx:",start_idx)
    # get goal index to transform state to goal
    goal_idx = torch.arange(start_idx, start_idx+goal_dim)
    #print("goal_idx:",goal_idx)
    # Vectorized reward function.
    # Returns -1 we are not at the goal, and zero otherwise
    imaged_goal = state[..., goal_idx.numpy()]
    #print("imaged_goal.shape:",imaged_goal.shape)
    #print("goal.shape:",goal.shape)
    assert imaged_goal.shape == goal.shape
    d = np.linalg.norm(imaged_goal.cpu() - goal.cpu(), axis=-1) 
    #print("d.shape", d.shape)
    return d

def get_distance(self, logits=None):
        distribution = torch.nn.functional.softmax(logits, dim=-1)  # (E, B, D)
        distances = torch.arange(start=0, end=self.bins, device=logits.device) / self.bins
        distances = distances.unsqueeze(0).unsqueeze(0)  # (E, B, D)
        if self.alpha is None:
            # Return the expectation
            predicted_distance = (distribution * distances).sum(dim=-1)
        else:
            # Return the LSE weighted by the distribution.
            exp_q = torch.exp(-distances / self.alpha)
            predicted_distance = -self.alpha * torch.log(torch.sum(distribution * exp_q, dim=-1))
        return torch.max(predicted_distance, dim=0)[0]

class DiscreteMLPDistance(nn.Module):
    def __init__(
        self, env, hidden_layers=[512,512], bins =100, goal_dim=2, ensemble_size=1, ortho_init = True, output_gain= None):
        super().__init__()
        self._bins = bins
        self.state_dim  = env.observation_space['observation'].shape[0] 
        self.goal_dim   = goal_dim
        
        input_dim = self.state_dim + self.goal_dim
        self.ensemble_size = ensemble_size
        if self.ensemble_size > 1:
            self.mlp = EnsembleMLP(input_dim=input_dim, output_dim = self.bins, ensemble_size=self.ensemble_size, hidden_layers=hidden_layers)
        else:
            self.mlp = MLP(input_dim = input_dim, output_dim = self.bins, hidden_layers=hidden_layers)
        self.ortho_init = ortho_init
        self.output_gain = output_gain
        self.reset_parameters()

    def reset_parameters(self):
        if self.ortho_init:
            self.apply(partial(weight_init, gain=float(self.ortho_init)))  # use the fact that True converts to 1.0
            if self.output_gain is not None:
                self.mlp.last_layer.apply(partial(weight_init, gain=self.output_gain))

    @property
    def bins(self):
        return self._bins

    def forward(self, s,g):
        x = torch.cat([s,  g], -1)
        v = self.mlp(x)
        return v



'''
def eval_env(cfg, model, device, render=False):
    if render:
        render_mode = 'human'
    else:
        render_mode = None

    if cfg.env_name in ['PointMaze_UMaze-v3', 'PointMaze_Medium-v3', 'PointMaze_Large-v3']:
        if cfg.vision:
            env = gym.make(cfg.env_name, continuing_task=False, render_mode='rgb_array')
            DEFAULT_CAMERA_CONFIG = {
            "distance": 14 if len(env.maze.maze_map) > 8 else 8.8,
            "elevation": -90,
            "lookat": [0, 0, 0,],
            }
            env.point_env.mujoco_renderer.default_cam_config = DEFAULT_CAMERA_CONFIG
            env = PixelObservationWrapper(env, pixels_only=False)
            env = PreprocessObservationWrapper(env, shape=64, grayscale=True)
        else:
            env = env = gym.make(cfg.env_name, continuing_task=False, render_mode=render_mode)
    elif cfg.env_name in ['AntMaze_UMaze-v4', 'AntMaze_Medium-v4', 'AntMaze_Large-v4']:
        env = AntmazeWrapper(env = gym.make(cfg.env_name, continuing_task=False, render_mode=render_mode))
    else:
        raise NotImplementedError 

    test_start_state_goal = get_test_start_state_goals(cfg)
    
    model.eval()
    results = dict()
    eval_batch_size = 1

    timesteps = torch.arange(start=0, end=env.spec.max_episode_steps, step=1)
    timesteps = timesteps.repeat(eval_batch_size, 1).to(device)

    with torch.no_grad():
        cum_reward = 0
        for ss_g in test_start_state_goal:
            total_reward = 0
            total_timesteps = 0
            print(ss_g['name'] + ':')
            for _ in range(cfg.num_eval_ep):
                # zeros place holders
                actions = torch.zeros(
                    (eval_batch_size,  env.spec.max_episode_steps, model.act_dim),
                    dtype=torch.float32,
                    device=device,
                )
                states = torch.zeros(
                    (eval_batch_size,  env.spec.max_episode_steps, model.state_dim),
                    dtype=torch.float32,
                    device=device,
                )
                returns_to_go = torch.zeros(
                    (eval_batch_size,  env.spec.max_episode_steps, 1),
                    dtype=torch.float32,
                    device=device,
                )

                # init episode
                obs, _ = env.reset(options=ss_g)
                done = False

                for t in range(env.spec.max_episode_steps):
                    total_timesteps += 1
                    # add state in placeholder and normalize
                    states[0, t] = torch.tensor(obs['observation'], dtype=torch.float32, device=device)
                    # predict rtg by model
                    if t < cfg.context_len:
                        rtg_preds, _, _ = model.forward(
                            timesteps[:, :cfg.context_len],
                            states[:, :cfg.context_len],
                            actions[:, :cfg.context_len],
                            returns_to_go[:, :cfg.context_len],
                        )
                        rtg = rtg_preds[0, t].detach()
                    else:
                        rtg_preds, _, _ = model.forward(
                            timesteps[:, t - cfg.context_len + 1 : t + 1],
                            states[:, t - cfg.context_len + 1 : t + 1],
                            actions[:, t - cfg.context_len + 1 : t + 1],
                            returns_to_go[:, t - cfg.context_len + 1 : t + 1],
                        )
                        rtg = rtg_preds[0, -1].detach()

                    # add rtg in placeholder
                    returns_to_go[0, t] = rtg

                    # take action by model
                    if t < cfg.context_len:
                        _, act_dist_preds, _ = model.forward(
                            timesteps[:, :cfg.context_len],
                            states[:, :cfg.context_len],
                            actions[:, :cfg.context_len],
                            returns_to_go[:, :cfg.context_len],
                        )
                        act = act_dist_preds[0, t].detach()
                    else:
                        _, act_dist_preds, _ = model.forward(
                            timesteps[:, t - cfg.context_len + 1 : t + 1],
                            states[:, t - cfg.context_len + 1 : t + 1],
                            actions[:, t - cfg.context_len + 1 : t + 1],
                            returns_to_go[:, t - cfg.context_len + 1 : t + 1],
                        )
                        act = act_dist_preds[0, -1].detach()
                        
                    # env step
                    obs, running_reward, done, _, _ = env.step(act.cpu().numpy())
                    # add action in placeholder
                    actions[0, t] = act

                    total_reward += running_reward

                    if done:
                        break

                print('Achievied goal: ', tuple(obs['achieved_goal'].tolist()))
                print('Desired goal: ', tuple(obs['desired_goal'].tolist()))
                
            print("=" * 60)
            cum_reward += total_reward
            results['eval/' + str(ss_g['name']) + '_avg_reward'] = total_reward / cfg.num_eval_ep
            results['eval/' + str(ss_g['name']) + '_avg_ep_len'] = total_timesteps / cfg.num_eval_ep
        
        results['eval/avg_reward'] = cum_reward / (cfg.num_eval_ep * len(test_start_state_goal))
        env.close()
    return results
'''

def eval_env(cfg, model, device, render=False):
    if render:
        render_mode = 'human'
    else:
        render_mode = None

    if cfg.env_name in ['PointMaze_UMaze-v3', 'PointMaze_Medium-v3', 'PointMaze_Large-v3']:
        if cfg.vision:
            env = gym.make(cfg.env_name, continuing_task=False, render_mode='rgb_array')
            DEFAULT_CAMERA_CONFIG = {
            "distance": 14 if len(env.maze.maze_map) > 8 else 8.8,
            "elevation": -90,
            "lookat": [0, 0, 0,],
            }
            env.point_env.mujoco_renderer.default_cam_config = DEFAULT_CAMERA_CONFIG
            env = PixelObservationWrapper(env, pixels_only=False)
            env = PreprocessObservationWrapper(env, shape=64, grayscale=True)
        else:
            env = env = gym.make(cfg.env_name, continuing_task=False, render_mode=render_mode)
    elif cfg.env_name in ['AntMaze_UMaze-v4', 'AntMaze_Medium-v4', 'AntMaze_Large-v4']:
        env = AntmazeWrapper(env = gym.make(cfg.env_name, continuing_task=False, render_mode=render_mode))
    else:
        raise NotImplementedError 

    test_start_state_goal = get_test_start_state_goals(cfg)
    
    model.eval()
    results = dict()
    eval_batch_size = 1

    with torch.no_grad():
        cum_reward = 0
        for ss_g in test_start_state_goal:
            total_reward = 0
            total_timesteps = 0

            timesteps = torch.arange(start=0, end=env.spec.max_episode_steps, step=1)
            timesteps = timesteps.repeat(eval_batch_size, 1).to(device)
            print(ss_g['name'] + ':')
            for _ in range(cfg.num_eval_ep):
                # zeros place holders
                m_actions = torch.zeros((eval_batch_size, env.spec.max_episode_steps, model.act_dim),
                                    dtype=torch.float32, device=device)
                m_states = torch.zeros((eval_batch_size, env.spec.max_episode_steps, model.state_dim),
                                    dtype=torch.float32, device=device)
                m_goals = torch.zeros((eval_batch_size, env.spec.max_episode_steps, model.goal_dim),
                                    dtype=torch.float32, device=device)
                returns_to_go = torch.zeros((eval_batch_size, env.spec.max_episode_steps, 1),
                                    dtype=torch.float32,device=device)
                obs, _ = env.reset(options=ss_g)
                done = False

                for t in range(env.spec.max_episode_steps):
                    total_timesteps += 1

                    m_states[0, t] = torch.tensor(obs['observation'], dtype=torch.float32, device=device)
                    m_goals[0, t] = torch.tensor(obs['desired_goal'], dtype=torch.float32, device=device)

                    returns_to_go[0,t] =  torch.tensor(norm_distance(env =env, state=m_states[0, t], goal=m_goals[0, t]), dtype=torch.float32, device=device)
                    
                    if t < cfg.context_len:
                        rtg_preds, _, _ = model.forward(
                            timesteps[:, :cfg.context_len],
                            m_states[:, :cfg.context_len],
                            m_actions[:, :cfg.context_len],
                            returns_to_go[:, :cfg.context_len],
                        )
                        rtg = rtg_preds[0, t].detach()
                    else:
                        rtg_preds, _, _ = model.forward(
                            timesteps[:, t - cfg.context_len + 1 : t + 1],
                            m_states[:, t - cfg.context_len + 1 : t + 1],
                            m_actions[:, t - cfg.context_len + 1 : t + 1],
                            returns_to_go[:, t - cfg.context_len + 1 : t + 1],
                        )
                        rtg = rtg_preds[0, -1].detach()
                    # add rtg in placeholder
                    returns_to_go[0, t] = rtg
                    # predict rtg by model
                    if t < cfg.context_len:
                        (
                          _,
                          act_preds,
                          _,
                        ) = model.forward(
                                                    timesteps[:, :cfg.context_len],
                                                    m_states[:,:cfg.context_len],
                                                    m_actions[:,:cfg.context_len],
                                                    returns_to_go[:,:cfg.context_len])
                    else:
                        (
                          _,
                          act_preds,
                          _,
                        ) = model.forward(
                                                    timesteps[:, t - cfg.context_len + 1 : t + 1],
                                                    m_states[:, t-cfg.context_len+1:t+1],
                                                    m_actions[:, t-cfg.context_len+1:t+1],
                                                    returns_to_go[:, t-cfg.context_len+1:t+1])
                        

                    act = act_preds[0, -1].detach()

                    obs, running_reward, done, _, _ = env.step(act.cpu().numpy())

                    # add action in placeholder
                    m_actions[0, t] = act

                    total_reward += running_reward

                    if done:
                        break

                print('Achievied goal: ', tuple(obs['achieved_goal'].tolist()))
                print('Desired goal: ', tuple(obs['desired_goal'].tolist()))
                
            print("=" * 60)
            cum_reward += total_reward
            results['eval/' + str(ss_g['name']) + '_avg_reward'] = total_reward / cfg.num_eval_ep
            results['eval/' + str(ss_g['name']) + '_avg_ep_len'] = total_timesteps / cfg.num_eval_ep
        
        results['eval/avg_reward'] = cum_reward / (cfg.num_eval_ep * len(test_start_state_goal))
        env.close()
    return results

def train(cfg, hydra_cfg):

    #set seed
    set_seed(cfg.seed)

    #set device
    device = torch.device(cfg.device)

    if cfg.save_snapshot:
        checkpoint_path = Path(hydra_cfg['runtime']['output_dir']) / Path('checkpoint')
        checkpoint_path.mkdir(exist_ok=True)
        log_path =  Path(hydra_cfg['runtime']['output_dir'])
        best_eval_returns = 0

    start_time = datetime.now().replace(microsecond=0)
    time_elapsed = start_time - start_time
    start_time_str = start_time.strftime("%y-%m-%d-%H-%M-%S")

    if cfg.dataset_name in ["pointmaze-umaze-v0", "pointmaze-medium-v0", "pointmaze-large-v0"]:
        cfg.vision = False
        if cfg.dataset_name in ["pointmaze-umaze-v0"]:
            cfg.env_name = 'PointMaze_UMaze-v3'
            cfg.nclusters = 20 if cfg.nclusters is None else cfg.nclusters
        elif cfg.dataset_name in ["pointmaze-medium-v0"]:
            cfg.env_name = 'PointMaze_Medium-v3'
            cfg.nclusters = 40 if cfg.nclusters is None else cfg.nclusters
        elif cfg.dataset_name == "pointmaze-large-v0":
            cfg.env_name = 'PointMaze_Large-v3'
            cfg.nclusters = 80 if cfg.nclusters is None else cfg.nclusters
        env = gym.make(cfg.env_name, continuing_task=False)

    elif "antmaze" in cfg.dataset_name:
        if "umaze" in cfg.dataset_name:
            cfg.env_name = 'AntMaze_UMaze-v4'
            cfg.nclusters = 20 if cfg.nclusters is None else cfg.nclusters
        elif "medium" in cfg.dataset_name:
            cfg.env_name = 'AntMaze_Medium-v4'
            cfg.nclusters = 40 if cfg.nclusters is None else cfg.nclusters
        elif "large" in cfg.dataset_name:
            cfg.env_name = 'AntMaze_Large-v4'
            cfg.nclusters = 80 if cfg.nclusters is None else cfg.nclusters
        env = AntmazeWrapper(gym.make(cfg.env_name, continuing_task=False))

    elif cfg.dataset_name in ["v-pointmaze-umaze-v0", "v-pointmaze-medium-v0", "v-pointmaze-large-v0"]:
        cfg.vision = True
        if cfg.dataset_name == "v-pointmaze-umaze-v0":
            cfg.env_name = 'PointMaze_UMaze-v3'
            cfg.nclusters = 20 if cfg.nclusters is None else cfg.nclusters
        elif cfg.dataset_name == "v-pointmaze-medium-v0":
            cfg.env_name = 'PointMaze_Medium-v3'
            cfg.nclusters = 40 if cfg.nclusters is None else cfg.nclusters
        elif cfg.dataset_name == "v-pointmaze-large-v0":
            cfg.env_name = 'PointMaze_Large-v3'
            cfg.nclusters = 80 if cfg.nclusters is None else cfg.nclusters
        env = gym.make(cfg.env_name, continuing_task=False, render_mode='rgb_array')
        DEFAULT_CAMERA_CONFIG = {
        "distance": 14 if len(env.maze.maze_map) > 8 else 8.8,
        "elevation": -90,
        "lookat": [0, 0, 0,],
        }
        env.point_env.mujoco_renderer.default_cam_config = DEFAULT_CAMERA_CONFIG
        env = PixelObservationWrapper(env, pixels_only=False)
        env = PreprocessObservationWrapper(env, shape=64, grayscale=True)

    else:
        raise NotImplementedError
    env.action_space.seed(cfg.seed)
    env.observation_space.seed(cfg.seed)

    print(cfg.nclusters)

    if cfg.vision:
        if cfg.max_return_strategy:
            train_dataset = VisionMaxEpisodicTrajectoryDataset(cfg.dataset_name, cfg. datasize, cfg.context_len, cfg.augment_data, cfg.augment_prob, cfg.nclusters, cfg.vision)  
        else:
            train_dataset = VisionKMeansEpisodicTrajectoryDataset(cfg.dataset_name, cfg. datasize, cfg.context_len, cfg.augment_data, cfg.augment_prob, cfg.nclusters, cfg.vision)  
    else:
        if cfg.max_return_strategy:
            print("right dataset!!")
            train_dataset = MaxEpisodicTrajectoryDataset(env, cfg.dataset_name, cfg. datasize, cfg.context_len, cfg.augment_data, cfg.augment_prob, cfg.nclusters, cfg.vision)  
        else:
            train_dataset = KMeansEpisodicTrajectoryDataset(cfg.dataset_name, cfg. datasize, cfg.context_len, cfg.augment_data, cfg.augment_prob, cfg.nclusters, cfg.vision)  
    
    train_data_loader = DataLoader(
                            train_dataset,
                            batch_size=cfg.batch_size,
                            shuffle=True,
                            num_workers=cfg.num_workers
                        )
    train_data_iter = iter(train_data_loader)

    #create model
    if cfg.vision:
        model = DecisionMaxConvTransformer(cfg.env_name, env, cfg.n_blocks, cfg.embed_dim, cfg.context_len, cfg.n_heads, cfg.drop_p, goal_dim=train_dataset.goal_dim).to(device)
    else:
        model = DecisionMaxTransformer(cfg.env_name, env, cfg.n_blocks, cfg.embed_dim, cfg.context_len, cfg.n_heads, cfg.drop_p, goal_dim=train_dataset.goal_dim).to(device)
    
    
    optimizer = Lamb(
            model.parameters(),
            lr=cfg.lr,
            weight_decay=cfg.wt_decay,
            eps=1e-8,
        )

    scheduler = torch.optim.lr_scheduler.LambdaLR(
                            optimizer,
                            lambda steps: min((steps+1)/cfg.warmup_steps, 1)
                        )

    total_updates = 0
    success_rate_list = []
    for i_train_iter in range(cfg.max_train_iters):
        
        log_action_losses = []
        model.train()

        for i in range(cfg.num_updates_per_iter):
            #print("update steps:",i)
            try:
                #states, proprio, goals, actions = next(train_data_iter)
                timesteps, states, goal, actions, traj_mask = next(train_data_iter)

            except StopIteration:
                train_data_iter = iter(train_data_loader)
                #states, proprio, goals, actions = next(train_data_iter)
                timesteps, states, goal, actions, traj_mask = next(train_data_iter)


            if cfg.vision:
                states = states.to(device).squeeze(-1).unsqueeze(2)                                  # B x T x pixel_dim
            else:
                states = states.to(device)       
            
            #proprio = proprio.to(device)                                # B x T x state_dim
            goal = goal.to(device).repeat(1, cfg.context_len, 1)      # B x T x goal_dim
            actions = actions.to(device)                                # B x T
            # traj_masks = traj_masks.to(device)                          # B x T

            returns_to_go = torch.tensor(norm_distance(env =env, state=states, goal=goal)).to(device).unsqueeze(
                dim=-1
            )                                          # B x T x 1
            
            traj_mask = traj_mask.to(device)      # B x T

            # model forward ----------------------------------------------
            (
                returns_to_go_preds,
                action_preds,
                _,
            ) = model.forward(
                timesteps=timesteps,
                states=states,
                actions=actions,
                returns_to_go=returns_to_go,
            )

            returns_to_go_target = torch.clone(returns_to_go).view(
                -1, 1
            )[
                traj_mask.view(-1,) > 0
            ]
            returns_to_go_preds = returns_to_go_preds.view(-1, 1)[
                traj_mask.view(-1,) > 0
            ]

            # returns_to_go_loss -----------------------------------------
            norm = returns_to_go_target.abs().mean()
            u = (returns_to_go_target - returns_to_go_preds) / norm
            returns_to_go_loss = torch.mean(
                torch.abs(
                    cfg.tau - (u < 0).float()
                ) * u ** 2
            )
            
            # action_loss ------------------------------------------------
            action_loss = F.mse_loss(action_preds, actions)

            loss = returns_to_go_loss + action_loss

            # optimization -----------------------------------------------
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), 
                cfg.grad_norm
            )
            optimizer.step()

            scheduler.step()

            log_action_losses.append(loss.detach().cpu().item())

        time = datetime.now().replace(microsecond=0) - start_time - time_elapsed
        time_elapsed = datetime.now().replace(microsecond=0) - start_time

        total_updates += cfg.num_updates_per_iter
        
        mean_action_loss = np.mean(log_action_losses)
        
        results = eval_env(cfg, model, device, render=cfg.render)

        log_str = ("=" * 60 + '\n' +
                "time elapsed: " + str(time_elapsed)  + '\n' +
                "num of updates: " + str(total_updates) + '\n' +
                "train action loss: " +  format(mean_action_loss, ".5f") #+ '\n' +
            )
        
        success_rate_list.append(results['eval/avg_reward'])
        print(log_str)
        print(results)
        print("eval_reward_list:", success_rate_list)

        if cfg.save_snapshot and results['eval/avg_reward'] >= best_eval_returns:
            print("=" * 60)
            print("saving best model!")
            print("=" * 60)
            best_eval_returns = results['eval/avg_reward']
            print("*******************************************************************************")
            print("total_updates:!!!!!!!!!!!!!!!!!",total_updates)
            print("*******************************************************************************")
            print("*******************************************************************************")
            print("best_eval_returns:!!!!!!!!!!!!!!!!!",best_eval_returns)
            print("*******************************************************************************")

    log_data = results
    log_filename = 'log.txt'
    if cfg.augment_data == True:
        new_log_filename = 'max_return_' +log_filename
    else:
        new_log_filename = 'no_augment_' +log_filename
    print("log_path:", log_path)
    
    log_file_path = os.path.join(log_path, new_log_filename)

    
    if not os.path.exists(log_path):
        os.makedirs(log_path)

    
    with open(log_file_path, 'w') as file:
        for key, value in log_data.items():
            file.write(f'{key}: {value}\n')

        print(f'Data has been written to {log_file_path}')

    print("=" * 60)
    print("finished training!")
    print("=" * 60)
    end_time = datetime.now().replace(microsecond=0)
    time_elapsed = str(end_time - start_time)
    end_time_str = end_time.strftime("%y-%m-%d-%H-%M-%S")
    print("started training at: " + start_time_str)
    print("finished training at: " + end_time_str)
    print("total training time: " + time_elapsed)
    print("*******************************************************************************")
    print("total_updates:!!!!!!!!!!!!!!!!!", total_updates)
    print("*******************************************************************************")
    print("*******************************************************************************")
    print("eval_reward_list:!!!!!!!!!!!!!!!!!", success_rate_list)
    print("*******************************************************************************")
    print("*******************************************************************************")
    print("best_eval_returns:!!!!!!!!!!!!!!!!!",best_eval_returns)
    print("*******************************************************************************")
    print("=" * 60)

@hydra.main(config_path='cfgs', config_name='max_dt', version_base=None)
def main(cfg: DictConfig):
    hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
    
    if cfg.wandb_log:
        if cfg.wandb_dir is None:
            cfg.wandb_dir = hydra_cfg['runtime']['output_dir']

        project_name = cfg.dataset_name
        wandb.init(project=project_name, entity=cfg.wandb_entity, config=dict(cfg), dir=cfg.wandb_dir, group=cfg.wandb_group_name)
        wandb.run.name = cfg.wandb_run_name
    
    train(cfg, hydra_cfg)
        
if __name__ == "__main__":
    main()