from os import wait
from typing import Optional, Callable, Dict, Any, List, Tuple, Literal
from contextlib import nullcontext
from functools import partial
import datetime as dt
from pytz import timezone

import warnings

from sympy import use
import torch
from torch import Tensor
from torch_sparse import SparseTensor
from torch_geometric.data import Data
from torch.utils.checkpoint import checkpoint
from pytorch_memlab import MemReporter

from loguru import logger
from tensornvme import DiskOffloader
from torch.profiler import profile, ProfilerActivity

from models.offload import save_on_cpu, grad_offload, grad_upload
from utils.cache import Cache
# from utils.buffer import Buffer
from utils.new_buffer import HostStorageTensors, AcceleratorTensors
from utils.pool import AsyncIOPool
from utils.loader import SubgraphLoader
from utils.adjacency_matrix import AdjacencyMatrixWithOffloader, AdjacencyMatrixWithGDS
from utils.others import report_memory

from utils.debug import get_memory_used
# from utils.debug import *
# from utils.gdtensor import GDTensor


class GriNNderGNN(torch.nn.Module):
    r"""An abstract class for implementing exact GNN training
    via offloading techniques.
    This class manages the intermediate feature buffers and the gradient buffers.
    In case historical embeddings are stored on the CPU, they will reside
    inside pinned memory, which allows for asynchronous memory transfers of
    historical embeddings.
    For this, this class maintains a :class:`AsyncIOPool` object that
    implements the underlying mechanisms of asynchronous memory transfers as
    described in our paper.

    Args:
        num_nodes (int): The number of nodes in the graph.
        hidden_channels (int): The number of hidden channels of the model.
            As a current restriction, all intermediate node embeddings need to
            utilize the same number of features.
        num_layers (int): The number of layers of the model.
        pool_size (int, optional): The number of pinned CPU buffers for pulling
            histories and transfering them to GPU.
            Needs to be set in order to make use of asynchronous memory
            transfers. (default: :obj:`None`)
    """
    def __init__(self, in_channels: int, hidden_channels: int, out_channels: int, num_layers: int,
                 loader: SubgraphLoader,
                 topo_caching: Optional[bool] = True,
                 device=None,
                 use_cache: Optional[bool] = False,
                 layer_wise_cache: Optional[bool] = False,
                 checkpointing_strategy: Literal['scattered', 'cpu', 'storage'] = 'scattered',
                 storage_offload: Optional[bool] = False,
                 storage_path: Optional[str] = None,
                 optimize_dataloader: bool = True):
        super().__init__()

        self.verbose = False

        self.layer_wise_cache = layer_wise_cache
        self.checkpointing_strategy = checkpointing_strategy

        self.optimize_dataloader = optimize_dataloader

        self.storage_offload = storage_offload
        if self.storage_offload:
            self.secondary_device = 'storage'
        else:
            self.secondary_device = 'cpu'

        # self.loader = loader
        self.hidden_channels = hidden_channels
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self._device = device

        self.topo_caching = topo_caching
        self.adj_ts = list[AdjacencyMatrixWithGDS]()
        self.all_boundaries = list[list[Tensor | None]]()
        self.all_states = list[list[dict | None]]()
        self.batch_sizes = list[int]()

        self.in_sizes: list[int] = [len(subdata.n_id) for subdata in loader]
        self.n_parts = len(loader)

        self.use_cache = use_cache

        self.pool_size = 2 # double buffering
        self.d2h_streams = list[torch.cuda.Stream]()
        self.h2d_streams = list[torch.cuda.Stream]()
        self.bw_act_h2d_streams = list[torch.cuda.Stream]()

        for _ in range(self.pool_size):
            self.d2h_streams.append(torch.cuda.Stream('cuda:0'))
            self.h2d_streams.append(torch.cuda.Stream('cuda:0'))
            self.bw_act_h2d_streams.append(torch.cuda.Stream('cuda:0'))
        self.fbw_stream = torch.cuda.Stream('cuda:0') # We need only one FW and BW stream 

        self.f_events = list[list[torch.cuda.Event]()]()
        self.b_events = list[list[torch.cuda.Event]()]()
        for i in range(self.num_layers):
            self.f_events.append([])
            self.b_events.append([])
            for _ in range(self.n_parts):
                self.f_events[i].append(torch.cuda.Event(enable_timing=True))
                self.b_events[i].append(torch.cuda.Event(enable_timing=True))


        loader = [sub_data + ({}, ) for sub_data in loader]
        for _, (data, batch_size, _, boundaries, *state) in enumerate(loader):
            self.adj_ts.append(data.adj_t) # preload topology to device
            self.all_boundaries.append(boundaries)
            self.all_states.append(state)
            self.batch_sizes.append(batch_size)

        self.host_storage_tensors = list[HostStorageTensors | None]()
        self.host_storage_tensors.append(
            HostStorageTensors(self.n_parts, self.in_channels,
                                 self.batch_sizes, device=self.secondary_device,
                                 storage_path=storage_path))
        for i in range(self.num_layers-1):
            self.host_storage_tensors.append(
                HostStorageTensors(self.n_parts, self.hidden_channels,
                                   self.batch_sizes, device=self.secondary_device,
                                   storage_path=storage_path))
        self.host_storage_tensors.append(
            HostStorageTensors(self.n_parts, self.out_channels,
                                 self.batch_sizes, device=self.secondary_device,
                                 storage_path=storage_path))

        """ We need to store the gradients for the backward pass """
        # kinda gradient passer
        self.host_storage_gradients = list[HostStorageTensors | None]()
        self.accelerator_gradients = list[AcceleratorTensors | None]()
        for i in range(self.num_layers):
            if i == 0:
                self.host_storage_gradients.append(None)
            else:
                self.host_storage_gradients.append(
                    HostStorageTensors(self.n_parts, self.hidden_channels,
                                    self.batch_sizes, device=self.secondary_device,
                                    storage_path=storage_path, is_grad=True))
            self.accelerator_gradients.append(
                AcceleratorTensors(self.n_parts, self.hidden_channels,
                                   self.batch_sizes, device='cuda', for_grad=True))

        # prefill init buffer
        for i, (data, *_) in enumerate(loader):
            if isinstance(data.true_x, Tensor):
                self.host_storage_tensors[0].sync_fill(i, data.true_x)
                data.true_x = None
            else:  # gdtensor
                # self.init_feat_gdtensors[i] = data.true_x
                assert data.true_x.at == 'storage'
                data.true_x.synchronize()
                data.true_x.to_inplace('cpu', async_=False)
                self.host_storage_tensors[0].sync_fill(i, data.true_x._tensor)
                data.true_x = None

        self.accelerator_tensors = list[AcceleratorTensors | None]()
        self.accelerator_tensors.append(
            AcceleratorTensors(self.n_parts, self.in_channels,
                                self.in_sizes, device='cuda'))
        for i in range(self.num_layers-1):
            self.accelerator_tensors.append(
                AcceleratorTensors(self.n_parts, self.hidden_channels,
                                   self.in_sizes, device='cuda'))

        self.accelerator_activations = list[list[Tensor | None]]()

        # we need the same number of activations
        # generated from the accelerator_tensors
        for i in range(self.num_layers):
            self.accelerator_activations.append(
                [None for _ in range(self.n_parts)]
            )

        logger.info(report_memory('Before DATA'))

        if not self.optimize_dataloader:
            for i in range(self.n_parts):
                self.adj_ts[i].storage_to_device(async_=True)
            for i in range(self.n_parts):
                self.adj_ts[i].synchronize() # for testing
            
            # for i in range(self.n_parts):
            #     self.adj_ts[i].discard_from_gpu(async_=True)
            # for i in range(self.n_parts):
            #     self.adj_ts[i].synchronize() # for testing

        logger.info(report_memory('After DATA (Init Fin.)'))


        # logging memory usage
        logger.info(f"After host/storage tensoring: {get_memory_used()/10**9} GB")

    @property
    def device(self):
        return self._device

    def reset_tensors(self, epoch: int):
        # we need to reset buffers w/o the initial tensors
        for i, buffer in enumerate(self.host_storage_tensors):
            if i == 0:
                continue # do not reset the initial tensor
            buffer.reset_tensors(epoch)

        # free activations
        for i, acts in enumerate(self.accelerator_activations):
            for j, act in enumerate(acts):
                if act is not None:
                    self.accelerator_activations[i][j] = None
                    del act
        # reset grad buffers
        for i, buffer in enumerate(self.host_storage_gradients):
            if i != 0:
                buffer.reset_tensors(epoch)
            # we do not have the gradients for the first layer
        # reset accelerator tensors
        for buffer in self.accelerator_tensors:
            buffer.reset_tensors()
        # reset accelerator gradients
        for buffer in self.accelerator_gradients:
            buffer.reset_tensors()

    def __call__(self) -> tuple[Tensor, DiskOffloader | None]:
        return self.grinnder_forward()

    def _synchronize(self):
        # for stream in self.fw_streams:
        #     stream.synchronize()
        # for stream in self.bw_streams:
        #     stream.synchronize()
        pass

    def backward(self, losses: List[Tensor]):
        r""" As :meth:`exact_forward` requires multiple computation graph,
        We need to connect computation graphs, which are dedicated for each layer.
        IF we run backward for the exact mode.
        """

        logger.info(f'Backward start {dt.datetime.now().astimezone(timezone("Asia/Seoul"))}')

        torch.cuda.nvtx.range_push(f'Last layer BW') # set layer profile
        logger.info(f"BW the last layer {len(self.accelerator_activations)-1}")
        
        """
        We need to upload three things
        - 1) Grad of (i+1)
        - 2) Activation of (i+1)
        - 3) Activation of i
        We need to offload one thing
        - 1) Grad of i

        Graphical Expression:
        
        (Gathered A_1)-|
                       |
        __  __ _____ __  __
        | | | | (BW) | | | |
        --  --  \    | | | |
        G_2 A_2  \   | | | | --- (Scatter/Offlaod the G_1)
                  \  | | | |
          (expand) \ | | | |
                     --  -- 
                     A_1 G_1

        Afther each BW, we need to free device memory
        - 1) Grad of (i+1)
        - 2) Activation of (i+1)
        - 3) Activation / Grad of i

        """

        if self.storage_offload:
            self.host_storage_tensors[-2].storage_to_cpu()
            if self.host_storage_gradients[-1] is not None:
                self.host_storage_gradients[-1].storage_to_cpu()

        # prologue

        # For the last layer, 1) Upload Grad of (i+1) is omitted
        # (because the loss is already on the GPU and pretty light)

        # 2) Upload Activation of (i+1)
        self.accelerator_activations[-1][0].untyped_storage().resize_(
            self.accelerator_activations[-1][0].numel()*self.accelerator_activations[-1][0].element_size())
        self.host_storage_tensors[-1].async_upload(0, self.accelerator_activations[-1][0],
                                                   h2d_stream=self.h2d_streams[0], wait_stream=torch.cuda.current_stream())
        # self.host_storage_tensors[-1].sync_upload(0, self.accelerator_activations[-1][0])

        # 3) Upload Activation of i: it will be handled by the unpack_from_cpu
        # Note that we just handle the sync by the unpack_from_cpu
        if self.checkpointing_strategy == 'scattered':
            self.accelerator_tensors[-1].async_gather(0, self.host_storage_tensors[-2], self.all_boundaries[0],
                                                    h2d_stream=self.bw_act_h2d_streams[0], wait_stream=torch.cuda.current_stream())

        if self.optimize_dataloader:
            self.adj_ts[0].storage_to_device(async_=True)

        # There is no free mechanism for 1) Grad of (i+1) in the last layer
        # Because loss is pretty lightweight and loss related grads are automatically freed
        for pid, loss in enumerate(losses):

            self.fbw_stream.wait_stream(self.h2d_streams[pid%self.pool_size]) # wait for the previous h2d
            if self.checkpointing_strategy == 'scattered':
                self.fbw_stream.wait_stream(self.bw_act_h2d_streams[pid%self.pool_size]) # wait for the previous h2d

            self.host_storage_tensors[-1].h2d_synchronize(self.h2d_streams[pid%self.pool_size])

            if self.optimize_dataloader:
                self.adj_ts[pid].synchronize()

            if pid != len(losses)-1:
                if self.optimize_dataloader:
                    self.adj_ts[pid+1].storage_to_device(async_=True)
                self.accelerator_activations[-1][pid+1].untyped_storage().resize_(
                    self.accelerator_activations[-1][pid+1].numel()*self.accelerator_activations[-1][pid+1].element_size())
                self.host_storage_tensors[-1].async_upload(pid+1, self.accelerator_activations[-1][pid+1],
                                                        h2d_stream=self.h2d_streams[(pid+1)%self.pool_size], wait_stream=self.h2d_streams[pid%self.pool_size])
                if self.checkpointing_strategy == 'scattered':
                    self.accelerator_tensors[-1].async_gather(pid+1, self.host_storage_tensors[-2], self.all_boundaries[pid+1],
                                                            h2d_stream=self.bw_act_h2d_streams[(pid+1)%self.pool_size], wait_stream=self.bw_act_h2d_streams[pid%self.pool_size])

            with torch.cuda.stream(self.fbw_stream): # start the backward
                if pid != 0:
                    if self.host_storage_gradients[-1] is not None:
                        self.host_storage_gradients[-1].d2h_synchronize(self.d2h_streams[(pid-1)%self.pool_size])
                    # 3) free the activation/grad of i
                    self.accelerator_tensors[-1][pid-1].untyped_storage().resize_(0)
                    self.accelerator_tensors[-1][pid-1].grad = None
                    # we can free the grad of i
                    # because we already offloaded it to the host memory

                if pid == len(losses) - 1:
                    retain_graph=False
                else:
                    retain_graph=True                
                loss.backward(retain_graph=retain_graph)
                # 2) free the used activations
                self.accelerator_activations[-1][pid].untyped_storage().resize_(0)
                self.accelerator_activations[-1][pid].grad = None
                if self.optimize_dataloader:
                    # self.adj_ts[pid].gpu_to_storage(async_=True) # async fill
                    # just discarding from here is enough
                    # todo - the below code seems to occur accuracy drop
                    # I think we do not need to further optimize this part
                    # because it is already much faster than the existing work
                    self.adj_ts[pid].discard_from_device(async_=True)
        
        # epilogue
        if self.optimize_dataloader:
            for i in range(self.n_parts):
                self.adj_ts[i].synchronize()

        if self.host_storage_gradients[-1] is not None:
            self.host_storage_gradients[-1].d2h_synchronize(self.d2h_streams[(self.n_parts-1)%self.pool_size]) # epilogue
        # 3) free the activation/grad of i
        self.accelerator_tensors[-1][self.n_parts-1].untyped_storage().resize_(0)
        self.accelerator_tensors[-1][self.n_parts-1].grad = None
        # semi 1) free the grad of i
        for loss in losses:
            del loss

        if self.storage_offload: # use the already existing storage tensor
            self.host_storage_tensors[-2].cpu_to_storage(use_duplicate=True)
            if self.host_storage_gradients[-1] is not None:
                self.host_storage_gradients[-1].cpu_to_storage(use_duplicate=False) # we must update the accumulated gradients

        torch.cuda.nvtx.range_pop() # pop layer profile

        logger.info(report_memory('BW Last'))
        # logging memory usage
        logger.info(f"Host: {get_memory_used()/10**9} GB")

        # For this part, we need to take care of the host_storage_gradients
        # For the above part, we do not need to take care of the host_storage_gradients
        # because the loss is already on the GPU and loss is very light
        for lid in reversed(range(self.num_layers-1)):
            torch.cuda.nvtx.range_push(f'Layer {lid} BW') # set layer profile
            logger.info(f"Backward layer {lid}", dt.datetime.now().astimezone(timezone('Asia/Seoul')))

            if self.storage_offload:
                self.host_storage_tensors[lid].storage_to_cpu()
                if self.host_storage_gradients[lid] is not None:
                    self.host_storage_gradients[lid].storage_to_cpu()

            # prologue
            if self.optimize_dataloader:
                self.adj_ts[0].storage_to_device(async_=True)

            # 1) Upload Grad of (i+1) is omitted
            # (because the loss is already on the GPU and pretty light)
            self.accelerator_gradients[lid].async_direct_pull(0, self.host_storage_gradients[lid+1][0],
                                                              h2d_stream=self.h2d_streams[0], wait_stream=torch.cuda.current_stream(),
                                                              is_gds=self.storage_offload)

            # 2) Upload Activation of (i+1)
            self.accelerator_activations[lid][0].untyped_storage().resize_(
                self.accelerator_activations[lid][0].numel()*self.accelerator_activations[lid][0].element_size())
            self.host_storage_tensors[lid+1].async_upload(0, self.accelerator_activations[lid][0],
                                                          h2d_stream=self.h2d_streams[0], wait_stream=torch.cuda.current_stream())

            # 3) Upload Activation of i: it will be handled by the unpack_from_cpu
            # Note that we just handle the sync by the unpack_from_cpu
            if self.checkpointing_strategy == 'scattered':
                self.accelerator_tensors[lid].async_gather(0, self.host_storage_tensors[lid], self.all_boundaries[0],
                                                        h2d_stream=self.bw_act_h2d_streams[0], wait_stream=torch.cuda.current_stream())

            for pid in range(self.n_parts):

                self.fbw_stream.wait_stream(self.h2d_streams[pid%self.pool_size]) # wait for the previous h2d
                if self.checkpointing_strategy == 'scattered':
                    self.fbw_stream.wait_stream(self.bw_act_h2d_streams[pid%self.pool_size]) # wait for the previous h2d

                if self.optimize_dataloader:
                    self.adj_ts[pid].synchronize()
                self.host_storage_tensors[lid+1].h2d_synchronize(self.h2d_streams[pid%self.pool_size])
                self.accelerator_gradients[lid].h2d_synchronize(self.h2d_streams[pid%self.pool_size],
                                                                is_gds=self.storage_offload)

                if pid != self.n_parts-1:
                    if self.optimize_dataloader:
                        self.adj_ts[pid+1].storage_to_device(async_=True)
                    self.accelerator_activations[lid][pid+1].untyped_storage().resize_(
                        self.accelerator_activations[lid][pid+1].numel()*self.accelerator_activations[lid][pid+1].element_size())
                    self.host_storage_tensors[lid+1].async_upload(pid+1, self.accelerator_activations[lid][pid+1],
                                                                    h2d_stream=self.h2d_streams[(pid+1)%self.pool_size], wait_stream=self.h2d_streams[pid%self.pool_size])
                    self.accelerator_gradients[lid].async_direct_pull(pid+1, self.host_storage_gradients[lid+1][pid+1],
                                                                        h2d_stream=self.h2d_streams[(pid+1)%self.pool_size], wait_stream=self.h2d_streams[pid%self.pool_size],
                                                                        is_gds=self.storage_offload)
                    if self.checkpointing_strategy == 'scattered':
                        self.accelerator_tensors[lid].async_gather(pid+1, self.host_storage_tensors[lid], self.all_boundaries[pid+1],
                                                            h2d_stream=self.bw_act_h2d_streams[(pid+1)%self.pool_size], wait_stream=self.bw_act_h2d_streams[pid%self.pool_size])
                
                with torch.cuda.stream(self.fbw_stream): # start the backward
                    if pid != 0:
                        if self.host_storage_gradients[lid] is not None:
                            self.host_storage_gradients[lid].d2h_synchronize(self.d2h_streams[(pid-1)%self.pool_size])
                        # 3) free the activation/grad of i
                        self.accelerator_tensors[lid][pid-1].untyped_storage().resize_(0)
                        self.accelerator_tensors[lid][pid-1].grad = None
                        # we can free the grad of i
                        # because we already offloaded it to the host memory

                    if pid == self.n_parts - 1:
                        retain_graph=False 
                    else:
                        retain_graph=True

                    self.accelerator_activations[lid][pid].backward(self.accelerator_gradients[lid][pid], retain_graph=retain_graph)
                    # 2) free the used activations
                    self.accelerator_activations[lid][pid].untyped_storage().resize_(0)
                    self.accelerator_activations[lid][pid].grad = None
                    # 1) free the grad of (i+1)
                    self.accelerator_gradients[lid].free_from_device(pid)
                    if self.optimize_dataloader:
                        # self.adj_ts[pid].gpu_to_storage(async_=True) # async fill
                        # just discarding from here is enough
                        self.adj_ts[pid].discard_from_device(async_=True)

            # epilogue
            if self.optimize_dataloader:
                for i in range(self.n_parts):
                    self.adj_ts[i].synchronize()
            if self.host_storage_gradients[lid] is not None:
                self.host_storage_gradients[lid].d2h_synchronize(self.d2h_streams[(self.n_parts-1)%self.pool_size]) # epilogue
            # 3) free the activation/grad of i
            self.accelerator_tensors[lid][self.n_parts-1].untyped_storage().resize_(0)
            self.accelerator_tensors[lid][self.n_parts-1].grad = None

            if self.storage_offload:
                self.host_storage_tensors[lid].cpu_to_storage(use_duplicate=True)
                if self.host_storage_gradients[lid] is not None:
                    self.host_storage_gradients[lid].cpu_to_storage(use_duplicate=False) # we must update the accumulated gradients

            torch.cuda.nvtx.range_pop() # pop layer profile
            logger.info(report_memory('BW'))
            # logging memory usage
            logger.info(f"Host: {get_memory_used()/10**9} GB")

    def grinnder_forward(self) -> tuple[Tensor, DiskOffloader | None]:
        r""" An implementation of layer-wise forward of GNNs
        For each layer, we need to sync all the embedding by :meth:`foward_layer`.
        Additional state (sch as res. conn.) can be stored in a `state` directory.
        """
        logger.info(f"Forward start {dt.datetime.now().astimezone(timezone('Asia/Seoul'))}")

        logger.info(report_memory('FW Start'))

        for lid in range(self.num_layers):
            torch.cuda.nvtx.range_push(f'Layer {lid} FW') # set layer profile
            logger.info(f"Forward layer {lid}", dt.datetime.now().astimezone(timezone('Asia/Seoul')))
            logger.info(report_memory(f'FW {lid}'))

            # TODO: layer-wise caching

            if self.storage_offload:
                self.host_storage_tensors[lid].storage_to_cpu()

            if self.optimize_dataloader:
                self.adj_ts[0].storage_to_device(async_=True)

            logger.info(report_memory(f'FW {lid} 0 prologue'))

            # Prologue (pulling from previous layer)
            # prologue waits the current stream
            self.accelerator_tensors[lid].async_gather(0, self.host_storage_tensors[lid], self.all_boundaries[0],
                                                       h2d_stream=self.h2d_streams[0], wait_stream=torch.cuda.current_stream())

            logger.info(report_memory(f'FW {lid} 0 prologue done'))

            for pid in range(1, self.n_parts+1): # start from 1 (as we manually pull 0)
                # torch.cuda.nvtx.range_pop() # pop gather

                # after prologue (or h2d gather)
                # we need to start the forward
                self.fbw_stream.wait_stream(self.h2d_streams[(pid-1)%self.pool_size]) # wait for the previous h2d

                with torch.cuda.stream(self.fbw_stream): # start the forward
                    if self.optimize_dataloader:
                        self.adj_ts[pid-1].synchronize()
                    self.accelerator_tensors[lid].h2d_synchronize(self.h2d_streams[(pid-1)%self.pool_size])

                    # overlapping I/O and computation
                    # part 1: I/O
                    if pid != self.n_parts: # not the last partition
                        if self.optimize_dataloader:
                            self.adj_ts[pid].storage_to_device(async_=True)
                        # free the activation from the previous part
                        self.accelerator_tensors[lid].async_gather(pid, self.host_storage_tensors[lid], self.all_boundaries[pid],
                                                                h2d_stream=self.h2d_streams[pid%self.pool_size], wait_stream=self.h2d_streams[(pid-1)%self.pool_size])

                    # print memory usage
                    logger.info(report_memory(f'FW {lid} {pid} Before require_grad'))

                    # part 2: computation
                    self.accelerator_tensors[lid][pid-1].requires_grad = True

                    # print memory usage
                    logger.info(report_memory(f'FW {lid} {pid} After require_grad'))

                    # state = self.all_states[pid-1]

                    # ================================
                    # # memory profiling
                    # reporter = MemReporter()
                    # reporter.report()
                    # ================================

                    # import code; code.interact(local=locals())

                    # self.fbw_stream.wait_stream(self.h2d_streams[(pid-1)%self.pool_size])
                    # we need to wait for the h2d_stream before the forward
                    # it is because the forward layer needs the embeddings
                    with save_on_cpu(lid, pid-1, self.accelerator_tensors,
                                    checkpointing_strategy=self.checkpointing_strategy,
                                    h2d_stream=self.bw_act_h2d_streams[(pid-1)%self.pool_size]):
                        # out = checkpoint(
                        #     self.upoffload_wrapper,
                        #     lid, pid-1, self.accelerator_tensors[lid][pid-1], self.adj_ts[pid-1], state,
                        #     use_reentrant=True
                        # )
                        out = checkpoint(
                            self.upoffload_wrapper,
                            lid, pid-1, self.accelerator_tensors[lid][pid-1],
                            use_reentrant=True
                        ) # warning do not pass adj_t
                        # because PyTorch's gradient checkpointing
                        # checkpoints all the tensors passed to the function

                    # import code; code.interact(local=locals())

                    logger.info(report_memory(f'FW {lid} {pid} After forward'))

                    if self.optimize_dataloader:
                        # self.adj_ts[pid-1].gpu_to_storage(async_=True) # async fill
                        # just discarding from here is enough
                        self.adj_ts[pid-1].discard_from_device(async_=False)
                        # I think async false would be work....

                    logger.info(report_memory(f'FW {lid} {pid} After adj_t discard'))
                    # import code; code.interact(local=locals())

                    # now free acclelerator tensors from the device
                    self.accelerator_tensors[lid].free_from_device(pid-1)

                    logger.info(report_memory(f'FW {lid} {pid} After free'))

                    # save to accelerator activations for compute / I/O overlapping
                    self.accelerator_activations[lid][pid-1] = out
                    
                    # [warning] - i found that proper free mechanism for "out" is required
                    # the recommended way is make a callback for resize zero.

                    # when output is generated, start the d2h fill directly
                    self.host_storage_tensors[lid+1].async_fill(pid-1, out, d2h_stream=self.d2h_streams[(pid-1)%self.pool_size],
                                                                wait_stream=self.fbw_stream)

                    # ================================
                    # self.accelerator_tensors[lid].free_from_device(pid)
                    # self.adj_ts[pid].synchronize()
                    # self.adj_ts[pid].discard_from_gpu(async_=False)
                    # reporter.report(device=torch.device(0))
                    # ================================

                    logger.info(report_memory(f'FW {lid} {pid} After fill'))
                    # import code; code.interact(local=locals())

                # after the forward, we need to check whether the previous d2h is done
                # and free the accelerator tensors
                if pid != 1: # not the first partition
                    torch.cuda.current_stream().wait_stream(self.d2h_streams[(pid-2)%self.pool_size]) # wait for the previous d2h
                    # d2h synchronize from the previous part
                    self.host_storage_tensors[lid+1].d2h_synchronize(self.d2h_streams[(pid-2)%self.pool_size])
                    self.accelerator_activations[lid][pid-2].untyped_storage().resize_(0) # free the previous partition
                logger.info(report_memory(f'FW {lid} {pid}'))

            if self.optimize_dataloader:
                for i in range(self.n_parts):
                    self.adj_ts[i].synchronize()

            if self.storage_offload: # use the already existing storage tensor
                self.host_storage_tensors[lid].cpu_to_storage(use_duplicate=True)

            # epilogue
            torch.cuda.current_stream().wait_stream(self.d2h_streams[(self.n_parts-1)%self.pool_size]) # wait for the last d2h
            self.host_storage_tensors[lid+1].d2h_synchronize(self.d2h_streams[(self.n_parts-1)%self.pool_size]) # flush
            self.accelerator_activations[lid][-1].untyped_storage().resize_(0) # free the last partition

            logger.info(report_memory(f'FW {lid}'))
            # logging memory usage
            logger.info(f"Host: {get_memory_used()/10**9} GB")

            torch.cuda.nvtx.range_pop() # pop layer profile

        return (self.accelerator_activations[-1], self.host_storage_tensors[-1]) # return last activations and offloader

    def upoffload_wrapper(self, lid, pid, x):
        r""" A wrapper for the forward layer.
        This function is used for the offloading the forward layer to the CPU.
        """
        state = self.all_states[pid]

        # reporter = MemReporter()

        # print(f"entrance wrapper {lid} {pid}")

        # import code; code.interact(local=locals())

        x = grad_offload(x, lid, pid, self.host_storage_gradients[lid], self.all_boundaries[pid],
                         self.d2h_streams[pid%self.pool_size], self.fbw_stream)

        # print(f"after grad offload {lid} {pid}")

        # import code; code.interact(local=locals())

        # force false to requries_grad
        self.adj_ts[pid].requires_grad_(False)

        adj_t_gpu = self.adj_ts[pid].to('cuda')

        h = self.forward_layer(lid, pid, x, adj_t_gpu, state)
        # h = self.forward_layer(lid, pid, x, self.adj_ts[pid], state)
        # h = self.forward_layer(lid, pid, x, state)

        del adj_t_gpu

        # print(f"after forward layer {lid} {pid}")

        # import code; code.interact(local=locals())

        # print(f"after resize {lid} {pid}")

        # import code; code.interact(local=locals())

        # below are outdated
        # h = self.forward_layer(lid, pid, x, adj_t, state) 
        # h = grad_upload(pid, h, self.host_storage_tensors[lid+1])

        return h

    # def forward_layer(self, layer: int, pid: int, x: Tensor, adj_t: SparseTensor, state: Dict[str, Any]) -> Tensor:
    def forward_layer(self, layer: int, pid: int, x: Tensor, state: Dict[str, Any]) -> Tensor:
        r""" Layerwise forward conductor.
        For each model, the user need to implement this.
        Note that the user should not use `@torch.no_grad`.
        But use with torch.no_grad() if not gen_grad else nullcontext():
        """
        raise NotImplementedError