import copy
import time
import os
import queue
import setproctitle

import SMOS
import SMOS_utils
import torch
from core.utils import profile, TimeRecorder, clip_grad_norm

import numpy as np
import torch.optim as optim
import torch.nn.functional as F
import multiprocessing as mp

import zmq
import zlib
from torch.nn import L1Loss
from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler as GradScaler

from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

from core.test import _test
from core.mysgd import create_my_sgd_optimizer
from core.replay_buffer import ReplayBuffer
from core.shared_storage import SharedStorage
from core.selfplay_worker import DataWorker
from core.reanalyze_worker import BatchWorker_GPU, BatchWorker_CPU

from core.storage_config import StorageConfig
from core.replay_buffer import get_replay_buffer
from core.shared_storage import get_shared_storage
from core.watchdog import get_watchdog_server
from core.meta_data_manager import MetaDataSharedMemoryManager

from core.shared_storage import start_shared_storage_server
from core.replay_buffer import start_replay_buffer_server
from core.watchdog import start_watchdog_server
from core.priority import init_priority_array, close_priority_array

from core.selfplay_worker import start_data_worker
from core.reanalyze_worker import start_batch_worker_cpu, start_batch_worker_gpu, start_priority_updater, start_value_updater
from core.test import start_test
from core.watchdog import start_watchdog

from core.model import get_ddp_model_weights

from core.xnode import PacketType, PriorityPacket, SignalPacket, BatchPacket
from core.xnode import batch_sender, batch_receiver, signal_publisher, signal_subscriber,\
    priority_publisher, priority_subscriber, replay_buffer_publisher, replay_buffer_subscriber

# from lamb.optim.lamb import create_lamb_optimizer

VIRTUAL_TEST=os.environ.get("VIRTUAL_TEST", False)


# torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = True

# torch.backends.cudnn.benchmark = True
def consist_loss_func(f1, f2):
    """Consistency loss function: similarity loss
    Parameters
    """
    f1 = F.normalize(f1, p=2., dim=-1, eps=1e-5)
    f2 = F.normalize(f2, p=2., dim=-1, eps=1e-5)
    return -(f1 * f2).sum(dim=1)


