import glob
import logging
import math
import os

# from multiprocessing.sharedctypes import Synchronized
import time
from abc import ABC, abstractmethod
from collections import OrderedDict
from dataclasses import asdict, dataclass
from typing import Any, Callable, Dict, Tuple

import mlflow
import torch
import torch.distributed
import torch.multiprocessing as mp
from accelerate import Accelerator

from data.data_lib import AIGDatasetConfig
from data.dataset import (
    AIG_Dataset_Collection,
    AIGBatchSampler,
    DistributedBatchSampler,
    EvalAIGBatchSampler,
)
from data.pyaig.aig_env import AIGEnv
from loss.loss_lib import LossConfig

from matplotlib import transforms

from matplotlib.pylab import permutation
from misc_utils import time_formatter
from model.model_lib import ActivationConfig, ModelConfig
from model.utils import (
    load_snapshot,
    mlflow_load_snapshot,
    mlflow_save_snapshot,
    save_snapshot,
    Snapshot,
    TrainCollate,
)
from numpy import save
from omegaconf import DictConfig, ListConfig, MISSING, OmegaConf
from optimizer.grokfast import gradfilter_ema
from optimizer.optimizer_lib import OptimizerConfig
from rl import generate_AIG

from rl.mcts_policy import AlphaZeroConfig
from scheduler.scheduler_lib import SchedulerConfig
from sympy import comp
from tensordict import TensorDict
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter

from torchmetrics import MeanMetric
from tqdm import tqdm


@dataclass
class TrainConfig:
    # model
    model: ModelConfig = MISSING
    action_activation: ActivationConfig = MISSING
    target_activation: ActivationConfig = MISSING
    value_activation: ActivationConfig = MISSING
    policy_loss: LossConfig = MISSING
    value_loss: LossConfig = MISSING
    optimizer: OptimizerConfig = MISSING
    scheduler: SchedulerConfig = MISSING

    # hparams
    epochs: int = 1
    batch_size: int = 64
    use_amp: bool = False
    use_grokfast: bool = False
    grad_norm_clip: float = 1.0
    train_value: bool = False
    seed: int = 42

    # logging
    save_every: int = 1
    log_dir: str | None = None
    model_file: str | None = None
    step_logging: int = 500
    mlflow_uri: str = "http://localhost:5001"

    # envirornment
    const_node: bool = True
    return_action_mask: bool = True
    get_causal_mask: bool = True
    reward_type: str = "simple"

    # device
    use_ddp: bool = False
    use_compile: bool = True
    use_accelerator: bool = True
    device: str = "cpu"
    master_rank: int = 0

    # data
    dataset: AIGDatasetConfig = MISSING
    dataloader_workers: int = 8
    train_split: float = 0.9
    debug: bool = False
    train_negation_prob: float = 0.5
    train_permutation_prob: float = 0.5
    eval_negation_prob: float = 0.5
    eval_permutation_prob: float = 0.5

    # Evaluation
    test_generation: bool = False
    test_frequency: int = 10
    test_limit: int = 100
    max_nodes: int = 20
    AZ: AlphaZeroConfig = MISSING
    AZ_workers: int = 1


