import glob
import logging
import math
import os

import tempfile
import time
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from typing import Any, Callable, Dict, Tuple

import hydra
import mlflow
import torch
import torch.distributed
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from data import ReplayBufferConfig, TTDatasetConfig

from data.data_lib import AIGDatasetConfig
from data.dataset import (
    AIG_Dataset_Collection,
    AIGBatchSampler,
    DistributedBatchSampler,
    EvalAIGBatchSampler,
)
from data.pyaig.aig_env import AIGEnv, AIGEnvConfig
from loss.loss_lib import LossConfig

from misc_utils import time_formatter
from model.model_lib import ActivationConfig, ModelConfig
from model.utils import (
    kldiv_activation,
    load_snapshot,
    mlflow_load_snapshot,
    mlflow_save_snapshot,
    normalize_action,
    save_snapshot,
    Snapshot,
    tanh_activation,
)
from omegaconf import DictConfig, ListConfig, MISSING, OmegaConf
from optimizer.optimizer_lib import OptimizerConfig
from rl import generate_AIG
from rl.her import (
    AIGNegateTarget,
    AIGRewardTransform,
    AIGSubGoalAssigner,
    HERConfig,
    HERSubGoalSampler,
    HindsightExperienceReplayTransform,
)

from rl.mcts_policy import (
    ActionExplorationModule,
    AlphaZeroConfig,
    AlphaZeroExpansionStrategy,
    DirichletNoiseModule,
    MctsPolicy,
    PuctSelectionPolicy,
    SimulatedSearchPolicy,
    UpdateTreeStrategy,
)
from rl.utils import FineTuningCollate, get_actor_value_model, make_alpha_zero_actor
from scheduler.scheduler_lib import SchedulerConfig

from tensordict import LazyStackedTensorDict, TensorDict
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter

from torchmetrics import MeanMetric
from torchmetrics.functional.classification import binary_hamming_distance
from torchrl._utils import accept_remote_rref_invocation
from torchrl.data.replay_buffers import RemoteReplayBuffer
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage, ListStorage
from torchrl.data.replay_buffers.writers import RoundRobinWriter
from torchrl.envs.transforms import Reward2GoTransform
from tqdm import tqdm


@dataclass
class FineTuningConfig:
    # Model training configs
    model: ModelConfig = MISSING
    action_activation: ActivationConfig = MISSING
    target_activation: ActivationConfig = MISSING
    value_activation: ActivationConfig = MISSING
    value_loss: LossConfig = MISSING
    policy_loss: LossConfig = MISSING
    optimizer: OptimizerConfig = MISSING
    scheduler: SchedulerConfig = MISSING

    # hparams
    epochs: int = 10
    batch_size: int = 64
    use_amp: bool = False
    grad_norm_clip: float = 1.0
    seed: int = 2024
    training_steps: int = 10000
    epochs: int = 3
    gamma: float = 0.95
    capacity: int = 10000
    weight_sharing_frequency: int = 200

    # logging
    save_every: int = 1
    log_dir: str | None = None
    model_file: str | None = None
    step_logging: int = 500
    mlflow_uri: str = "http://localhost:5001"

    # Load pretrained model
    pretrained: bool = False
    path: str | None = None
    config_path: str | None = None
    model_path: str | None = None
    model_name: str | None = None

    # Policy
    AZ: AlphaZeroConfig = MISSING
    HER: HERConfig = MISSING

    # Environment
    env: AIGEnvConfig = MISSING

    # data
    dataset: TTDatasetConfig = MISSING
    buffer: ReplayBufferConfig = MISSING
    start_training_buffer_size: int = 100
    nprefetch_batch: int = 10

    max_nodes: int = 30

    # Distributed Utilities
    nnodes: int = 1

    master_addr: str = "localhost"
    master_port_rpc: str = "29500"
    master_port_params: str = "29501"
    backend: str = "nccl"

    # Trainers
    trainer_global_world_size: int = 1
    trainer_local_world_size: int = 1
    trainer_rank_offset: int = 0

    # Replay buffer
    replay_buffer_global_world_size: int = 1
    replay_buffer_local_world_size: int = 1
    replay_buffer_rank_offset: int = MISSING

    # Data collectors
    data_collectors_global_world_size: int = 7
    data_collectors_local_world_size: int = 7
    data_collectors_rank_offset: int = 1

    # RPC
    retry_limit: int = 2
    retry_delay_sec: int = 3

    debug: bool = False