def adjust_lr(config, optimizer, step_count):
    # adjust learning rate, step lr every lr_decay_steps
    if step_count < config.lr_warm_step:
        lr = config.lr_init * step_count / config.lr_warm_step
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        lr = config.lr_init * config.lr_decay_rate ** ((step_count - config.lr_warm_step) // config.lr_decay_steps)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    return lr

def preload_batch(batch_storage_name, smos_client, step_count, rank):
    
    # obtain a batch
    batch_time0 = time.time()
    while True:
        status, handle, batch = smos_client.pop_from_object(name=batch_storage_name)
        if not status == SMOS.SMOS_SUCCESS:
            time.sleep(0.01) # 0.1
            # print(f"[Trainer {rank}] Warning: Waiting for new batch!!")
            continue
        policy_version_gap = (step_count - batch[3][:, 0].astype(np.int32)).reshape(-1)
        value_version_gap = (step_count - batch[3][:, 1].astype(np.int32)).reshape(-1)
        max_version_gap = 51 # config.target_model_interval + step_count % config.target_model_interval + 1 # step_count % config.target_model_interval
        if (policy_version_gap.mean() + value_version_gap.mean()) / 2 > max_version_gap:
            print(f"[Trainer {rank}] version gap {(policy_version_gap.mean() + value_version_gap.mean()) / 2} > {max_version_gap}, drop batch at step {step_count}")
            smos_client.free_handle(object_handle=handle)
            continue
        else:
            break
    batch_time1 = time.time()
    # total_batch_time += batch_time1 - batch_time0
    # total_version_gap += (policy_version_gap.mean() + value_version_gap.mean()) / 2

    batch = [[batch[0]] + batch[1], batch[2], batch[3], batch[4]]

    inputs_batch, targets_batch, target_version, worker_node_id = batch
    obs_batch_ori, action_batch, mask_batch, indices, weights_lst, make_time = inputs_batch
    target_value_prefix, target_value, target_policy, target_pred_policy, advantages = targets_batch

    obs_batch_ori = torch.from_numpy(obs_batch_ori).to(rank, non_blocking=True)

    # use GPU tensor
    action_batch = torch.from_numpy(action_batch).to(rank, non_blocking=True)
    mask_batch = torch.from_numpy(mask_batch).to(rank, non_blocking=True)
    target_value_prefix = torch.from_numpy(target_value_prefix).to(rank, non_blocking=True)
    target_value = torch.from_numpy(target_value).to(rank, non_blocking=True)
    target_policy = torch.from_numpy(target_policy).to(rank, non_blocking=True)
    weights_lst = torch.from_numpy(weights_lst).to(rank, non_blocking=True)

    batch = [obs_batch_ori, action_batch, mask_batch, indices, weights_lst, make_time], [target_value_prefix, target_value, target_policy, target_pred_policy, advantages], target_version, worker_node_id

    return batch, handle

# @profile
def new_update_weights(step_count, rank, model, batch, optimizer, config, scaler, smos_client, storage_config: StorageConfig,
                   vis_result=False):
    """update models given a batch data
    Parameters
    ----------
    rank: Any
        DDP trainer rank
    model: Any
        EfficientZero models
    batch: Any
        a batch data inlcudes [inputs_batch, targets_batch]
    scaler: Any
        scaler for torch amp
    vis_result: bool
        True -> log some visualization data in tensorboard (some distributions, values, etc)
    """
    inputs_batch, targets_batch, target_version, worker_node_id = batch
    obs_batch_ori, action_batch, mask_batch, indices, weights_lst, make_time = inputs_batch
    target_value_prefix, target_value, target_policy, target_pred_policy, advantages = targets_batch

    # [:, 0: config.stacked_observations * 3,:,:]
    # obs_batch_ori is the original observations in a batch
    # obs_batch is the observation for hat s_t (predicted hidden states from dynamics function)
    # obs_target_batch is the observations for s_t (hidden states from representation function)
    # to save GPU memory usage, obs_batch_ori contains (stack + unroll steps) frames
    obs_batch_ori = obs_batch_ori.float() / 255.0
    obs_batch = obs_batch_ori[:, 0: config.stacked_observations * config.image_channel, :, :]
    obs_target_batch = obs_batch_ori[:, config.image_channel:, :, :]

    # use GPU tensor
    action_batch = action_batch.unsqueeze(-1).long()
    mask_batch = mask_batch.float()
    target_value_prefix = target_value_prefix.float()
    target_value = target_value.float()
    target_policy = target_policy.float()
    weights = weights_lst.float()
    weights = torch.ones_like(weights)

    '''if step_count <= 4000:
        target_policy = torch.ones_like(target_policy) / target_policy.shape[-1]'''

    # do augmentations
    if config.use_augmentation:
        obs_batch = config.transform(obs_batch)
        obs_target_batch = config.transform(obs_target_batch)

    batch_size = obs_batch.size(0)
    # print("[Trainer]", batch_size, config.batch_size, target_value_prefix.size(0))
    assert (batch_size == config.batch_size == target_value_prefix.size(0)), (batch_size, config.batch_size, target_value_prefix.size(0))
    metric_loss = torch.nn.L1Loss()

    # transform targets to categorical representation
    transformed_target_value_prefix = config.reward_scalar_transform(target_value_prefix)
    target_value_prefix_phi = config.reward_phi(transformed_target_value_prefix)

    transformed_target_value = config.value_scalar_transform(target_value)
    target_value_phi = config.value_phi(transformed_target_value)

    # -------------------------------------------------------------------------------------------

    with autocast():
        policy_loss, value_loss, value_prefix_loss, consistency_loss, hidden_states, policy_logits, value, value_prefix, train_time_gaps = model(config, rank, obs_batch, action_batch, obs_target_batch, mask_batch, consist_loss_func, target_policy, target_value_phi, target_value_prefix_phi, indices)

    if vis_result:
        scaled_step0_value = config.inverse_value_transform(value[0].detach())
        # some logs preparation
        other_log = {}
        other_dist = {}

        other_loss = {
            'l1': -1,
            'l1_1': -1,
            'l1_-1': -1,
            'l1_0': -1,
        }
        for i in range(config.num_unroll_steps):
            key = 'unroll_' + str(i + 1) + '_l1'
            other_loss[key] = -1
            other_loss[key + '_1'] = -1
            other_loss[key + '_-1'] = -1
            other_loss[key + '_0'] = -1
        target_value_prefix_cpu = target_value_prefix.detach().cpu()
        # state_lst = hidden_states[0].detach().cpu().numpy()
        predicted_values = scaled_step0_value.detach().cpu()
        predicted_policies = torch.softmax(policy_logits[0], dim=1).detach().cpu()
        predicted_value_prefixs = []
        for step_i in range(config.num_unroll_steps):
            scaled_value_prefixs_cpu = config.inverse_reward_transform(value_prefix[step_i]).detach().cpu()

            predicted_values = torch.cat([predicted_values, config.inverse_value_transform(value[step_i + 1]).detach().cpu()])
            predicted_value_prefixs.append(scaled_value_prefixs_cpu)
            predicted_policies = torch.cat([predicted_policies, torch.softmax(policy_logits[step_i + 1], dim=1).detach().cpu()])
            # state_lst = np.concatenate([state_lst, hidden_states[step_i + 1].detach().cpu().numpy()])

            key = 'unroll_' + str(step_i + 1) + '_l1'

            value_prefix_indices_0 = (target_value_prefix_cpu[:, step_i].unsqueeze(-1) == 0)
            value_prefix_indices_n1 = (target_value_prefix_cpu[:, step_i].unsqueeze(-1) == -1)
            value_prefix_indices_1 = (target_value_prefix_cpu[:, step_i].unsqueeze(-1) == 1)

            target_value_prefix_base = target_value_prefix_cpu[:, step_i].reshape(-1).unsqueeze(-1)

            other_loss[key] = metric_loss(scaled_value_prefixs_cpu, target_value_prefix_base)
            if value_prefix_indices_1.any():
                other_loss[key + '_1'] = metric_loss(scaled_value_prefixs_cpu[value_prefix_indices_1], target_value_prefix_base[value_prefix_indices_1])
            if value_prefix_indices_n1.any():
                other_loss[key + '_-1'] = metric_loss(scaled_value_prefixs_cpu[value_prefix_indices_n1], target_value_prefix_base[value_prefix_indices_n1])
            if value_prefix_indices_0.any():
                other_loss[key + '_0'] = metric_loss(scaled_value_prefixs_cpu[value_prefix_indices_0], target_value_prefix_base[value_prefix_indices_0])
    
    # --------------------------------------------------------------------------------------

    # weighted loss with masks (some invalid states which are out of trajectory.)
    loss = (config.consistency_coeff * consistency_loss + config.policy_loss_coeff * policy_loss +
            config.value_loss_coeff * value_loss + config.reward_loss_coeff * value_prefix_loss)
    weighted_loss = (weights * loss).mean()

    # backward
    gradient_scale = 1 / config.num_unroll_steps
    parameters = model.parameters()
    if config.amp_type == 'torch_amp':
        with autocast():
            total_loss = weighted_loss
            total_loss.register_hook(lambda grad: grad * gradient_scale)
    else:
        total_loss = weighted_loss
        total_loss.register_hook(lambda grad: grad * gradient_scale)
    optimizer.zero_grad()

    if config.amp_type == 'none':
        total_loss.backward()
    elif config.amp_type == 'torch_amp':
        scaler.scale(total_loss).backward()
        scaler.unscale_(optimizer)

    clip_grad_norm(parameters, config.max_grad_norm)
    if config.amp_type == 'torch_amp':
        scaler.step(optimizer)
        scaler.update()
    else:
        optimizer.step()

    if vis_result and step_count % 100 == 0:
        import sys
        _msg = '#{:<10} Total Loss: {:<8.3f} [weighted Loss:{:<8.3f} Policy Loss: {:<8.3f} Value Loss: {:<8.3f} ' \
            'Reward Sum Loss: {:<8.3f} Consistency Loss: {:<8.3f} ] \n    ' \
            'Policy Version Gap: ({:<4.1f}, {:<4.1f}, {:<4.1f}) Value Version Gap: ({:<4.1f}, {:<4.1f}, {:<4.1f}) Step 0 value: {:<8.3f}, Time: {}'
        _msg = _msg.format(step_count, total_loss, weighted_loss, policy_loss.mean().item(), value_loss.mean().item(), value_prefix_loss.mean().item(), consistency_loss.mean().item(), 
                            step_count - target_version[:, 0].mean(), (step_count - target_version[:, 0]).min(), (step_count - target_version[:, 0]).max(),
                            step_count - target_version[:, 1].mean(), (step_count - target_version[:, 1]).min(), (step_count - target_version[:, 1]).max(),
                            scaled_step0_value.mean().item(), train_time_gaps)
        print(_msg, flush=True)

        print("policy: ", torch.softmax(policy_logits[0, 0], dim=-1).detach().cpu(), "target policy: ", target_policy.transpose(0, 1)[0, 0].detach().cpu())
        print("Time Gap:", time.time() - np.mean(make_time))

    # packing data for logging
    if vis_result:
        loss_data = (total_loss.item(), weighted_loss.item(), loss.mean().item(), 0, policy_loss.mean().item(),
                    value_prefix_loss.mean().item(), value_loss.mean().item(), consistency_loss.mean())
        new_priority = L1Loss(reduction='none')(scaled_step0_value.squeeze(-1), target_value[:, 0]).data.cpu().numpy() + config.prioritized_replay_eps
        target_value_prefix_cpu = target_value_prefix.detach().cpu()
        reward_w_dist, representation_mean, dynamic_mean, reward_mean, prediction_mean = model.module.get_params_mean()
        other_dist['reward_weights_dist'] = reward_w_dist
        other_log['representation_weight'] = representation_mean
        other_log['dynamic_weight'] = dynamic_mean
        other_log['reward_weight'] = reward_mean
        other_log['prediction_weight'] = prediction_mean

        # reward l1 loss
        value_prefix_indices_0 = (target_value_prefix_cpu[:, :config.num_unroll_steps].reshape(-1).unsqueeze(-1) == 0)
        value_prefix_indices_n1 = (target_value_prefix_cpu[:, :config.num_unroll_steps].reshape(-1).unsqueeze(-1) == -1)
        value_prefix_indices_1 = (target_value_prefix_cpu[:, :config.num_unroll_steps].reshape(-1).unsqueeze(-1) == 1)

        target_value_prefix_base = target_value_prefix_cpu[:, :config.num_unroll_steps].reshape(-1).unsqueeze(-1)

        predicted_value_prefixs = torch.stack(predicted_value_prefixs).transpose(1, 0).squeeze(-1)
        predicted_value_prefixs = predicted_value_prefixs.reshape(-1).unsqueeze(-1)
        other_loss['l1'] = metric_loss(predicted_value_prefixs, target_value_prefix_base)
        if value_prefix_indices_1.any():
            other_loss['l1_1'] = metric_loss(predicted_value_prefixs[value_prefix_indices_1], target_value_prefix_base[value_prefix_indices_1])
        if value_prefix_indices_n1.any():
            other_loss['l1_-1'] = metric_loss(predicted_value_prefixs[value_prefix_indices_n1], target_value_prefix_base[value_prefix_indices_n1])
        if value_prefix_indices_0.any():
            other_loss['l1_0'] = metric_loss(predicted_value_prefixs[value_prefix_indices_0], target_value_prefix_base[value_prefix_indices_0])

        td_data = (new_priority, target_value_prefix.detach().cpu().numpy(), target_value.detach().cpu().numpy(),
                   transformed_target_value_prefix.detach().cpu().numpy(), transformed_target_value.detach().cpu().numpy(),
                   target_value_prefix_phi.detach().cpu().numpy(), target_value_phi.detach().cpu().numpy(),
                   predicted_value_prefixs.detach().cpu().numpy(), predicted_values.detach().cpu().numpy(),
                   target_policy.detach().cpu().numpy(), predicted_policies.detach().cpu().numpy(), None,
                   other_loss, other_log, other_dist)
        priority_data = (weights, indices)
    else:
        loss_data, td_data, priority_data = None, None, None
    other_data = (step_count - target_version[:, 0], step_count - target_version[:, 1],  mask_batch.mean())

    return loss_data, td_data, priority_data, other_data, scaler

def log_worker_status(data_worker_list, cpu_worker_list, gpu_worker_list, priority_updater_list, value_updater_list, 
                      test_worker,
                      batch_sender_list, signal_subscriber_proc, priority_subscriber_proc,
                      watchdog_process,
                      shared_storage_server, replay_buffer_server, watchdog_server, smos_server,
                      replay_buffer_publisher_proc, replay_buffer_subscriber_proc):
    # data workers
    total_data_workers = len(data_worker_list)
    alive_data_workers = 0
    for data_worker in data_worker_list:
        if data_worker.is_alive():
            alive_data_workers += 1

    # cpu workers
    total_cpu_workers = len(cpu_worker_list)
    alive_cpu_workers = 0
    for cpu_worker in cpu_worker_list:
        if cpu_worker.is_alive():
            alive_cpu_workers += 1

    # gpu workers
    total_gpu_workers = len(gpu_worker_list)
    alive_gpu_workers = 0
    for gpu_worker in gpu_worker_list:
        if gpu_worker.is_alive():
            alive_gpu_workers += 1

    # priority updaters
    total_priority_updaters = len(priority_updater_list)
    alive_priority_updaters = 0
    for priority_updater in priority_updater_list:
        if priority_updater.is_alive():
            alive_priority_updaters += 1

    # value updaters
    total_value_updaters = len(value_updater_list)
    alive_value_updaters = 0
    for value_updater in value_updater_list:
        if value_updater.is_alive():
            alive_value_updaters += 1

    # batch senders
    total_batch_senders = len(batch_sender_list)
    alive_batch_senders = 0
    for batch_sender_proc in batch_sender_list:
        if batch_sender_proc.is_alive():
            alive_batch_senders += 1

    # log
    print("************************ Worker Status ************************")
    print(f"Alive data workers: {alive_data_workers}/{total_data_workers}")
    print(f"Alive cpu workers: {alive_cpu_workers}/{total_cpu_workers}")
    print(f"Alive gpu workers: {alive_gpu_workers}/{total_gpu_workers}")
    print(f"Alive priority updaters: {alive_priority_updaters}/{total_priority_updaters}")
    print(f"Alive value updaters: {alive_value_updaters}/{total_value_updaters}")
    # print(f"Is test alive: {test_worker.is_alive()}")
    print(f"Alive batch senders: {alive_batch_senders}/{total_batch_senders}")
    print(f"Is signal subscriber alive: {signal_subscriber_proc.is_alive()}")
    print(f"Is priority subscriber alive: {priority_subscriber_proc.is_alive()}")
    print("***************************************************************")

    def terminate():
        # wait for all workers to finish
        for data_worker in data_worker_list:
            if data_worker.is_alive():
                data_worker.terminate()
        print(f"[main process] All data workers have terminated.")
        for cpu_worker in cpu_worker_list:
            if cpu_worker.is_alive():
                cpu_worker.terminate()
        print(f"[main process] All CPU workers have terminated.")
        for gpu_worker in gpu_worker_list:
            if gpu_worker.is_alive():
                gpu_worker.terminate()
        print(f"[main process] All GPU workers have terminated.")
        for priority_updater in priority_updater_list:
            if priority_updater.is_alive():
                priority_updater.terminate()
        print(f"[main process] All priority updaters have terminated.")
        for value_updater in value_updater_list:
            if value_updater.is_alive():
                value_updater.terminate()
        print(f"[main process] All value updaters have terminated.")
        if replay_buffer_publisher_proc is not None and replay_buffer_publisher_proc.is_alive():
            replay_buffer_publisher_proc.terminate()
        print(f"[main process] Replay buffer publisher has terminated.")
        if replay_buffer_subscriber_proc is not None and replay_buffer_subscriber_proc.is_alive():
            replay_buffer_subscriber_proc.terminate()
        print(f"[main process] Replay buffer subscriber has terminated.")
        if test_worker is not None:
            test_worker.join()
        watchdog_process.terminate()
        print(f"[main process] All workers have stopped.")
        # terminate xnode workers
        for batch_sender_proc in batch_sender_list:
            batch_sender_proc.terminate()
        priority_subscriber_proc.terminate()
        signal_subscriber_proc.terminate()

        # stop servers
        shared_storage_server.terminate()
        replay_buffer_server.terminate()
        watchdog_server.terminate()
        smos_server.stop()
        print(f"[main process] All servers have stopped.")

    if (alive_data_workers < total_data_workers):
        terminate()
        raise RuntimeError("Some data worker died!")
    
    if (alive_value_updaters < total_value_updaters):
        terminate()
        raise RuntimeError("Some value updater died!")

    if total_cpu_workers > 0 and alive_cpu_workers == 0:
        terminate()
        raise RuntimeError("All cpu worker died!")
    
    if total_gpu_workers > 0 and alive_gpu_workers == 0:
        terminate()
        raise RuntimeError("All gpu worker died!")

    if test_worker is not None:
        if not test_worker.is_alive():
            terminate()
            raise RuntimeError("Test worker died!")
    
    if replay_buffer_publisher_proc is not None and not replay_buffer_publisher_proc.is_alive():
        terminate()
        raise RuntimeError("Replay buffer publisher died!")
    
    if replay_buffer_subscriber_proc is not None and not replay_buffer_subscriber_proc.is_alive():
        terminate()
        raise RuntimeError("Replay buffer subscriber died!")


def log_trainer_status(batch_receiver_list, signal_publisher_proc, priority_publisher_proc):
    # batch receivers
    total_batch_receivers = len(batch_receiver_list)
    alive_batch_receivers = 0
    for batch_receiver_proc in batch_receiver_list:
        if batch_receiver_proc.is_alive():
            alive_batch_receivers += 1

    # log
    print("************************ Worker Status ************************")
    print(f"Alive batch receivers: {alive_batch_receivers}/{total_batch_receivers}")
    print(f"Is signal publisher alive: {signal_publisher_proc.is_alive()}")
    print(f"Is priority publisher alive: {priority_publisher_proc.is_alive()}")
    print("***************************************************************")

def train_loop(step_count, rank, config, storage_config, preloaded_batch, batch_receiver_list, signal_publisher_proc, priority_publisher_proc, shared_storage, smos_client, optimizer, scaler, model, watchdog_server):
    # start time
    start_time = time.time()
    # time.sleep(0.08)

    # log
    if step_count % 100 == 0 and rank == 0:
        log_trainer_status(batch_receiver_list, signal_publisher_proc, priority_publisher_proc)

    # adjust learning rate
    if rank == 0:
        shared_storage.incr_counter()
        increase_counter_packet = SignalPacket(packet_type=PacketType.SIG_INCREASE_COUNTER,
                                                create_steps=step_count, data=None)
        smos_client.push_to_object(name=storage_config.signal_queue_name, data=[increase_counter_packet])
    lr = adjust_lr(config, optimizer, step_count)

    # update model for self-play
    if step_count % config.checkpoint_interval == 0 and rank == 0:
        ddp_weights = get_ddp_model_weights(ddp_model=model)
        # shared_storage.set_weights(ddp_weights)
        model_packet = SignalPacket(packet_type=PacketType.MODEL,
                                    create_steps=step_count, data=(ddp_weights, step_count))
        smos_client.push_to_object(name=storage_config.signal_queue_name, data=[model_packet])

    batch, handle = preloaded_batch

    if step_count % config.vis_interval == 0 and rank == 0:
        vis_result = True
    else:
        vis_result = False

    if config.amp_type == 'torch_amp':
        # log_data = update_weights(rank, model, batch, optimizer, config, scaler, vis_result)
        log_data = new_update_weights(step_count=step_count, rank=rank, model=model, batch=batch, optimizer=optimizer,
                                    config=config, scaler=scaler, vis_result=vis_result,
                                    storage_config=storage_config, smos_client=smos_client)
        scaler = log_data[-1]
    else:
        # log_data = update_weights(rank, model, batch, optimizer, config, scaler, vis_result)
        log_data = new_update_weights(step_count=step_count, rank=rank, model=model, batch=batch, optimizer=optimizer,
                                    config=config, scaler=scaler, vis_result=vis_result,
                                    storage_config=storage_config, smos_client=smos_client)

    if step_count % config.log_interval == 0 and rank == 0:
        # _log(config, step_count, log_data[0:3], model, replay_buffer, lr, shared_storage,
        #      summary_writer, vis_result)
        log_packet = SignalPacket(packet_type=PacketType.LOG, create_steps=step_count,
                                    data=[config, step_count, [*log_data[0:4], None], lr, vis_result])
        smos_client.push_to_object(name=storage_config.signal_queue_name, data=[log_packet])
        # print(f"[WANDB LOG] wandb run log at step {step_count}: {log_data[0:3]}")

    # clean up
    smos_client.free_handle(object_handle=handle)

def _train(rank, model, target_model, smos_client, shared_storage, config, storage_config: StorageConfig,
           batch_receiver_list=None, signal_publisher_proc=None, priority_publisher_proc=None):
    """training loop
    Parameters
    ----------
    model: Any
        EfficientZero models
    target_model: Any
        EfficientZero models for reanalyzing
    summary_writer: Any
        logging for tensorboard
    """
    time.sleep(10)
    # ----------------------------------------------------------------------------------
    model.to(rank)
    model = DDP(model, device_ids=[rank])
    # target_model = target_model.to(rank)

    if config.optimizer == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=config.lr_init, momentum=config.momentum,
                          weight_decay=config.weight_decay)
    elif config.optimizer == "lamb":
        optimizer = create_lamb_optimizer(
        model, lr=config.lr_init, weight_decay=config.weight_decay, bias_correction=True, rank=rank)
    else:
        raise RuntimeError(f"Optimizer {config.optimizer} is not supported.")


    scaler = GradScaler()

    model.train()
    # target_model.eval()
    # ----------------------------------------------------------------------------------
    # set augmentation tools
    if config.use_augmentation:
        config.set_transforms()

    # get watchdog server
    if rank == 0:
        watchdog_server = get_watchdog_server(storage_config)
    else:
        watchdog_server = None

    # wait for all worker nodes to start
    print(f"[Trainer {rank}]", VIRTUAL_TEST, shared_storage.get_num_ready_worker_nodes(), config.num_worker_nodes)
    if rank == 0 and not VIRTUAL_TEST:
        while shared_storage.get_num_ready_worker_nodes() < config.num_worker_nodes:
            time.sleep(5.)
            print(f"[Trainer {rank}] Not all nodes are ready.", shared_storage.get_num_ready_worker_nodes(), config.num_worker_nodes)
            pass
        print(f"[Trainer {rank}] All worker nodes are ready. Start training.")
        start_training_packet = SignalPacket(packet_type=PacketType.START_TRAINING, create_steps=0, data=None)
        smos_client.push_to_object(name=storage_config.signal_queue_name, data=[start_training_packet])
    dist.barrier()

    time.sleep(30)

    print(f"[Trainer {rank}] Start training.", flush=True)
    

    # while loop
    total_time, step_count = 0, 0
    log_interval_counter = 0
    total_version_gap = 0
    total_batch_time = 0

    # preload batches
    batch_storage_name = storage_config.batch_storage_name + f"{rank % storage_config.batch_storage_count_trainer}"
    batch = preload_batch(batch_storage_name, smos_client, step_count, rank)
    preloaded_batch = preload_batch(batch_storage_name, smos_client, step_count, rank)

    while step_count < config.training_steps + config.last_steps:
        start_time = time.time()
        time.sleep(0.1)

        train_loop(step_count, rank, config, storage_config, batch, batch_receiver_list, signal_publisher_proc, priority_publisher_proc, shared_storage, smos_client, optimizer, scaler, model, watchdog_server)
        
        # increase training step
        step_count += 1

        # preload a batch
        batch = preloaded_batch
        preloaded_batch = preload_batch(batch_storage_name, smos_client, step_count, rank)

        # save models
        if step_count % config.save_ckpt_interval == 0 and rank == 0:
            model_path = os.path.join(config.model_dir, 'model_{}.p'.format(step_count))
            torch.save(model.state_dict(), model_path)

        # update watchdog server
        if rank == 0:
            watchdog_server.increase_training_step_count()

        # log training status
        end_time = time.time()
        total_time += end_time - start_time
        if step_count % 100 == 0:
            batch_storage_name = storage_config.batch_storage_name + f"{rank % storage_config.batch_storage_count_trainer}"
            _, batch_queue_size = smos_client.get_entry_count(name=batch_storage_name)
            print('[Trainer {}] loop={}, Avg. Tloop={:.5f}, Lst. Tloop={:.5f}, batch storage size={}, Avg. Ver. Gap={:.2f}, Avg. Batch={:.2f}'
                    .format(rank, step_count, total_time / step_count, end_time - start_time, batch_queue_size, total_version_gap / step_count, total_batch_time / step_count))


    ddp_weights = get_ddp_model_weights(ddp_model=model)
    # shared_storage.set_weights(ddp_weights)
    '''if rank == 0:
        model_packet = SignalPacket(packet_type=PacketType.MODEL,
                                    create_steps=step_count, data=ddp_weights)
        smos_client.push_to_object(name=storage_config.signal_queue_name, data=[model_packet])
    time.sleep(15)'''
    return ddp_weights