class Trainer:
    def __init__(
        self,
        cfg: TrainConfig,
        model: torch.nn.Module,
        test_model: torch.nn.Module,
        policy_criterion: torch.nn.Module,
        value_criterion: torch.nn.Module,
        optimizer: Optimizer,
        lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
        action_activation: Callable,
        target_activation: Callable,
        value_activation: Callable,
        train_dataset: AIG_Dataset_Collection,
        eval_dataset: AIG_Dataset_Collection,
        generation_dataset: Any,
    ):
        self.config = cfg
        self.model = model
        self.start_time: float = time.time()

        # parameters
        self.train_value = self.config.train_value
        self.save_every = self.config.save_every
        self.n_total_epochs = self.config.epochs
        self.n_seen_points = 0
        self.epochs_run = 0
        self.test_generation = self.config.test_generation
        self.return_action_mask = self.config.return_action_mask
        self.step_logging = self.config.step_logging
        self.log_dir = self.config.log_dir
        self.model_file = self.config.model_file
        self.test_frequency = self.config.test_frequency
        self.generation_tds = generation_dataset
        self.test_model = test_model
        self.batch_size = self.config.batch_size

        # set torchrun variables
        local_rank = os.environ.get("LOCAL_RANK")
        if local_rank is None:
            local_rank = 0
        self.local_rank = int(local_rank)
        self.bar_pos = int(self.local_rank)
        world_size = os.environ.get("WORLD_SIZE")
        if world_size is None:
            world_size = 1
        self.world_size = int(world_size)
        # self.global_rank = int(os.environ["RANK"])

        # logging
        self.logger = logging.getLogger(__name__ + f"[{self.local_rank}]")
        self.policy_loss_tracker = MeanMetric()
        self.value_loss_tracker = MeanMetric()

        # load model from checkpoint
        if self.log_dir is not None:
            self._load_snapshot()

        # Only main process logs the loss
        if self.is_main_process():
            if self.log_dir is None:
                self.log_dir = self._get_dir_name(self.model)
            self.writer = SummaryWriter(log_dir=self.log_dir)
            self.writer.add_text("Config", OmegaConf.to_yaml(cfg))
            OmegaConf.save(cfg, f"{self.writer.log_dir}/config.yaml")
            os.environ["HTTP_PROXY"] = ""
            mlflow.set_tracking_uri(self.config.mlflow_uri)
            mlflow.set_experiment(
                f"Circuit Generation [{int(math.log2(self.config.model.embedding_size))}-input]"
            )  # Model name embedding size

            # Need to log the dataset
            # There is evaluation tracking as well

        self.generation_processes = []

        # dataset
        self.train_dataset = train_dataset
        self.train_loader = self._prepare_dataloader(self.train_dataset, train=True)
        self.eval_dataset = eval_dataset
        self.eval_loader = (
            self._prepare_dataloader(self.eval_dataset, train=False)
            if eval_dataset
            else None
        )

        self._grads = None
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self.action_activation = action_activation
        self.target_activation = target_activation
        self.value_activation = value_activation

        self.policy_criterion = policy_criterion
        self.value_criterion = value_criterion

    ######################### Training Methods ###########################
    @staticmethod
    def mlflow_run_context(func: Callable):
        def wrapper(self, *args, **kwargs):
            if self.is_main_process():
                raw_model = (
                    self.model.module if hasattr(self.model, "module") else self.model
                )
                with mlflow.start_run(
                    run_name=f"{raw_model.__class__.__name__}_heads={self.config.model.n_heads}",
                    nested=True,
                    tags={
                        "embedding_size": str(self.config.model.embedding_size),
                        "n_heads": str(self.config.model.n_heads),
                        "num_layers": str(self.config.model.n_layers),
                    },
                    log_system_metrics=True,
                ):
                    self._log_params()
                    mlflow.log_param(
                        "model_parameters",
                        sum(p.numel() for p in raw_model.parameters()),
                    )
                    mlflow.log_param("finetuned", False)
                    mlflow.log_text(OmegaConf.to_yaml(self.config), "config.yaml")
                    mlflow.log_text(self.log_dir, "model_path.txt")
                    # mlflow.log_input(self.train_dataset, "training")
                    # mlflow.log_input(self.eval_dataset, "evaluation")

                    mlflow.pytorch.log_model(
                        raw_model,
                        "models",
                        registered_model_name=f"{raw_model.__class__.__name__}[{self.config.model.embedding_size}]",
                    )
                    return func(self, *args, **kwargs)
            else:
                return func(self, *args, **kwargs)

        return wrapper

    @mlflow_run_context
    def train(self):
        self.logger.info("Starting Training")
        for epoch in range(self.epochs_run, self.n_total_epochs):
            epoch += 1
            self._run_epoch(epoch, self.train_loader, train=True)

            # log train info
            self._log_info(epoch, True)

            # save train model
            if epoch % self.save_every == 0:
                self._save_snapshot(epoch)

            # eval run
            with torch.no_grad():
                if self.eval_loader:
                    self._run_epoch(epoch, self.eval_loader, train=False)

            # log eval info
            self._log_info(epoch, False)

            # test
            if (
                self.is_main_process()
                and self.test_generation
                and epoch % self.test_frequency == 0
            ):
                self.generation_evaluation(epoch)

            # adjust learning rate
            self.lr_scheduler.step()

        # log final parameters
        self._log_hparams()

        # wait for the AZ processes to finish
        for p in self.generation_processes:
            p.join()

    @abstractmethod
    def _run_batch(self, td: TensorDict, train: bool):
        pass

    def _warm_up(self):
        # create a fake batch to warm up the PyTorch cache allocator
        max_nodes = self.train_dataset.max_seq_len
        emb_size = self.config.model.embedding_size
        td = TensorDict(
            {
                "nodes": torch.rand((self.batch_size, max_nodes, emb_size)),
                "causal_mask": torch.ones(
                    (self.batch_size, 1, max_nodes, max_nodes), dtype=torch.bool
                ).tril_(),
                "action_mask": torch.ones(
                    (self.batch_size, 4, max_nodes - 1, max_nodes - 1), dtype=torch.bool  # type: ignore
                ).tril_(),
                "actions": torch.rand(
                    (self.batch_size, 4, max_nodes - 1, max_nodes - 1)  # type: ignore
                ),
                "reward": torch.rand((self.batch_size, 1)),
            },
            batch_size=self.batch_size,
        )
        td = self._send_to_device(td)
        loss = self._model_loss(td)
        loss.backward()
        self.optimizer.zero_grad()

    @abstractmethod
    def _send_to_device(self, td: TensorDict) -> TensorDict:
        pass

    def _model_loss(self, td: TensorDict) -> torch.Tensor:
        # forward pass
        if self.train_value:
            action_logits, value = self.model(
                td["nodes"], td["causal_mask"], get_value=True
            )
        else:
            action_logits = self.model(td["nodes"], td["causal_mask"], get_value=False)

        # filter out invalid actions
        if self.return_action_mask:
            # NOTE: The mask is already negated
            action_logits.masked_fill_(
                td["action_mask"], torch.finfo(action_logits.dtype).min
            )

        # normalize target
        tgt = self.target_activation(td["actions"])

        # Apply log-softmax to logits
        action = self.action_activation(action_logits)

        # compute policy loss
        policy_loss = self.policy_criterion(action, tgt)
        self.policy_loss_tracker.update(policy_loss)

        loss = policy_loss

        if self.train_value:
            # Scale value
            value = self.value_activation(value)

            # compute value loss
            value_loss = self.value_criterion(value, td["reward"].unsqueeze(-1))

            self.value_loss_tracker.update(value_loss)

            loss += value_loss

        return loss

    def _run_epoch(
        self,
        epoch: int,
        dataloader: DataLoader,
        train: bool = True,
    ):
        self.policy_loss_tracker.reset()
        self.value_loss_tracker.reset()

        desc = self._get_desc(epoch, train, False)
        with tqdm(
            unit="batch", total=len(dataloader), desc=desc, position=self.bar_pos
        ) as progress_bar:
            for iter, td in enumerate(dataloader):
                self._run_batch(td, train)
                progress_bar.update()

                # update bar every 100 batches
                if iter % 100 == 0:
                    desc = self._get_desc(epoch, train, True)
                    progress_bar.set_description(desc)
                    progress_bar.refresh()

                if train and iter % self.step_logging == 0:
                    self._step_log_info()

            # log final loss
            desc = self._get_desc(epoch, train, True)
            progress_bar.set_description(desc)
            progress_bar.refresh()

    def _prepare_dataloader(self, dataset: AIG_Dataset_Collection, train: bool):
        batch_sampler = self._get_batch_sampler(dataset, train)

        if train:
            neg_pr = self.config.train_negation_prob
            per_pr = self.config.train_permutation_prob
        else:
            neg_pr = self.config.eval_negation_prob
            per_pr = self.config.eval_permutation_prob

        collate_wrapper = TrainCollate(
            device=None,
            embedding_size=self.model.embedding_size,
            train_value=self.config.train_value,
            const_node=self.config.dataset.const_node,
            return_action_mask=self.config.dataset.return_action_mask,
            get_causal_mask=self.config.get_causal_mask,
            negation_prob=neg_pr,
            permutation_prob=per_pr,
        )

        return DataLoader(
            dataset,
            collate_fn=collate_wrapper,
            pin_memory=torch.cuda.is_available(),
            num_workers=self.config.dataloader_workers,
            persistent_workers=self.config.dataloader_workers > 0,
            batch_sampler=batch_sampler,
        )

    @abstractmethod
    def _get_batch_sampler(self, dataset, train):
        pass

    ######################### Testing Methods ###########################
    @staticmethod
    def test_truth_tables(
        model: torch.nn.Module,
        data_queue: mp.Queue,
        result_queue: mp.Queue,
        cfg,
    ):
        results = 0
        while not data_queue.empty():
            td = data_queue.get().clone()
            aig_env = AIGEnv(model.embedding_size, cfg.const_node, cfg.reward_type)
            aig_env.reset(td)
            success = generate_AIG(model, aig_env, cfg.max_nodes, cfg.AZ)
            if success:
                results += 1
        result_queue.put(results)

    @staticmethod
    def generation_evaluation_wrapper(
        model: torch.nn.Module,
        data_queue: mp.Queue,
        cfg: Any,
        epoch: int,
        q_size: int,
        log_dir: str,
    ) -> None:
        results_queue = mp.Queue()

        if cfg.AZ_workers > 1:
            processes = []
            for _ in range(cfg.AZ_workers):
                p = mp.Process(
                    target=Trainer.test_truth_tables,
                    args=(
                        model,
                        data_queue,
                        results_queue,
                        cfg,
                    ),
                )
                p.start()
                processes.append(p)

            for p in processes:
                p.join()
        else:
            Trainer.test_truth_tables(
                model,
                data_queue,
                results_queue,
                cfg,
            )

        success = 0
        while not results_queue.empty():
            success += results_queue.get()

        SummaryWriter(log_dir).add_scalar("Accuracy", success / q_size, epoch)
        mlflow.log_metric("Accuracy", success / q_size, epoch)

    def generation_evaluation(
        self,
        epoch: int,
    ) -> None:
        if True:
            self._gpu_generation(epoch)
        else:
            self._cpu_parallel_generation(epoch)

    def _gpu_generation(self, epoch: int):
        raw_model = self.model.module if hasattr(self.model, "module") else self.model
        aig_env = AIGEnv(raw_model.embedding_size, const_node=self.config.const_node)
        aig_env.state = self._send_to_device(aig_env.state)
        aig_env = aig_env.to(aig_env.state.device)  # type: ignore

        results = 0
        for tt in tqdm(self.generation_tds):
            tt = tt.clone()
            print(tt)
            print(tt["num_inputs"])
            print(tt["num_outputs"])
            print(tt["target"])
            tt = self._send_to_device(tt)
            print(tt)
            print(tt["num_inputs"])
            print(tt["num_outputs"])
            print(tt["target"])
            aig_env.reset(tt)
            success = generate_AIG(
                raw_model, aig_env, self.config.max_nodes, self.config.AZ
            )
            if success:
                results += 1
        mlflow.log_metric("Accuracy", success / len(self.generation_tds), epoch)

    def _cpu_parallel_generation(
        self,
        epoch: int,
    ) -> None:
        raw_model = self.model.module if hasattr(self.model, "module") else self.model
        self.test_model.load_state_dict(raw_model.state_dict())
        self.test_model.to("cpu").share_memory().eval()

        # spawn can be problematic
        ctx = mp.get_context("spawn")
        data_queue = ctx.Queue()

        for tt in self.generation_tds:
            data_queue.put_nowait(tt)

        q_size = len(self.generation_tds)
        p = ctx.Process(
            target=self.generation_evaluation_wrapper,
            args=(
                self.test_model,
                data_queue,
                self.config,
                epoch,
                q_size,
                self.log_dir,
            ),
        )
        p.start()
        self.generation_processes.append(p)

    ######################### Logging Methods ###########################
    @staticmethod
    def main_process_only(func: Callable):
        def wrapper(self, *args, **kwargs):
            if self.is_main_process():
                return func(self, *args, **kwargs)

        return wrapper

    def _get_desc(
        self,
        epoch: int,
        train: bool,
        include_loss: bool = False,
    ) -> str:
        step_type = "Train" if train else "Eval "

        desc = (
            f"[{step_type}][Rank: {self.local_rank}][Epoch: {epoch:4} / {self.n_total_epochs}]"
            f"[Total time: {self._get_elapsed_time()}]"
        )
        if include_loss:
            desc += f"[Policy Loss: {round(self.policy_loss_tracker.compute().item(), 3):3f}]"

            if self.train_value:
                desc += f"[Value Loss: {round(self.value_loss_tracker.compute().item(), 3):3f}]"

        return desc

    def _get_elapsed_time(self) -> str:
        return time_formatter(
            time.time() - self.start_time,
            show_ms=False,
        )

    @staticmethod
    def _get_dir_name(model: torch.nn.Module) -> str:
        i = 0
        dir_name = (
            f"runs/pre_{model.__class__.__name__}"
            f"_emb={model.embedding_size}_"
            f"heads={model.n_heads}_v"
        )

        def aux_get_dirname() -> str:
            return f"{dir_name}{i:04d}"

        while os.path.exists(aux_get_dirname()):
            i += 1
        return aux_get_dirname()

    @main_process_only
    def _log_info(self, epoch: int, train: bool):
        if train:
            step_type = "train"
        else:
            step_type = "eval"

        epoch_policy_loss = self.policy_loss_tracker.compute().item()
        self.writer.add_scalar(f"Policy_Loss/{step_type}", epoch_policy_loss, epoch)
        mlflow.log_metric(f"Policy_Loss/{step_type}", epoch_policy_loss, epoch)

        if self.train_value:
            epoch_value_loss = self.value_loss_tracker.compute().item()
            self.writer.add_scalar(f"Value_Loss/{step_type}", epoch_value_loss, epoch)
            mlflow.log_metric(f"Value_Loss/{step_type}", epoch_policy_loss, epoch)

    @main_process_only
    def _step_log_info(self):
        train_epoch_policy_loss = self.policy_loss_tracker.compute().item()
        self.writer.add_scalar(
            "Policy_Loss/train(total)",
            train_epoch_policy_loss,
            self.n_seen_points,
        )
        mlflow.log_metric(
            "Policy_Loss/train x-total", train_epoch_policy_loss, self.n_seen_points
        )

        if self.train_value:
            train_epoch_value_loss = self.value_loss_tracker.compute().item()
            self.writer.add_scalar(
                "Value_Loss/train(x-total)",
                train_epoch_value_loss,
                self.n_seen_points,
            )
            mlflow.log_metric(
                "Value_Loss/train x-total", train_epoch_policy_loss, self.n_seen_points
            )

    @main_process_only
    def _log_hparams(self):
        self.writer.add_hparams(
            {
                "lr": self.config.optimizer.lr,
                "embedding_size": self.config.model.embedding_size,
                "n_heads": self.config.model.n_heads,  # type: ignore
                "num_layers": self.config.model.n_layers,  # type: ignore
                "batch size": self.config.batch_size,
                "epochs": self.config.epochs,
            },
            {
                "hparam/policy_loss": self.policy_loss_tracker.compute().item(),
                "hparam/value_loss": self.value_loss_tracker.compute().item(),
            },
        )  # type: ignore

    def _log_params(self):
        def _explore_recursive(parent_name, element):
            if isinstance(element, DictConfig):
                for k, v in element.items():
                    if isinstance(v, DictConfig) or isinstance(v, ListConfig):
                        _explore_recursive(f"{parent_name}.{k}", v)
                    else:
                        mlflow.log_param(f"{parent_name}.{k}", v)
            elif isinstance(element, ListConfig):
                for i, v in enumerate(element):
                    mlflow.log_param(f"{parent_name}.{i}", v)
            else:
                mlflow.log_param(parent_name, element)

        for param_name, element in self.config.model.items():  # type: ignore
            _explore_recursive(param_name, element)

    def is_main_process(self) -> bool:
        return self.local_rank == self.config.master_rank

    def _load_snapshot(self):
        if self.model_file == "latest":
            model_path = max(glob.glob(os.path.join(self.log_dir, "*.pt")), key=os.path.getmtime)  # type: ignore
        else:
            model_path = os.path.join(self.log_dir, self.model_file)  # type: ignore

        try:
            snapshot = load_snapshot(model_path)
        except FileNotFoundError:
            self.logger.info("Snapshot not found. Training model from scratch")
            return

        self.model.load_state_dict(snapshot.model_state)
        self.optimizer.load_state_dict(snapshot.optimizer_state)
        self.epochs_run = snapshot.finished_epoch
        self.logger.info(f"Loaded model from {model_path} at Epoch {self.epochs_run}")

    @main_process_only
    def _save_snapshot(self, epoch, name: str | None = None):
        if name is None:
            path = os.path.join(self.log_dir, f"model_{epoch}.pt")  # type: ignore
        else:
            path = os.path.join(self.log_dir, f"{name}.pt")  # type: ignore

        # capture snapshot
        # save_snapshot(self.model, self.optimizer, epoch, path)
        mlflow_save_snapshot(self.model, self.optimizer, epoch, path)

        self.logger.info(f"Saved model at {path}")