class CollectorNode:
    """Data collector node responsible for collecting experiences used for learning.
    Args:
        cfg (FineTuningConfig): the configuration object for the fine-tuning process
        replay_buffer (rpc.RRef): the RRef associated with the construction of the replay buffer
    """

    def __init__(
        self,
        cfg: FineTuningConfig,
        replay_buffer: rpc.RRef,
    ) -> None:
        self.id = rpc.get_worker_info().id
        self.device = torch.device("cpu")
        self.set_local_device()
        self.cfg = cfg
        self.replay_buffer = replay_buffer
        self.logger = logging.getLogger(__name__ + f"{self.id}")
        self.model = hydra.utils.instantiate(self.cfg.model).to(self.device)
        self.model.eval()
        self.keys = [
            "nodes",
            "action_dist",
            "action_mask",
            "target",
            "num_inputs",
            "reward",
        ]
        self.dataset = hydra.utils.instantiate(
            self.cfg.dataset, embedding_size=self.model.embedding_size
        ).get_slice(self.id, self.cfg.data_collectors_global_world_size)
        # Write your collector here
        #  self.collector = SyncDataCollector(...)
        self.stop_collect_flag = False
        self.logger.info("Data Collector Node constructed")

    def set_local_device(self):
        device_map = {}
        if os.environ.get("DEVICE_MAP", "") != "":
            device_map_str = os.environ.get("DEVICE_MAP")
            device_map_list = device_map_str.split(",")  # type: ignore
            for device in device_map_list:
                rank, device_id = device.split(":")
                device_map[int(rank)] = device_id
        if self.id in device_map:
            self.device = torch.device(f"cuda:{device_map[self.id]}")
            os.environ["CUDA_VISIBLE_DEVICES"] = device_map[self.id]

    def extend_buffer(self, td: TensorDict) -> rpc.RRef:
        """Function that collects data and populates the replay buffer."""
        return rpc.remote(
            self.replay_buffer.owner(),
            ReplayBufferNode.extend,
            args=(
                self.replay_buffer,
                td,
            ),
            timeout=120.0,
        )

    @accept_remote_rref_invocation
    @torch.no_grad()
    def collect(self):
        """Method that begins experience collection (we just generate random TensorDicts in this example).
        `accept_remote_rref_invocation` enables this method to be invoked remotely provided the class instantiation
        `rpc.RRef` is provided in place of the object reference.
        """
        # Create an environment
        aig_env = AIGEnv(
            embedding_size=self.cfg.model.embedding_size,
            const_node=self.cfg.env.const_node,
            # reward_type=self.cfg.env.reward_type,
            reward_type="simple",
        ).to(self.device)
        aig_env.state = aig_env.state.to(self.device)

        # Create an actor value agent
        actor_value_agent = get_actor_value_model(self.model)

        # Create a HER transformer
        her_transform = HindsightExperienceReplayTransform(
            SubGoalSampler=HERSubGoalSampler(self.cfg.HER.num_sub_goals),
            SubGoalAssigner=AIGSubGoalAssigner(),
            RewardTransform=AIGRewardTransform(self.cfg.gamma),
            PostTransaform=AIGNegateTarget()
            if self.cfg.HER.allow_inv_experiences
            else None,
        )

        self.stop_collect_flag = False
        while not self.stop_collect_flag:
            for td in self.dataset:
                td = td.to(self.device, non_blocking=True)
                aig_env.reset(td)

                # Define AlphaZero policy
                tree_strategy = UpdateTreeStrategy(
                    value_network=actor_value_agent.get_value_operator(),
                    use_value_network=self.cfg.AZ.use_value_network,
                )

                expansion_strategy = AlphaZeroExpansionStrategy(
                    policy_module=actor_value_agent.get_policy_operator(),
                )

                selection_strategy = PuctSelectionPolicy(self.cfg.AZ.c_puct)

                exploration_strategy = ActionExplorationModule()

                mcts_policy = MctsPolicy(
                    expansion_strategy=expansion_strategy,
                    selection_strategy=selection_strategy,
                    exploration_strategy=exploration_strategy,
                )

                if self.cfg.AZ.dirichlet_alpha is not None:
                    noise_module = DirichletNoiseModule(self.cfg.AZ.dirichlet_alpha)
                else:
                    noise_module = None

                policy = SimulatedSearchPolicy(
                    policy=mcts_policy,
                    tree_updater=tree_strategy,
                    env=aig_env,
                    num_simulations=self.cfg.AZ.num_simulations,
                    simulation_max_steps=self.cfg.AZ.simulation_max_steps,
                    max_steps=self.cfg.max_nodes,
                    noise_module=noise_module,
                    reutilize_tree=self.cfg.AZ.reutilize_tree,
                )

                # Execute the policy
                rollout = aig_env.rollout(
                    policy=policy, max_steps=self.cfg.max_nodes, return_contiguous=False
                )

                # Compute the action distribution from the MCTS nodes
                mcts_nodes = LazyStackedTensorDict(*policy.root_list)

                for i in range(rollout.batch_size[0]):  # type: ignore
                    action_count = mcts_nodes[i]["children_visits"]
                    rollout[i]["action_dist"] = action_count / torch.sum(action_count)  # type: ignore

                if not rollout[-1]["next", "terminated"]:  # type: ignore
                    last_generated_tt = rollout[-1]["nodes"][-1, :]  # type: ignore
                    reward = binary_hamming_distance(last_generated_tt.unsqueeze(0), rollout[-1]["target"], multidim_average="samplewise").item()  # type: ignore
                    rollout[-1]["next", "reward"][0] = -min(reward, 1 - reward)  # type: ignore
                    # rollout["next", "reward"][-1] = -1.0

                discounted_reward = Reward2GoTransform(
                    gamma=self.cfg.gamma,
                    in_keys=("next", "reward"),  # type: ignore
                    out_keys="reward",  # type: ignore
                )
                rollout = discounted_reward.inv(rollout)  # type: ignore
                rollout[0]["reward"][0] = 0.0  # type: ignore
                self.extend_buffer(rollout.select(*self.keys).to("cpu"))  # type: ignore

                # HER
                # if not rollout[-1]["next", "terminated"]:  # type: ignore
                #     her_experiences = her_transform.her_augmentation(rollout)  # type: ignore
                #     self.extend_buffer(her_experiences.select(*self.keys).to("cpu"))  # type: ignore

    @accept_remote_rref_invocation
    def stop_collect(self):
        # Stop the data collection process
        self.stop_collect_flag = True

    @accept_remote_rref_invocation
    def _init_parameter_sharing(self):
        self.group = torch.distributed.init_process_group(
            backend=self.cfg.backend,
            init_method=f"tcp://{self.cfg.master_addr}:{self.cfg.master_port_params}",
            rank=rpc.get_worker_info().id,
            world_size=(
                self.cfg.trainer_global_world_size
                + self.cfg.data_collectors_global_world_size
            ),
            device_id=self.device if self.device.type == "cuda" else None,  # type: ignore
        )

    @accept_remote_rref_invocation
    def cleanup(self):
        torch.distributed.destroy_process_group(self.group)

    @accept_remote_rref_invocation
    def _receive_weights(self, block=False):
        for param in self.model.parameters():
            torch.distributed.broadcast(param.data, 0, async_op=True, group=self.group)
        if block:
            torch.distributed.barrier(group=self.group)


