"""
Integrates EG coding with a DDP communication hook.
"""

import time

import torch
import torch.distributed as dist

from eg_coding import encode, decode

QUANT_FAC = 2**48


class EGHookState:
    def __init__(self):
        # Number of hook calls.
        self.calls = 0
        # Total number of parameters transferred.
        self.params = 0
        # Total number of bytes transferred.
        self.bytes = 0

        # 0: to_cpu
        # 1: quantize
        # 2: encode
        # 3: all_gather
        # 4: decode
        # 5: dequantize
        # 6: to_device
        self.profiling = [0] * 100

        # self.grads = []


def ddp_eg_coding(state: None | EGHookState, bucket):
    """
    Use C++ implementation of EG coding to compress gradients.

    state: Optional, EGHookState instance for recording stats.
    bucket: DDP gradient bucket.
    """
    # # Encode tensor with EG.
    # grad = bucket.buffer()
    # orig_device = grad.device
    # grad = (grad * QUANT_FAC).to(torch.int32)
    # positive = grad > 0
    # grad = torch.abs(grad) * 2
    # grad[positive] -= 1

    # words, bits = encode(grad)
    # grad_eg = (words.view(torch.uint8), bits, grad.numel())

    # if state is not None:
    #     # Update stats.
    #     state.calls += 1
    #     state.params += grad.numel()
    #     state.bytes += grad_eg[0].numel()

    # # All gather.
    # world_size = dist.get_world_size()
    # gather_list = [None for _ in range(world_size)]
    # dist.all_gather_object(gather_list, grad_eg)

    # # Decode result.
    # grad_list = []
    # for data in gather_list:
    #     decoded = decode(data[0].view(torch.uint64), data[1], data[2])
    #     positive = decoded % 2 == 1
    #     decoded[positive] += 1
    #     decoded //= 2
    #     decoded[~positive] *= -1
    #     grad_list.append(decoded)
    # grad = torch.stack(grad_list).sum(dim=0).float() / QUANT_FAC
    # grad = grad.to(orig_device)

    # fut = torch.futures.Future()
    # fut.set_result(grad)
    # return fut

    grad = bucket.buffer()
    orig_device = grad.device

    # Quantization
    grad_quantized = grad # (grad * QUANT_FAC).long() / QUANT_FAC

    # Tensor Communication
    world_size = dist.get_world_size()
    gather_list = [torch.zeros_like(grad_quantized) for _ in range(world_size)]

    dist.all_gather(gather_list, grad_quantized)

    # Dequantization
    grad_sum = torch.zeros_like(grad)
    for quant_grad in gather_list:
        grad_sum += quant_grad  # .float() / QUANT_FAC

    grad = grad_sum.to(orig_device)

    fut = torch.futures.Future()
    fut.set_result(grad)
    return fut


def _noop(state: EGHookState, bucket):
    """
    No-op DDP hook. Used to records stats.
    """
    data = bucket.buffer()
    if state is not None:
        state.calls += 1
        state.params += data.numel()
        state.bytes += data.numel() * data.element_size()

    dist.all_reduce(data)

    fut = torch.futures.Future()
    fut.set_result(data)

    return fut