class BasicTrainer(Trainer):
    def __init__(
        self,
        cfg: TrainConfig,
        model: torch.nn.Module,
        test_model: torch.nn.Module,
        policy_criterion: torch.nn.Module,
        value_criterion: torch.nn.Module,
        optimizer: Optimizer,
        lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
        action_activation: Callable,
        target_activation: Callable,
        value_activation: Callable,
        train_dataset: AIG_Dataset_Collection,
        eval_dataset: AIG_Dataset_Collection,
        generation_dataset: Any,
        device: torch.device,
    ):
        super().__init__(
            cfg,
            model,
            test_model,
            policy_criterion,
            value_criterion,
            optimizer,
            lr_scheduler,
            action_activation,
            target_activation,
            value_activation,
            train_dataset,
            eval_dataset,
            generation_dataset,
        )
        # initialize train states
        self.model = model.to(device)
        self.policy_criterion.to(device)
        self.value_criterion.to(device)
        self.policy_loss_tracker.to(device)
        self.value_loss_tracker.to(device)
        if self.config.use_amp:
            self.scaler = torch.amp.GradScaler()  # type: ignore

    def _run_batch(self, td: TensorDict, train: bool = True) -> None:
        with torch.set_grad_enabled(train), torch.autocast(
            device_type="cuda", dtype=torch.float16, enabled=self.config.use_amp
        ):
            # model forward pass
            td = td.to(self.local_rank, non_blocking=True)
            loss = self._model_loss(td)

            # optimization step
            if train:
                self.n_seen_points += len(td["nodes"]) * self.world_size
                self.optimizer.zero_grad(set_to_none=True)
                if self.config.use_amp:
                    self.scaler.scale(loss).backward()

                    if self.config.use_grokfast:
                        self._grads = gradfilter_ema(
                            self.model, grads=self._grads, alpha=0.98, lamb=2.0
                        )

                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), self.config.grad_norm_clip
                    )
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    loss.backward()

                    if self.config.use_grokfast:
                        self._grads = gradfilter_ema(
                            self.model, grads=self._grads, alpha=0.98, lamb=2.0
                        )

                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), self.config.grad_norm_clip
                    )
                    self.optimizer.step()

    def _get_batch_sampler(self, dataset, train):
        if train:
            batch_sampler = AIGBatchSampler(dataset, self.config.batch_size)
        else:
            batch_sampler = EvalAIGBatchSampler(dataset, self.config.batch_size)
        return batch_sampler

    def _send_to_device(self, td: TensorDict) -> TensorDict:
        return td.to(self.local_rank, non_blocking=True)


