# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

"""Full iteration CUDA graph for training."""

import os
import torch
from megatron.core.tensor_parallel.random import get_all_rng_states
from .utils import print_rank_0

# The below functions traverse through nested data structures (tuples, lists, dicts)
# present in src and creates a deep copy where all PyTorch tensors are cloned,
# detached from the computation graph, and moved to CUDA device. Non-tensor objects
# are returned as-is.

# Copy src to new tensors
def copy_tensors_in_struct(src):
    if isinstance(src, tuple):
        return tuple(copy_tensors_in_struct(i) for i in src)
    elif isinstance(src, list):
        return list(copy_tensors_in_struct(i) for i in src)
    elif isinstance(src, dict):
        return {k: copy_tensors_in_struct(src[k]) for k in src}
    elif isinstance(src, torch.Tensor):
        return src.clone().detach().cuda()
    else:
        return src

# Copy src to pre-existing tensors in tgt
def clone_tensors_in_struct(tgt, src):
    if isinstance(src, tuple):
        raise Exception(f"Unsupported copy for tuple yet: {type(src)}")
    elif isinstance(src, list):
        for i in range(len(src)):
            if isinstance(src[i], (tuple, list, dict, torch.Tensor)):
                clone_tensors_in_struct(tgt[i], src[i])
            else:
                tgt[i] = src[i]
    elif isinstance(src, dict):
        for k in src:
            if isinstance(src[k], (tuple, list, dict, torch.Tensor)):
                clone_tensors_in_struct(tgt[k], src[k])
            else:
                tgt[k] = src[k]
    elif isinstance(src, torch.Tensor):
        tgt.copy_(src, non_blocking=True)
    else:
        raise Exception(f"Expect top-level as container type but got: {type(src)}")

# Class to copy dataloader output to static CUDA tensors for CUDA graph input. This
# maintains separate static buffers for training and validation CUDA graphs.
class StaticBufferLoader:
    """Load data to static buffers."""

    static_buffers = {'training': [], 'validation': []}

    def __init__(self):
        self.stream = torch.cuda.Stream()

    def __call__(self, inputs, stage, microbatch):
        assert stage in ['training', 'validation']
        assert microbatch <= len(StaticBufferLoader.static_buffers[stage])
        if isinstance(inputs, tuple) and isinstance(inputs[0], dict):
            inputs = inputs[0]

        assert isinstance(inputs, dict)
        if microbatch == len(
            StaticBufferLoader.static_buffers[stage]
        ):
            with torch.cuda.stream(self.stream):
                StaticBufferLoader.static_buffers[stage].append(
                    copy_tensors_in_struct(inputs)
                )
        else:

            for k in inputs.keys():
                if (
                    k
                    not in StaticBufferLoader.static_buffers[stage][
                        microbatch
                    ]
                ):
                    StaticBufferLoader.static_buffers[stage][
                        microbatch
                    ][k] = torch.empty_like(inputs[k]).cuda()

            with torch.cuda.stream(self.stream):
                clone_tensors_in_struct(
                    StaticBufferLoader.static_buffers[stage][
                        microbatch
                    ],
                    inputs,
                )
        torch.cuda.current_stream().wait_stream(self.stream)
        return StaticBufferLoader.static_buffers[stage][microbatch]


class FullCudaGraphWrapper:

    curr_iteration = {'training': 0, 'validation': 0}
    cuda_graph = {'training': None, 'validation': None}
    result = {'training': None, 'validation': None}
    def __init__(
        self,
        forward_backward_func,
        cuda_graph_warmup_steps = 1
    ):
        self.forward_backward_func = forward_backward_func
        self.static_loader = StaticBufferLoader()
        self.cuda_graph_warmup_steps = cuda_graph_warmup_steps

    def data_read(
        self,
        data_iterator,
        model,
        training,
        num_microbatches
    ):
        if not isinstance(model, list) or len(model) == 1:
            assert not isinstance(data_iterator, list) or len(data_iterator) == 1
            iterator0 = data_iterator if not isinstance(data_iterator, list) else data_iterator[0]
            data_list = []
            if iterator0 is not None:
                for b in range(num_microbatches):
                    data_list.append(self.static_loader(next(iterator0), 'training' if training else 'validation', b))
                data_list = [iter(data_list)]
            else:
                data_list.append(None)
        else:
            assert isinstance(data_iterator, list) and len(data_iterator) == len(model)
            data_list = []
            for i in range(len(model)):
                if data_iterator[i] is not None:
                    data_list_i = []
                    for b in range(num_microbatches):
                        data_list_i.append(self.static_loader(next(data_iterator[i]), 'training' if training else 'validation', b))
                    data_list.append(iter(data_list_i))
                else:
                    data_list.append(None)
        return data_list

    def __call__(
        self,
        *args,
        **kwargs):
        assert len(args)==0, 'forward_backward_func does not accept positional args'
        assert all([kwarg in kwargs for kwarg in ['model', 'data_iterator', 'num_microbatches', 'seq_length', 'forward_only']])
        model = kwargs['model']
        num_microbatches = kwargs['num_microbatches']

        training = not kwargs['forward_only']
        data_iterator = kwargs['data_iterator']
        data_list = self.data_read(data_iterator, model, training, num_microbatches)
        kwargs['data_iterator'] = data_list

        training_str = 'training' if training else 'validation'
        curr_iteration = self.curr_iter(training_str)
        if curr_iteration == self.cuda_graph_warmup_steps:
            print_rank_0 (f'Capture CUDA graph for {training_str}!!!')
            torch.distributed.barrier()
            assert FullCudaGraphWrapper.cuda_graph[training_str] is None
            FullCudaGraphWrapper.cuda_graph[training_str] = torch.cuda.CUDAGraph()
            for _, state in get_all_rng_states().items():
                FullCudaGraphWrapper.cuda_graph[training_str].register_generator_state(state)
            torch.cuda.synchronize()
            capture_stream = torch.cuda.Stream()
            with torch.cuda.graph(FullCudaGraphWrapper.cuda_graph[training_str], stream=capture_stream, capture_error_mode="thread_local"):
                FullCudaGraphWrapper.result[training_str] = self.forward_backward_func (*args, **kwargs)
            torch.cuda.synchronize()
            torch.distributed.barrier()
            print_rank_0 (f'CUDA graph capture done!!!')

        if FullCudaGraphWrapper.cuda_graph[training_str] is None:
            FullCudaGraphWrapper.result[training_str] = self.forward_backward_func (*args, **kwargs)
        else:
            FullCudaGraphWrapper.cuda_graph[training_str].replay()
            
        self.next_iter(training_str)
        return FullCudaGraphWrapper.result[training_str]

    def curr_iter(self, stage):
        return FullCudaGraphWrapper.curr_iteration[stage]
    def next_iter(self, stage):
        FullCudaGraphWrapper.curr_iteration[stage] += 1