def start_train(rank, model, target_model, config, storage_config: StorageConfig,
                batch_receiver_list=None, signal_publisher_proc=None, priority_publisher_proc=None):
    """
    Start trainer in current process.
    """
    # get storages
    smos_client = SMOS.Client(connection=storage_config.smos_connection)
    shared_storage = get_shared_storage(storage_config=storage_config)

    # start trainer
    print(f"[Trainer] Start trainer {rank} at process {os.getpid()}.")
    final_weights = _train(rank=rank, model=model, target_model=target_model, smos_client=smos_client, shared_storage=shared_storage, config=config,
                           storage_config=storage_config, batch_receiver_list=batch_receiver_list,
                           signal_publisher_proc=signal_publisher_proc, priority_publisher_proc=priority_publisher_proc)
    return final_weights



def initialize_trainer(config, model_path=None, local_rank=-1):
    """training process
    Parameters
    ----------
    config: Any
        Atari configuration
    model_path: str
        model path for resuming
        default: train from scratch
    local_rank: Any
        local rank of DDP
    """
    # initialize model
    model = config.get_uniform_network(is_trainer=True)
    model.to(local_rank)
    model.set_device(config.device)
    target_model = None
    if model_path and local_rank == 0:
        print('resume model from path: ', model_path)
        weights = torch.load(model_path)

        model.load_state_dict(weights)
    print(f"[main process] trainer {local_rank} will start at {os.getpid()}")

    # set process title
    setproctitle.setproctitle(f"EfficientZero-trainer_worker{local_rank}")

    # initialize storage
    storage_config = StorageConfig()
    if local_rank == 0:
        """"""""""""""""""""""""""""""""""""""" Storages """""""""""""""""""""""""""""""""""""""
        # metadata storage
        meta_data_manager = MetaDataSharedMemoryManager(config, storage_config, create=True)
        
        # mp context
        ctx = mp.get_context('spawn')

        # start server if it's master trainer
        smos_server = SMOS.Server(connection=storage_config.smos_connection)
        smos_server.start()
        print("[main process] SMOS server has been started from main process.")

        # watchdog server
        watchdog_server = ctx.Process(target=start_watchdog_server, args=(storage_config,))
        watchdog_server.start()
        print("[main process] Watchdog server has been started from main process.")

        # create storages
        smos_client = SMOS.Client(connection=storage_config.smos_connection)
        # batch storage
        for batch_storage_idx in range(storage_config.batch_storage_count_trainer):
            batch_storage_name = storage_config.batch_storage_name + f"{batch_storage_idx}"
            smos_client.create_object(name=batch_storage_name, max_capacity=storage_config.batch_storage_capacity_trainer,
                                      track_count=5, block_size=storage_config.batch_storage_block_size_list + [128])
        print("[main process] Batch storage has been initialized.")

        # xnode storages
        smos_client.create_object(name=storage_config.signal_queue_name,
                                  max_capacity=storage_config.signal_queue_capacity,
                                  track_count=1, block_size=storage_config.signal_queue_block_size_list)
        smos_client.create_object(name=storage_config.priority_queue_name,
                                  max_capacity=storage_config.priority_queue_capacity,
                                  track_count=4, block_size=storage_config.priority_queue_block_size_list)

        # shared_storage server
        time.sleep(0.1)
        shared_storage_server = ctx.Process(target=start_shared_storage_server,
                                            args=(config, storage_config, None, None))
        shared_storage_server.start()
        print("[main process] Shared storage server has been started from main process.")
        time.sleep(0.1)

        """"""""""""""""""""""""""""""""""""""" Xnode """""""""""""""""""""""""""""""""""""""
        # batch receiver
        batch_receiver_list = [ctx.Process(target=batch_receiver,
                                           args=(receiver_id, config, storage_config))
                               for receiver_id in range(0, storage_config.batch_receiver_count)]
        for batch_receiver_proc in batch_receiver_list:
            batch_receiver_proc.start()
            time.sleep(0.1)
        print("[main process] Batch receivers have all been launched.")

        # signal publisher
        signal_publisher_proc = ctx.Process(target=signal_publisher, args=(storage_config, ))
        signal_publisher_proc.start()
        time.sleep(0.1)
        print("[main process] Signal publisher has been launched.")

        # priority publisher
        priority_publisher_proc = ctx.Process(target=priority_publisher, args=(storage_config, ))
        priority_publisher_proc.start()
        time.sleep(0.1)
        print("[main process] Priority publisher has been launched.")

        """"""""""""""""""""""""""""""""""""""" Training """""""""""""""""""""""""""""""""""""""
        # watchdog
        watchdog_process = ctx.Process(target=start_watchdog, args=(config, storage_config, "trainer"))
        watchdog_process.start()
        print("[main process] Watchdog has been launched.")

        # start training
        final_weights = start_train(rank=local_rank, model=model, target_model=target_model, config=config,
                                    storage_config=storage_config, batch_receiver_list=batch_receiver_list,
                                    signal_publisher_proc=signal_publisher_proc,
                                    priority_publisher_proc=priority_publisher_proc)
        time.sleep(15)

        """"""""""""""""""""""""""""""""""""""" Clean up """""""""""""""""""""""""""""""""""""""

        def terminate_all():
            # terminate all xnode workers
            for batch_receiver_proc in batch_receiver_list:
                batch_receiver_proc.terminate()
            signal_publisher_proc.terminate()
            priority_publisher_proc.terminate()

            # stop smos and watchdog
            watchdog_process.terminate()
            watchdog_server.terminate()
            smos_server.stop()

    else:
        # start training
        final_weights = start_train(rank=local_rank, model=model, target_model=target_model, config=config,
                                    storage_config=storage_config)
        
        terminate_all = None

    # clean up and return
    dist.destroy_process_group()
    return model, final_weights, terminate_all