class DDPTrainer(BasicTrainer):
    def __init__(
        self,
        cfg: TrainConfig,
        model: torch.nn.Module,
        test_model: torch.nn.Module,
        policy_criterion: torch.nn.Module,
        value_criterion: torch.nn.Module,
        optimizer: Optimizer,
        lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
        action_activation: Callable,
        target_activation: Callable,
        value_activation: Callable,
        train_dataset: AIG_Dataset_Collection,
        eval_dataset: AIG_Dataset_Collection,
        generation_dataset: Any,
    ):
        local_rank = os.environ.get("LOCAL_RANK")
        device = torch.device(f"cuda:{local_rank}")
        super().__init__(
            cfg,
            model,
            test_model,
            policy_criterion,
            value_criterion,
            optimizer,
            lr_scheduler,
            action_activation,
            target_activation,
            value_activation,
            train_dataset,
            eval_dataset,
            generation_dataset,
            device,
        )

        # wrap with DDP. this step will synch model across all the processes.
        if self.config.use_compile:
            self.model = torch.compile(self.model)
        self.model = DDP(
            self.model, device_ids=[self.local_rank], find_unused_parameters=True
        )
        with self.model.no_sync():
            self._warm_up()

    def _get_batch_sampler(
        self, dataset: AIG_Dataset_Collection, train: bool
    ) -> DistributedBatchSampler:
        return DistributedBatchSampler(dataset, batch_size=self.batch_size)

    def _run_epoch(self, epoch: int, dataloader: DataLoader, train: bool = True):
        dataloader.batch_sampler.set_epoch(epoch)  # type: ignore
        super()._run_epoch(epoch, dataloader, train)

    @Trainer.mlflow_run_context
    def train(self):
        self.logger.info("Starting Training")
        for epoch in range(self.epochs_run, self.n_total_epochs):
            epoch += 1
            with self.model.join(throw_on_early_termination=True):
                self._run_epoch(epoch, self.train_loader, train=True)
            torch.cuda.synchronize(self.local_rank)

            # log train info
            self._log_info(epoch, True)

            # save train model
            if epoch % self.save_every == 0:
                self._save_snapshot(epoch)

            # eval run
            with self.model.join(throw_on_early_termination=True):
                if self.eval_loader:
                    self._run_epoch(epoch, self.eval_loader, train=False)
            torch.cuda.synchronize(self.local_rank)

            # log eval info
            self._log_info(epoch, False)

            # test
            if (
                self.is_main_process()
                and self.test_generation
                and epoch % self.test_frequency == 0
            ):
                self.generation_evaluation(epoch)

            # adjust learning rate
            self.lr_scheduler.step()

        # log final parameters
        self._log_hparams()

        # wait for the AZ processes to finish
        for p in self.generation_processes:
            p.join()

        if self.is_main_process():
            mlflow.end_run()


