import contextlib
import re
from typing import Optional

import torch.distributed

from utils.logger.logger import Logger
from utils.utils import get_object_name, get_class_name
from .distributed_manager import DistributedManager
from ..utils import named_params_and_buffers


@contextlib.contextmanager
def ddp_sync(model: torch.nn.Module, sync: bool) -> None:
    assert isinstance(model, torch.nn.Module), f'{model} is not an instance of torch.nn.Module'
    if sync or not isinstance(model, torch.nn.parallel.DistributedDataParallel):
        yield
    else:
        with model.no_sync():
            yield


def check_consistency(fullname: str, tensor: torch.Tensor, ignore_regex: str = None) -> tuple[bool, float]:
    if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
        return True, 0
    tensor: torch.Tensor = tensor.detach()
    if tensor.is_floating_point():
        tensor: torch.Tensor = torch.nan_to_num(tensor)
    other: torch.Tensor = tensor.clone()
    torch.distributed.broadcast(tensor=other, src=DistributedManager.main_rank())
    outputs: list[Optional[tuple[bool, float]]] = [None for _ in range(DistributedManager.world_size)]
    value: tuple[bool, float] = \
        (tensor == other).all().detach().cpu().item(), torch.max(torch.abs(tensor - other)).detach().cpu().item()
    torch.distributed.all_gather_object(object_list=outputs, obj=value)
    consistent: bool = all(output[0] for output in outputs)
    max_diff: float = max(output[1] for output in outputs)
    return consistent, max_diff


def check_ddp_consistency(model: torch.nn.Module, ignore_regex: str = None) -> bool:
    assert isinstance(model, torch.nn.Module), f'{model} is not an instance of torch.nn.Module'
    if not DistributedManager.initialized:
        Logger.debug(
            f'{get_class_name(check_ddp_consistency)} - distributed not initialized, skipping consistency check')
        return True
    Logger.debug(f'{get_class_name(check_ddp_consistency)} - checking ddp consistency: {get_object_name(model)}')
    values: dict[str, tuple[bool, float]] = {}
    for name, tensor in named_params_and_buffers(model):
        fullname: str = type(model).__name__ + '.' + name
        consistent, max_diff = check_consistency(fullname, tensor, ignore_regex)
        values[fullname] = (consistent, max_diff)
    for fullname, (consistent, max_diff) in values.items():
        if not consistent:
            Logger.warning(f'ddp consistency check failed: {fullname}: {max_diff}')
            return False
    Logger.debug(f'{get_class_name(check_ddp_consistency)} - ddp consistency check passed')
    return True


def check_ddp_params_consistency(model: torch.nn.Module, ignore_regex: str = None) -> bool:
    assert isinstance(model, torch.nn.Module), f'{model} is not an instance of torch.nn.Module'
    if not DistributedManager.initialized:
        Logger.debug(
            f'{get_class_name(check_ddp_params_consistency)} - '
            f'distributed not initialized, skipping ddp params consistency check'
        )
        return True
    Logger.debug(
        f'{get_class_name(check_ddp_params_consistency)} - checking ddp params consistency: {get_object_name(model)}')
    values: dict[str, tuple[bool, float]] = {}
    for name, tensor in model.named_parameters():
        fullname: str = type(model).__name__ + '.' + name
        consistent, max_diff = check_consistency(fullname, tensor, ignore_regex)
        values[fullname] = (consistent, max_diff)
    for fullname, (consistent, max_diff) in values.items():
        if not consistent:
            Logger.warning(f'ddp params consistency check failed: {fullname}: {max_diff}')
            return False
    Logger.debug(f'{get_class_name(check_ddp_params_consistency)} - ddp params consistency check passed')
    return True


def check_ddp_buffers_consistency(model: torch.nn.Module, ignore_regex: str = None) -> bool:
    assert isinstance(model, torch.nn.Module), f'{model} is not an instance of torch.nn.Module'
    if not DistributedManager.initialized:
        Logger.debug(
            f'{get_class_name(check_ddp_buffers_consistency)} - '
            f'distributed not initialized, skipping ddp buffers consistency check'
        )
        return True
    Logger.debug(
        f'{get_class_name(check_ddp_buffers_consistency)} - checking ddp buffers consistency: {get_object_name(model)}')
    values: dict[str, tuple[bool, float]] = {}
    for name, tensor in model.named_buffers():
        fullname: str = type(model).__name__ + '.' + name
        consistent, max_diff = check_consistency(fullname, tensor, ignore_regex)
        values[fullname] = (consistent, max_diff)
    for fullname, (consistent, max_diff) in values.items():
        if not consistent:
            Logger.warning(f'ddp buffers consistency check failed: {fullname}: {max_diff}')
            return False
    Logger.debug(f'{get_class_name(check_ddp_buffers_consistency)} - ddp buffers consistency check passed')
    return True