def initialize_worker(config, exp_path, model_path=None):
    """
    initialize a worker node
    """
    """"""""""""""""""""""""""""""""""""""" Init """""""""""""""""""""""""""""""""""""""
    # initialize model
    model = config.get_uniform_network(is_data_worker=True)
    target_model = config.get_uniform_network(is_data_worker=True)
    if model_path:
        print('resume model from path: ', model_path)
        weights = torch.load(model_path)

        model.load_state_dict(weights)
        target_model.load_state_dict(weights)

    # initialize storage config and multiprocessing context
    storage_config = StorageConfig(worker_id=config.worker_node_id)

    config.update_config(storage_config)

    ctx = mp.get_context('spawn')

    """"""""""""""""""""""""""""""""""""""" Storages """""""""""""""""""""""""""""""""""""""
    # metadata storage
    meta_data_manager = MetaDataSharedMemoryManager(config, storage_config, create=True)
    
    # smos server
    smos_server = SMOS.Server(connection=storage_config.smos_connection)
    smos_server.start()
    print("[main process] SMOS server has been started from main process.")

    # mcts storage and batch storage
    smos_client = SMOS.Client(connection=storage_config.smos_connection)
    for mcts_idx in range(storage_config.mcts_storage_count):
        mcts_storage_name = storage_config.mcts_storage_name + f"{mcts_idx}"
        smos_client.create_object(name=mcts_storage_name, max_capacity=storage_config.mcts_storage_capacity,
                                  track_count=9, block_size=storage_config.mcts_block_size_list)
    print("[main process] MCTS storage has been initialized.")
    for batch_storage_idx in range(storage_config.batch_storage_count_worker):
        batch_storage_name = storage_config.batch_storage_name + f"{batch_storage_idx}"
        smos_client.create_object(name=batch_storage_name, max_capacity=storage_config.batch_storage_capacity_worker,
                                  track_count=6, block_size=storage_config.batch_storage_block_size_list_worker)
    print("[main process] Batch storage has been initialized.")

    # weights storage
    for weights_storage_idx in range(storage_config.weights_storage_cycle_size):
        weights_storage_name = storage_config.weights_storage_name + str(weights_storage_idx)
        smos_client.create_object(name=weights_storage_name, max_capacity=8, track_count=2, block_size=[256 * (1024 ** 2), 1024])
    print("[main process] Weights storage has been initialized.")

    for priority_idx in range(storage_config.priority_storage_count):
        priority_storage_name = storage_config.priority_storage_name + f"{priority_idx}"
        smos_client.create_object(name=priority_storage_name, max_capacity=storage_config.priority_storage_capacity, track_count=1, block_size=[256 * 1024])
    print("[main process] Priority storage has been initialized.")

    # control storage
    smos_client.create_object(name=storage_config.reanalyze_control_storage_name, max_capacity=storage_config.reanalyze_control_storage_capacity, track_count=1, block_size=[1024])
    for i in range(storage_config.reanalyze_control_storage_capacity):
        status, _ = smos_client.push_to_object(name=storage_config.reanalyze_control_storage_name,
                                                    data=[i])
        assert status == SMOS.SMOS_SUCCESS
    print("[main process] Reanalyze control storage has been initialized.")

    time.sleep(0.1)
    # shared storage server
    shared_storage_server = ctx.Process(target=start_shared_storage_server,
                                        args=(config, storage_config, model, target_model))
    shared_storage_server.start()
    print("[main process] Shared storage server has been started from main process.")

    time.sleep(0.1)

    # replay buffer
    smos_client.create_object(name=storage_config.replay_buffer_name,
                              max_capacity=storage_config.replay_buffer_capacity,
                              track_count=6, block_size=storage_config.replay_buffer_block_size_list)
    smos_client.create_object(name=storage_config.zombie_queue_name,
                              max_capacity=storage_config.zombie_queue_capacity,
                              track_count=1, block_size=storage_config.zombie_queue_block_size)
    replay_buffer_server = ctx.Process(target=start_replay_buffer_server,
                                       args=(storage_config, config))
    replay_buffer_server.start()
    print("[main process] Replay buffer server has been started from main process.")

    """"""""""""""""""""""""""""""""""""" Watchdog """""""""""""""""""""""""""""""""""""

    # watchdog server
    watchdog_server = ctx.Process(target=start_watchdog_server, args=(storage_config,))
    watchdog_server.start()
    print("[main process] Watchdog server has been started from main process.")

    """"""""""""""""""""""""""""""""""""""" Xnode """""""""""""""""""""""""""""""""""""""
    # get worker node id from bash script
    worker_node_id = config.worker_node_id

    if storage_config.is_data_worker:
        # data worker node
        replay_buffer_publisher_proc = ctx.Process(target=replay_buffer_publisher, args=(storage_config, config))
        replay_buffer_subscriber_proc = None
        replay_buffer_publisher_proc.start()
        time.sleep(0.1)
        print("[main process] Replay buffer publisher has been launched.")
    else:
        # normal worker node
        time.sleep(1.)
        replay_buffer_publisher_proc = None
        replay_buffer_subscriber_proc = ctx.Process(target=replay_buffer_subscriber, args=(worker_node_id, storage_config, config))
        replay_buffer_subscriber_proc.start()
        time.sleep(0.1)
        print("[main process] Replay buffer subscriber has been launched.")

    # signal subscriber
    signal_subscriber_proc = ctx.Process(target=signal_subscriber,
                                         args=(exp_path, config, storage_config))
    signal_subscriber_proc.start()
    time.sleep(0.1)
    print("[main process] Signal subscriber has been launched.")

    # priority subscriber
    priority_subscriber_proc = ctx.Process(target=priority_subscriber,
                                           args=(worker_node_id, config, storage_config))
    priority_subscriber_proc.start()
    time.sleep(0.1)
    print("[main process] Priority subscriber has been launched.")

    # batch sender
    batch_sender_list = [ctx.Process(target=batch_sender,
                                     args=(sender_id, worker_node_id, config.batch_size // config.reanalyze_batch_size, config, storage_config))
                         for sender_id in range(0, storage_config.per_worker_batch_sender_count)]
    for batch_sender_proc in batch_sender_list:
        batch_sender_proc.start()
        time.sleep(0.1)
    print("[main process] Batch senders have all been launched.")

    """"""""""""""""""""""""""""""""""""""" Workers """""""""""""""""""""""""""""""""""""""
    # data workers
    data_workers = [ctx.Process(target=start_data_worker, args=(rank, config, storage_config))
                    for rank in range(0, config.num_actors)]
    for data_worker in data_workers:
        data_worker.start()
        time.sleep(0.1)
    print("[main process] Data workers have all been launched.")

    # cpu workers
    cpu_workers = [ctx.Process(target=start_batch_worker_cpu, args=(worker_idx, config, storage_config))
                   for worker_idx in range(config.cpu_actor)]
    for cpu_worker in cpu_workers:
        cpu_worker.start()
        time.sleep(0.1)
    print("[main process] CPU batch workers have all been launched.", config.worker_node_id, config.cpu_actor)

    # gpu workers
    gpu_workers = [ctx.Process(target=start_batch_worker_gpu, args=(worker_idx, config, storage_config))
                   for worker_idx in range(config.gpu_actor)]
    for gpu_worker in gpu_workers:
        gpu_worker.start()
        time.sleep(0.1)
    print("[main process] GPU batch workers have all been launched.")

    # priority updater
    priority_updaters = [ctx.Process(target=start_priority_updater, args=(worker_idx, config, storage_config))
                   for worker_idx in range(config.priority_updater)]
    for priority_updater in priority_updaters:
        priority_updater.start()
        time.sleep(0.1)
    print("[main process] Priority updaters have all been launched.")

    # value updater
    value_updaters = [ctx.Process(target=start_value_updater, args=(worker_idx, config, storage_config))
                   for worker_idx in range(config.value_updater)]
    for value_updater in value_updaters:
        value_updater.start()
        time.sleep(0.1)
    print("[main process] Value updaters have all been launched.")

    # watchdog
    watchdog_process = ctx.Process(target=start_watchdog, args=(config, storage_config, "worker"))
    watchdog_process.start()
    print("[main process] Watchdog has been launched.")

    # test
    # test_process = ctx.Process(target=start_test, args=(config, storage_config, False, None, None, False))
    # test_process.start()
    # print("[main process] Test process has been launched.")

    time.sleep(30)

    """"""""""""""""""""""""""""""""""""""" Logging """""""""""""""""""""""""""""""""""""""
    shared_storage = get_shared_storage(storage_config=storage_config)
    while True:
        # ending condition
        trained_steps = shared_storage.get_counter()
        if trained_steps >= config.training_steps + config.last_steps:
            time.sleep(60)
            break

        # logging
        log_worker_status(data_worker_list=data_workers, gpu_worker_list=gpu_workers, priority_updater_list=priority_updaters, value_updater_list=value_updaters,
                          # cpu_worker_list=cpu_workers, test_worker=test_process,
                          cpu_worker_list=cpu_workers, test_worker=None,
                          batch_sender_list=batch_sender_list, signal_subscriber_proc=signal_subscriber_proc,
                          priority_subscriber_proc=priority_subscriber_proc, watchdog_process=watchdog_process,
                          shared_storage_server=shared_storage_server, replay_buffer_server=replay_buffer_server, watchdog_server=watchdog_server, smos_server=smos_server,
                          replay_buffer_publisher_proc=replay_buffer_publisher_proc, replay_buffer_subscriber_proc=replay_buffer_publisher_proc)
        time.sleep(30)

    """"""""""""""""""""""""""""""""""""""" Clean up """""""""""""""""""""""""""""""""""""""
    # wait for all workers to finish
    for data_worker in data_workers:
        data_worker.join()
    print(f"[main process] All data workers have terminated.")
    for cpu_worker in cpu_workers:
        cpu_worker.terminate()
    print(f"[main process] All CPU workers have terminated.")
    for gpu_worker in gpu_workers:
        gpu_worker.terminate()
    print(f"[main process] All GPU workers have terminated.")
    # test_process.join()
    for priority_updater in priority_updaters:
        priority_updater.terminate()
    print(f"[main process] All priority updaters have terminated.")
    for value_updater in value_updaters:
        value_updaters.terminate()
    print(f"[main process] All value updaters have terminated.")
    watchdog_process.terminate()
    print(f"[main process] All workers have stopped.")

    # terminate xnode workers
    for batch_sender_proc in batch_sender_list:
        batch_sender_proc.terminate()
    priority_subscriber_proc.terminate()
    signal_subscriber_proc.terminate()

    # stop servers
    shared_storage_server.terminate()
    replay_buffer_server.terminate()
    watchdog_server.terminate()
    smos_server.stop()
    print(f"[main process] All servers have stopped.")