class AcceleratorTrainer(Trainer):
    def __init__(
        self,
        cfg: TrainConfig,
        model: torch.nn.Module,
        test_model: torch.nn.Module,
        policy_criterion: torch.nn.Module,
        value_criterion: torch.nn.Module,
        optimizer: Optimizer,
        lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
        action_activation: Callable,
        target_activation: Callable,
        value_activation: Callable,
        train_dataset: AIG_Dataset_Collection,
        eval_dataset: AIG_Dataset_Collection,
        generation_dataset: Any,
        accelerator: Accelerator,
    ):
        super().__init__(
            cfg,
            model,
            test_model,
            policy_criterion,
            value_criterion,
            optimizer,
            lr_scheduler,
            action_activation,
            target_activation,
            value_activation,
            train_dataset,
            eval_dataset,
            generation_dataset,
        )
        self.accelerator = accelerator
        self._prepare_accelerator()
        with self.accelerator.no_sync(self.model):
            self._warm_up()

    def _run_batch(self, td: TensorDict, train: bool):
        # model forward pass
        loss = self._model_loss(td)

        # optimization step
        if train:
            self.n_seen_points += len(td["nodes"]) * self.world_size
            self.optimizer.zero_grad(set_to_none=True)
            self.accelerator.backward(loss)

            if self.config.use_grokfast:
                self._grads = gradfilter_ema(
                    self.model, grads=self._grads, alpha=0.98, lamb=2.0  # type: ignore
                )

            self.accelerator.clip_grad_norm_(
                self.model.parameters(), self.config.grad_norm_clip
            )
            self.optimizer.step()

    def _prepare_accelerator(self) -> None:
        (
            self.model,
            self.optimizer,
            self.train_loader,
            self.eval_loader,
            # self.train_batch_sampler,
            # self.eval_batch_sampler,
            self.lr_scheduler,
            self.policy_loss_tracker,
            self.value_loss_tracker,
        ) = self.accelerator.prepare(
            self.model,
            self.optimizer,
            self.train_loader,
            self.eval_loader,
            # self.train_batch_sampler,
            # self.eval_batch_sampler,
            self.lr_scheduler,
            self.policy_loss_tracker,
            self.value_loss_tracker,
        )

    def _get_batch_sampler(self, dataset, train):
        if train:
            batch_sampler = AIGBatchSampler(dataset, self.config.batch_size)
        else:
            batch_sampler = EvalAIGBatchSampler(dataset, self.config.batch_size)
        return batch_sampler

    def _send_to_device(self, td: TensorDict) -> TensorDict:
        return td.to(self.accelerator.device, non_blocking=True)