class ReplayBufferNode(RemoteReplayBuffer):
    """Experience replay buffer node that is capable of accepting remote connections. Being a `RemoteReplayBuffer`
    means all of its public methods are remotely invokable using `torch.rpc`.
    Using a LazyMemmapStorage is highly advised in distributed settings with shared storage due to the lower serialisation
    cost of MemoryMappedTensors as well as the ability to specify file storage locations which can improve ability to recover from node failures.
    Args:
        capacity (int): the maximum number of elements that can be stored in the replay buffer.
    """

    def __init__(
        self,
        capacity: int,
        batch_size: int,
        prefetch: int,
        trainer_device_id: int | None,
    ):
        trainer_device = (
            torch.device(f"cuda:{trainer_device_id}")
            if trainer_device_id is not None
            else torch.device("cpu")
        )
        trans = (
            lambda x: x.to(trainer_device, non_blocking=True)
            if trainer_device_id is not None
            else None
        )
        super().__init__(
            storage=ListStorage(
                max_size=capacity,
            ),
            writer=RoundRobinWriter(),
            collate_fn=FineTuningCollate(trainer_device),
            batch_size=batch_size,
            prefetch=prefetch,
            # transform=trans,
        )


class FineTuningNode:
    """Trainer node responsible for learning from experiences sampled from an experience replay buffer."""

    def __init__(
        self,
        cfg: FineTuningConfig,
        replay_buffer_node: str = "ReplayBuffer",
    ) -> None:
        self.cfg = cfg
        self.replay_buffer_node = replay_buffer_node
        self.id = rpc.get_worker_info().id
        self.device = torch.device("cpu")
        self.set_local_device()
        self.start_time: float = time.time()

        # parameters
        self.save_every = self.cfg.save_every
        self.n_total_epochs = self.cfg.epochs
        self.n_seen_points = 0
        self.epochs_run = 0
        # self.return_action_mask = self.cfg.return_action_mask
        self.step_logging = self.cfg.step_logging
        self.weight_sharing_frequency = self.cfg.weight_sharing_frequency
        self.log_dir = self.cfg.log_dir
        self.model_file = self.cfg.model_file
        self.local_rank = self.id
        self.bar_pos = self.id

        # logging
        self.logger = logging.getLogger(__name__ + f"[{self.id}]")
        self.policy_loss_tracker = MeanMetric(sync_on_compute=False)
        self.value_loss_tracker = MeanMetric(sync_on_compute=False)
        self.reward_tracker = MeanMetric(sync_on_compute=False)

        # lauch replay buffer and collectors
        self.replay_buffer = self._create_replay_buffer()
        self._create_data_collectors()
        self._init_parameter_sharing()

        # instantiate model
        self.model = hydra.utils.instantiate(self.cfg.model)

        # load model from checkpoint
        if self.log_dir is not None:
            self._load_snapshot()

        if self.log_dir is None:
            self.log_dir = self._get_dir_name(self.model)
        self.writer = SummaryWriter(log_dir=self.log_dir)
        self.writer.add_text("Config", OmegaConf.to_yaml(cfg))
        OmegaConf.save(cfg, f"{self.writer.log_dir}/config.yaml")
        os.environ["HTTP_PROXY"] = ""
        mlflow.set_tracking_uri(self.cfg.mlflow_uri)
        mlflow.set_experiment(
            f"Circuit Generation [{int(math.log2(self.cfg.model.embedding_size))}-input]"
        )  # Model name embedding size

        # Broadcast model weights
        self.model = self.model.to(self.device)
        self.share_weights(block=True)

        # Start collection
        self._launch_data_collectors()

        # Create training stuff
        self.optimizer = hydra.utils.instantiate(
            self.cfg.optimizer, params=self.model.parameters()
        )
        self.lr_scheduler = hydra.utils.instantiate(
            self.cfg.scheduler, optimizer=self.optimizer
        )

        self.action_activation = kldiv_activation
        self.target_activation = normalize_action
        self.value_activation = tanh_activation

        self.policy_criterion = hydra.utils.instantiate(self.cfg.policy_loss)
        self.value_criterion = hydra.utils.instantiate(self.cfg.value_loss)

        # initialize train states
        self.policy_criterion.to(self.device)
        self.value_criterion.to(self.device)
        self.policy_loss_tracker.to(self.device)
        self.value_loss_tracker.to(self.device)
        self.reward_tracker.to(self.device)
        if self.cfg.use_amp:
            self.scaler = torch.amp.GradScaler()  # type: ignore

    @staticmethod
    def mlflow_run_context(func: Callable):
        def wrapper(self, *args, **kwargs):
            if self.is_main_process():
                raw_model = (
                    self.model.module if hasattr(self.model, "module") else self.model
                )
                with mlflow.start_run(
                    run_name=f"{raw_model.__class__.__name__}_heads={self.cfg.model.n_heads}",
                    nested=True,
                    tags={
                        "embedding_size": str(self.cfg.model.embedding_size),
                        "n_heads": str(self.cfg.model.n_heads),
                        "num_layers": str(self.cfg.model.n_layers),
                    },
                    log_system_metrics=True,
                ):
                    self._log_params()
                    mlflow.log_param(
                        "model_parameters",
                        sum(p.numel() for p in raw_model.parameters()),
                    )
                    mlflow.log_param("finetuned", True)
                    mlflow.log_text(OmegaConf.to_yaml(self.cfg), "config.yaml")
                    mlflow.log_text(self.log_dir, "model_path.txt")
                    # mlflow.log_input(self.train_dataset, "training")
                    # mlflow.log_input(self.eval_dataset, "evaluation")

                    mlflow.pytorch.log_model(
                        raw_model,
                        "models",
                        registered_model_name=f"{raw_model.__class__.__name__}[{self.cfg.model.embedding_size}]",
                    )
                    return func(self, *args, **kwargs)
            else:
                return func(self, *args, **kwargs)

        return wrapper

    @mlflow_run_context
    def train(self) -> None:
        self.await_replay_buffer()

        self.logger.info("Starting Training")
        for epoch in range(self.epochs_run, self.n_total_epochs):
            epoch += 1
            self.model.train()
            self._run_epoch(epoch, train=True)

            # log train info
            self._log_info(epoch, True)

            # broadcast new model weights
            self.share_weights(block=True)

            # save train model
            if epoch % self.save_every == 0:
                self._save_snapshot(epoch)

            # eval run
            # self.model.eval()
            # with torch.no_grad():
            #     if self.eval_loader:
            #         self._run_epoch(epoch, self.eval_loader, train=False)

            # # log eval info
            # self._log_info(epoch, False)

            # test
            # if (
            #     self.is_main_process()
            #     # and self.test_generation
            #     and epoch % self.test_frequency == 0
            # ):
            #     self.generation_evaluation(epoch)

            # adjust learning rate
            self.lr_scheduler.step()

    def await_replay_buffer(self) -> None:
        # Wait until the buffer has elements
        while (
            rpc.rpc_sync(
                self.replay_buffer.owner(),
                ReplayBufferNode.__len__,
                args=(self.replay_buffer,),
            )
            < self.cfg.start_training_buffer_size
        ):
            continue

    def _model_loss(self, td: TensorDict) -> torch.Tensor:
        # forward pass
        action_logits, value = self.model(
            td["nodes"], td["causal_mask"], td["attention_mask"], get_value=True
        )

        # normalize target
        tgt = self.target_activation(td["action_dist"])

        # filter out invalid actions
        action = (
            action_logits.flatten(1)
            .masked_fill_(~td["action_mask"], torch.finfo(action_logits.dtype).min)
            .log_softmax(dim=-1)
        )
        # action = self.action_activation(action_logits)

        # compute policy loss
        policy_loss = self.policy_criterion(action, tgt)
        self.policy_loss_tracker.update(policy_loss)

        loss = policy_loss

        # Scale value
        value = self.value_activation(value)

        # compute value loss
        value_loss = self.value_criterion(value, td["reward"].unsqueeze(-1))
        self.value_loss_tracker.update(value_loss)
        self.reward_tracker.update(td["reward"].mean())

        loss += value_loss

        return loss

    def _run_epoch(
        self,
        epoch: int,
        train: bool = True,
    ):
        self.policy_loss_tracker.reset()
        self.value_loss_tracker.reset()
        self.reward_tracker.reset()

        desc = self._get_desc(epoch, train, False)
        with tqdm(
            unit="batch", total=self.cfg.training_steps, desc=desc, position=0
        ) as progress_bar:
            for iter in range(1, self.cfg.training_steps + 1):
                batch = rpc.rpc_sync(
                    self.replay_buffer.owner(),
                    ReplayBufferNode.sample,
                    args=(self.replay_buffer, None),
                )
                self._run_batch(batch, train)
                progress_bar.update()

                if iter % self.weight_sharing_frequency == 0:
                    self.share_weights(block=True)

                # update bar every 100 batches
                if iter % 100 == 0:
                    desc = self._get_desc(epoch, train, True)
                    progress_bar.set_description(desc)
                    progress_bar.refresh()

                if train and iter % self.step_logging == 0:
                    self._step_log_info()

            # log final loss
            desc = self._get_desc(epoch, train, True)
            progress_bar.set_description(desc)
            progress_bar.refresh()

    def _run_batch(self, td: TensorDict, train: bool):
        with torch.set_grad_enabled(train), torch.amp.autocast(  # type: ignore
            "cuda", dtype=torch.float16, enabled=(self.cfg.use_amp)  # type: ignore
        ):
            # model forward pass
            td = td.to(self.device, non_blocking=True)
            loss = self._model_loss(td)

            # optimization step
            if train:
                self.n_seen_points += (
                    len(td["nodes"]) * self.cfg.trainer_global_world_size
                )
                self.optimizer.zero_grad(set_to_none=True)
                if self.cfg.use_amp:
                    self.scaler.scale(loss).backward()
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), self.cfg.grad_norm_clip
                    )
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), self.cfg.grad_norm_clip
                    )
                    self.optimizer.step()

    def _create_replay_buffer(self) -> rpc.RRef:
        def connect():
            replay_buffer_info = rpc.get_worker_info(self.replay_buffer_node)
            buffer_rref = rpc.remote(
                replay_buffer_info,
                ReplayBufferNode,
                args=(
                    self.cfg.capacity,
                    self.cfg.batch_size,
                    self.cfg.nprefetch_batch,
                    self.device.index,
                ),
                timeout=120.0,
            )
            self.logger.info(f"Connected to replay buffer {replay_buffer_info}")
            return buffer_rref

        while True:
            try:
                return connect()
            except Exception as e:
                self.logger.info(f"Failed to connect to replay buffer: {e}")
                time.sleep(self.cfg.retry_delay_sec)

    def _create_data_collectors(self) -> None:
        self.data_collectors = []
        self.data_collector_infos = []
        # discover launched data collector nodes (with retry to allow collectors to dynamically join)
        def connect(n, retry):
            data_collector_info = rpc.get_worker_info(f"DataCollector{n}")
            self.logger.info(
                f"Data collector info: {data_collector_info}-retry={retry}"
            )
            dc_ref = rpc.remote(
                data_collector_info,
                CollectorNode,
                args=(
                    self.cfg,
                    self.replay_buffer,
                ),
                timeout=120.0,
            )
            self.logger.info("Finished connecting")
            self.data_collectors.append(dc_ref)
            self.data_collector_infos.append(data_collector_info)

        for n in range(
            self.cfg.data_collectors_rank_offset,
            self.cfg.data_collectors_rank_offset
            + self.cfg.data_collectors_global_world_size,
        ):
            for retry in range(self.cfg.retry_limit):
                self.logger.info(
                    f"Connecting to DataCollector{n-self.cfg.data_collectors_rank_offset+1}/{self.cfg.data_collectors_global_world_size}, retry={retry}"
                )
                try:
                    connect(n, retry)
                    break
                except Exception as e:
                    self.logger.info(
                        f"Failed to connect to DataCollector{n+1} with {retry} retries (err={e})"
                    )
                    time.sleep(self.cfg.retry_delay_sec)
            else:
                raise Exception

    def _init_parameter_sharing(self):
        for collector, data_collector_info in zip(
            self.data_collectors, self.data_collector_infos
        ):
            rpc.rpc_async(
                data_collector_info,
                CollectorNode._init_parameter_sharing,
                args=(collector,),
            )
        self.group = torch.distributed.init_process_group(
            backend=self.cfg.backend,
            init_method=f"tcp://{self.cfg.master_addr}:{self.cfg.master_port_params}",
            rank=rpc.get_worker_info().id,
            world_size=(
                self.cfg.trainer_global_world_size
                + self.cfg.data_collectors_global_world_size
            ),
            device_id=self.device if self.device.type == "cuda" else None,
        )

    def _launch_data_collectors(self) -> None:
        for collector, data_collector_info in zip(
            self.data_collectors, self.data_collector_infos
        ):
            rpc.remote(
                data_collector_info,
                CollectorNode.collect,
                args=(collector,),
                timeout=120.0,
            )

    def set_local_device(self):
        device_map = {}
        if os.environ.get("DEVICE_MAP", "") != "":
            device_map_str = os.environ.get("DEVICE_MAP")
            device_map_list = device_map_str.split(",")  # type: ignore
            for device in device_map_list:
                rank, device_id = device.split(":")
                device_map[int(rank)] = device_id
        if self.id not in device_map:
            self.device = torch.device("cpu")
        else:
            self.device = torch.device(f"cuda:{device_map[self.id]}")
            os.environ["CUDA_VISIBLE_DEVICES"] = device_map[self.id]

    def stop_collect(self):
        for collector, data_collector_info in zip(
            self.data_collectors, self.data_collector_infos
        ):
            rpc.rpc_async(
                data_collector_info,
                CollectorNode.stop_collect,
                args=(collector,),
            )

    def share_weights(self, block: bool = True):
        # Broadcast the model weights to all data collectors in the group
        for param in self.model.parameters():
            torch.distributed.broadcast(param.data, 0, async_op=True, group=self.group)

        # Asynchrnously start the process of receiving weights
        for collector, data_collector_info in zip(
            self.data_collectors, self.data_collector_infos
        ):
            rpc.rpc_async(
                data_collector_info,
                CollectorNode._receive_weights,
                args=(collector, block),
            )

        # If blocking, wait for all data collectors to receive the weights
        if block:
            torch.distributed.barrier(group=self.group)

    def cleanup(self):
        for collector, data_collector_info in zip(
            self.data_collectors, self.data_collector_infos
        ):
            rpc.rpc_async(
                data_collector_info,
                CollectorNode.cleanup,
                args=(collector,),
            )
        torch.distributed.destroy_process_group(self.group)

    ######################### Logging Methods ###########################
    @staticmethod
    def main_process_only(func: Callable):
        def wrapper(self, *args, **kwargs):
            if self.is_main_process():
                return func(self, *args, **kwargs)

        return wrapper

    def _get_desc(
        self,
        epoch: int,
        train: bool,
        include_loss: bool = False,
    ) -> str:
        step_type = "Train" if train else "Eval "

        desc = (
            f"[{step_type}][Rank: {self.local_rank}][Epoch: {epoch:4} / {self.n_total_epochs}]"
            f"[Total time: {self._get_elapsed_time()}]"
        )
        if include_loss:
            desc += f"[Policy Loss: {round(self.policy_loss_tracker.compute().item(), 3):3f}]"

            desc += (
                f"[Value Loss: {round(self.value_loss_tracker.compute().item(), 3):3f}]"
            )

            desc += f"[Reward: {round(self.reward_tracker.compute().item(), 3):3f}]"

        return desc

    def _get_elapsed_time(self) -> str:
        return time_formatter(
            time.time() - self.start_time,
            show_ms=False,
        )

    @staticmethod
    def _get_dir_name(model: torch.nn.Module) -> str:
        i = 0
        dir_name = (
            f"runs/fine_{model.__class__.__name__}"
            f"_emb={model.embedding_size}_"
            f"heads={model.n_heads}_v"
        )

        def aux_get_dirname() -> str:
            return f"{dir_name}{i:04d}"

        while os.path.exists(aux_get_dirname()):
            i += 1
        return aux_get_dirname()

    @main_process_only
    def _log_info(self, epoch: int, train: bool):
        if train:
            step_type = "train"
        else:
            step_type = "eval"

        epoch_policy_loss = self.policy_loss_tracker.compute().item()
        self.writer.add_scalar(f"Policy_Loss/{step_type}", epoch_policy_loss, epoch)
        mlflow.log_metric(f"Policy_Loss/{step_type}", epoch_policy_loss, epoch)

        epoch_value_loss = self.value_loss_tracker.compute().item()
        self.writer.add_scalar(f"Value_Loss/{step_type}", epoch_value_loss, epoch)
        mlflow.log_metric(f"Value_Loss/{step_type}", epoch_policy_loss, epoch)

        epoch_reward = self.reward_tracker.compute().item()
        self.writer.add_scalar(f"Reward/{step_type}", epoch_reward, epoch)
        mlflow.log_metric(f"Reward/{step_type}", epoch_reward, epoch)

    @main_process_only
    def _step_log_info(self):
        train_epoch_policy_loss = self.policy_loss_tracker.compute().item()
        self.writer.add_scalar(
            "Policy_Loss/train(total)",
            train_epoch_policy_loss,
            self.n_seen_points,
        )
        mlflow.log_metric(
            "Policy_Loss/train x-total", train_epoch_policy_loss, self.n_seen_points
        )

        train_epoch_value_loss = self.value_loss_tracker.compute().item()
        self.writer.add_scalar(
            "Value_Loss/train(x-total)",
            train_epoch_value_loss,
            self.n_seen_points,
        )
        mlflow.log_metric(
            "Value_Loss/train x-total", train_epoch_policy_loss, self.n_seen_points
        )

        train_epoch_reward = self.reward_tracker.compute().item()
        self.writer.add_scalar(
            "Reward/train(x-total)",
            train_epoch_reward,
            self.n_seen_points,
        )
        mlflow.log_metric(
            "Reward/train x-total", train_epoch_reward, self.n_seen_points
        )

    @main_process_only
    def _log_hparams(self):
        self.writer.add_hparams(
            {
                "lr": self.cfg.optimizer.lr,
                "embedding_size": self.cfg.model.embedding_size,
                "n_heads": self.cfg.model.n_heads,  # type: ignore
                "num_layers": self.cfg.model.n_layers,  # type: ignore
                "batch size": self.cfg.batch_size,
                "epochs": self.cfg.epochs,
            },
            {
                "hparam/policy_loss": self.policy_loss_tracker.compute().item(),
                "hparam/value_loss": self.value_loss_tracker.compute().item(),
            },
        )  # type: ignore

    def _log_params(self):
        def _explore_recursive(parent_name, element):
            if isinstance(element, DictConfig):
                for k, v in element.items():
                    if isinstance(v, DictConfig) or isinstance(v, ListConfig):
                        _explore_recursive(f"{parent_name}.{k}", v)
                    else:
                        mlflow.log_param(f"{parent_name}.{k}", v)
            elif isinstance(element, ListConfig):
                for i, v in enumerate(element):
                    mlflow.log_param(f"{parent_name}.{i}", v)
            else:
                mlflow.log_param(parent_name, element)

        for param_name, element in self.cfg.model.items():  # type: ignore
            _explore_recursive(param_name, element)

    def is_main_process(self) -> bool:
        return self.id == 0

    def _load_snapshot(self):
        if self.model_file == "latest":
            model_path = max(glob.glob(os.path.join(self.log_dir, "*.pt")), key=os.path.getmtime)  # type: ignore
        else:
            model_path = os.path.join(self.log_dir, self.model_file)  # type: ignore

        try:
            snapshot = load_snapshot(model_path)
        except FileNotFoundError:
            self.logger.info("Snapshot not found. Training model from scratch")
            return

        self.model.load_state_dict(snapshot.model_state)
        self.optimizer.load_state_dict(snapshot.optimizer_state)
        self.epochs_run = snapshot.finished_epoch
        self.logger.info(f"Loaded model from {model_path} at Epoch {self.epochs_run}")

    @main_process_only
    def _save_snapshot(self, epoch, name: str | None = None):
        if name is None:
            path = os.path.join(self.log_dir, f"model_{epoch}.pt")  # type: ignore
        else:
            path = os.path.join(self.log_dir, f"{name}.pt")  # type: ignore

        # capture snapshot
        # save_snapshot(self.model, self.optimizer, epoch, path)
        mlflow_save_snapshot(self.model, self.optimizer, epoch, path)

        self.logger.info(f"Saved model at {path}")

        ######################### Testing Methods ###########################

    @staticmethod
    def test_truth_tables(
        model: torch.nn.Module,
        data_queue: mp.Queue,
        result_queue: mp.Queue,
        cfg,
    ):
        results = 0
        while not data_queue.empty():
            td = data_queue.get().clone()
            aig_env = AIGEnv(model.embedding_size, cfg.const_node, cfg.reward_type)
            aig_env.reset(td)
            success = generate_AIG(model, aig_env, cfg.max_nodes, cfg.AZ)
            if success:
                results += 1
        result_queue.put(results)

    @staticmethod
    def generation_evaluation_wrapper(
        model: torch.nn.Module,
        data_queue: mp.Queue,
        cfg: FineTuningConfig,
        epoch: int,
        q_size: int,
        log_dir: str,
    ) -> None:
        results_queue = mp.Queue()

        if cfg.AZ_workers > 1:  # type: ignore
            processes = []
            for _ in range(cfg.AZ_workers):  # type: ignore
                p = mp.Process(
                    target=FineTuningNode.test_truth_tables,
                    args=(
                        model,
                        data_queue,
                        results_queue,
                        cfg,
                    ),
                )
                p.start()
                processes.append(p)

            for p in processes:
                p.join()
        else:
            FineTuningNode.test_truth_tables(
                model,
                data_queue,
                results_queue,
                cfg,
            )

        success = 0
        while not results_queue.empty():
            success += results_queue.get()

        SummaryWriter(log_dir).add_scalar("Accuracy", success / q_size, epoch)
        mlflow.log_metric("Accuracy", success / q_size, epoch)

    def generation_evaluation(
        self,
        epoch: int,
    ) -> None:
        return

        raw_model = self.model.module if hasattr(self.model, "module") else self.model
        self.model.load_state_dict(raw_model.state_dict())
        self.model.to("cpu").share_memory().eval()

        # spawn can be problematic
        ctx = mp.get_context("fork")
        data_queue = ctx.Queue()

        for tt in self.generation_tds:
            data_queue.put_nowait(tt)

        q_size = len(self.generation_tds)
        p = ctx.Process(
            target=self.generation_evaluation_wrapper,
            args=(
                self.model,
                data_queue,
                self.cfg,
                epoch,
                q_size,
                self.log_dir,
            ),
        )
        p.start()
        self.generation_processes.append(p)
