import torch

from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.optimizer.optimizer import ChainedOptimizer

from gpatch.core.utils import clear_memory, print_memory_tracking


class _CheckTensorAttr:
    run_once_flag = False

    @staticmethod
    def check_attr():
        if not _CheckTensorAttr.run_once_flag:
            device_id = torch.cuda.current_device()
            tmp_tensor = torch.rand((10), device=device_id)
            assert not hasattr(tmp_tensor, "gcore_cpu_data")
            assert not hasattr(tmp_tensor, "gcore_untyped_storage_data_size")
            _CheckTensorAttr.run_once_flag = True


def offload_tensor_to_cpu(tensor):
    _CheckTensorAttr.check_attr()
    if tensor is None:
        return
    assert isinstance(tensor, torch.Tensor), f"{tensor=} type must be torch.Tensor"
    if not hasattr(tensor, "gcore_cpu_data"):
                                                             
        setattr(tensor, "gcore_cpu_data", tensor.data.to("cpu", non_blocking=True))
    else:
        assert tensor.gcore_cpu_data.shape == tensor.data.shape
        assert tensor.gcore_cpu_data.dtype == tensor.data.dtype
        tensor.gcore_cpu_data.copy_(tensor.data, non_blocking=True)
    tensor.gcore_untyped_storage_data_size = tensor.untyped_storage().size()
    assert tensor.gcore_untyped_storage_data_size == tensor.gcore_cpu_data.untyped_storage().size()
    tensor.untyped_storage().resize_(0)


def onload_tensor_to_gpu(tensor):
    if tensor is None:
        return
    tensor.data.untyped_storage().resize_(tensor.gcore_untyped_storage_data_size)
    tensor.data.copy_(tensor.gcore_cpu_data, non_blocking=True)


def release_tensor_mem(tensor):
    _CheckTensorAttr.check_attr()
    if tensor is None:
        return
    tensor.gcore_untyped_storage_data_size = tensor.untyped_storage().size()
    tensor.untyped_storage().resize_(0)


def recover_tensor_mem(tensor):
    if tensor is None:
        return
    tensor.data.untyped_storage().resize_(tensor.gcore_untyped_storage_data_size)


def copy_tensor_to_cpu(tensor):
    _CheckTensorAttr.check_attr()
    if tensor is None:
        return
    if not hasattr(tensor, "gcore_cpu_data"):
        setattr(tensor, "gcore_cpu_data", tensor.data.to("cpu", non_blocking=True))
    else:
        assert tensor.gcore_cpu_data.shape == tensor.data.shape
        assert tensor.gcore_cpu_data.dtype == tensor.data.dtype
        tensor.gcore_cpu_data.copy_(tensor.data, non_blocking=True)


def copy_tensor_to_gpu(tensor):
    if tensor is None:
        return
    tensor.data.copy_(tensor.gcore_cpu_data, non_blocking=True)


@torch.no_grad()
def offload_megatron_model(models):
    """
    In megatron, the model and optimizer storage are:
    - bf16 parameter data chunked in model parallel group
    - fp32 grad chunked in model parallel group
    - fp32 main_parameter chunked in model and dp group
    - fp32 optimizer state chunked in model and dp group
    """
    clear_memory()
    print_memory_tracking(f"Memory tracking: actor before model offload", verbose=True, rank=0)
    if models is None:
        return

    for model_chunk in models:
        if isinstance(model_chunk, DDP):
            model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]
            for buffers in model_chunk_all_buffers:
                for buffer in buffers:
                                                  
                    offload_tensor_to_cpu(buffer.param_data)
                    release_tensor_mem(buffer.grad_data)
                                      
            for _, param in model_chunk.module.named_parameters():
                if not param.requires_grad:
                    assert not param._is_view()
                    offload_tensor_to_cpu(param)
        else:
                                         
            for _, param in model_chunk.named_parameters():
                offload_tensor_to_cpu(param)
                if param.grad is not None:
                    offload_tensor_to_cpu(param.grad)
    clear_memory()
    print_memory_tracking(f"Memory tracking: actor after model offload", verbose=True, rank=0)


@torch.no_grad()
def copy_megatron_model_to_cpu(models):
    """
    In megatron, the model and optimizer storage are:
    - bf16 parameter data chunked in model parallel group
    - fp32 grad chunked in model parallel group
    - fp32 main_parameter chunked in model and dp group
    - fp32 optimizer state chunked in model and dp group
    """
    if models is None:
        return

    for model_chunk in models:
        if isinstance(model_chunk, DDP):
            model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]
            for buffers in model_chunk_all_buffers:
                for buffer in buffers:
                                                  
                    copy_tensor_to_cpu(buffer.param_data)
                                      
            for _, param in model_chunk.module.named_parameters():
                if not param.requires_grad:
                    assert not param._is_view()
                    copy_tensor_to_cpu(param)
        else:
                                         
            for _, param in model_chunk.named_parameters():
                copy_tensor_to_cpu(param)
    torch.cuda.synchronize()


