import math
import pickle

import torch
from torch import distributed as dist
from torch.utils.data.sampler import Sampler


def get_rank():
	if not dist.is_available():
		return 0

	if not dist.is_initialized():
		return 0

	return dist.get_rank()


def synchronize():
	if not dist.is_available():
		return

	if not dist.is_initialized():
		return

	world_size = dist.get_world_size()

	if world_size == 1:
		return

	dist.barrier()


def get_world_size():
	if not dist.is_available():
		return 1

	if not dist.is_initialized():
		return 1

	return dist.get_world_size()


def reduce_sum(tensor):
	if not dist.is_available():
		return tensor

	if not dist.is_initialized():
		return tensor

	tensor = tensor.clone()
	dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

	return tensor


def gather_grad(params):
	world_size = get_world_size()

	if world_size == 1:
		return

	for param in params:
		if param.grad is not None:
			dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
			param.grad.data.div_(world_size)


def all_gather(data):
	world_size = get_world_size()

	if world_size == 1:
		return [data]

	buffer = pickle.dumps(data)
	storage = torch.ByteStorage.from_buffer(buffer)
	tensor = torch.ByteTensor(storage).to('cuda')

	local_size = torch.IntTensor([tensor.numel()]).to('cuda')
	size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
	dist.all_gather(size_list, local_size)
	size_list = [int(size.item()) for size in size_list]
	max_size = max(size_list)

	tensor_list = []
	for _ in size_list:
		tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))

	if local_size != max_size:
		padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
		tensor = torch.cat((tensor, padding), 0)

	dist.all_gather(tensor_list, tensor)

	data_list = []

	for size, tensor in zip(size_list, tensor_list):
		buffer = tensor.cpu().numpy().tobytes()[:size]
		data_list.append(pickle.loads(buffer))

	return data_list


def reduce_loss_dict(loss_dict):
	world_size = get_world_size()

	if world_size < 2:
		return loss_dict

	with torch.no_grad():
		keys = []
		losses = []

		for k in sorted(loss_dict.keys()):
			keys.append(k)
			losses.append(loss_dict[k])

		losses = torch.stack(losses, 0)
		dist.reduce(losses, dst=0)

		if dist.get_rank() == 0:
			losses /= world_size

		reduced_losses = {k: v for k, v in zip(keys, losses)}

	return reduced_losses