@torch.no_grad()
def onload_megatron_model(models, load_grad=True):
    clear_memory()
    print_memory_tracking(f"Memory tracking: actor before model onload", verbose=True, rank=0)
    if models is None:
        return
    for model_chunk in models:
        if isinstance(model_chunk, DDP):
            model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]
            for buffers in model_chunk_all_buffers:
                for buffer in buffers:
                                                                              
                    if load_grad:
                        recover_tensor_mem(buffer.grad_data)
                    onload_tensor_to_gpu(buffer.param_data)
                                     
            for _, param in model_chunk.module.named_parameters():
                if not param.requires_grad:
                    assert not param._is_view()
                    onload_tensor_to_gpu(param)
        else:
                                         
            for _, param in model_chunk.named_parameters():
                onload_tensor_to_gpu(param)
                if param.grad is not None:
                    onload_tensor_to_gpu(param.grad)
    clear_memory()
    print_memory_tracking(f"Memory tracking: actor after model onload", verbose=True, rank=0)


@torch.no_grad()
def copy_megatron_model_to_gpu(models):
    if models is None:
        return
    for model_chunk in models:
        if isinstance(model_chunk, DDP):
            model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers]
            for buffers in model_chunk_all_buffers:
                for buffer in buffers:
                    copy_tensor_to_gpu(buffer.param_data)
                                     
            for _, param in model_chunk.module.named_parameters():
                if not param.requires_grad:
                    assert not param._is_view()
                    copy_tensor_to_gpu(param)
        else:
                                         
            for _, param in model_chunk.named_parameters():
                copy_tensor_to_gpu(param)
    torch.cuda.synchronize()


@torch.no_grad()
def offload_megatron_copy_params(optimizers):
    """
    Offload optimizer parameters to CPU
    
    Args:
        optimizers: The optimizer containing parameter groups to offload
    """

    def offload_group_to_cpu(group):
        if group is None:
            return

        if isinstance(group, list):
            for param_group in group:
                if isinstance(param_group, list):
                    for param in param_group:
                        offload_tensor_to_cpu(param)
                else:
                    offload_tensor_to_cpu(param_group)
        else:
            offload_tensor_to_cpu(group)

                                         

    if hasattr(optimizers, 'shard_fp32_from_float16_groups'):
        offload_group_to_cpu(getattr(optimizers, 'shard_fp32_from_float16_groups'))


@torch.no_grad()
def onload_megatron_copy_params(optimizers):
    """
    Load optimizer parameters back to GPU
    
    Args:
        optimizers: The optimizer containing parameter groups to load
    """

    def load_group_to_gpu(group):
        if group is None:
            return

        if isinstance(group, list):
            for param_group in group:
                if isinstance(param_group, list):
                    for param in param_group:
                        onload_tensor_to_gpu(param)
                else:
                    onload_tensor_to_gpu(param_group)
        else:
            onload_tensor_to_gpu(group)

                                      

    if hasattr(optimizers, 'shard_fp32_from_float16_groups'):
        load_group_to_gpu(getattr(optimizers, 'shard_fp32_from_float16_groups'))


@torch.no_grad()
def offload_megatron_optimizer(optimizers):
    clear_memory()
    print_memory_tracking(f"Memory tracking: actor before optimizer offload", verbose=True, rank=0)
    if optimizers is None:
        return
    optimizer_lst = []
    if isinstance(optimizers, ChainedOptimizer):
        chained_optimizers = optimizers.chained_optimizers
        optimizer_lst.extend(chained_optimizers)
    else:
        optimizer_lst.append(optimizers)

    for optimizer in optimizer_lst:
        offload_megatron_copy_params(optimizer)
        opt_state_dict_values = optimizer.optimizer.state.values()

        for v in opt_state_dict_values:
            offload_tensor_to_cpu(v['exp_avg'])
            offload_tensor_to_cpu(v['exp_avg_sq'])

    clear_memory()
    print_memory_tracking(f"Memory tracking: actor after optimizer offload", verbose=True, rank=0)


@torch.no_grad()
def onload_megatron_optimizer(optimizers):
    clear_memory()
    print_memory_tracking(f"Memory tracking: actor before optimizer onload", verbose=True, rank=0)
    if optimizers is None:
        return
    optimizer_lst = []
    if isinstance(optimizers, ChainedOptimizer):
        chained_optimizers = optimizers.chained_optimizers
        optimizer_lst.extend(chained_optimizers)
    else:
        optimizer_lst.append(optimizers)

    for optimizer in optimizer_lst:
        onload_megatron_copy_params(optimizer)

        opt_state_dict_values = optimizer.optimizer.state.values()
        for v in opt_state_dict_values:
            onload_tensor_to_gpu(v['exp_avg'])
            onload_tensor_to_gpu(v['exp_avg_sq'])

    clear_memory()
    print_memory_tracking(f"Memory tracking: actor after optimizer onload", verbose=True, rank=0)
