import torch
from .compressor import CompressorManager, Compressor, NoneCompressor
from typing import Any, Tuple, Union, Iterator, List
import torch.distributed as dist
import json
import numpy as np
import os


class SimCompressionState(object):
    def __init__(self, process_group: dist.ProcessGroup,
                 compressor: Union[Compressor, CompressorManager] = NoneCompressor, params_per_layer: List[Any] = None,
                 acc_grad_coef: float = None, warmup_period: int = None, adjust_freq: int = None, save_dir=None):
        self.process_group = process_group if process_group is not None else dist.group.WORLD
        self.total_size = 0
        self.step = 0
        self.layers_states = []
        self.compressor = compressor
        self.params_per_layer = params_per_layer
        self.acc_grad_coef = acc_grad_coef
        self.adjust_freq = adjust_freq
        self.warmup_period = 2
        self.DP_weights = None
        if warmup_period:
            self.warmup_period = max(warmup_period, self.warmup_period)
        self.save_dir = save_dir
        # if named_parameters:
        #     named_parameters = list(named_parameters)
        #     layer_id = len(named_parameters) - 1
        #     for name, p in named_parameters:
        #         if not p.requires_grad:
        #             continue
        #         self.layers_states[layer_id] = {}
        #         self.layers_states[layer_id]["name"] = name
        #         self.layers_states[layer_id]["step"] = 0
        #         self.layers_states[layer_id]["layer_size"] = p.numel()
        #         layer_id -= 1

    def compute_error_stats(self):
        if self.params_per_layer is None or len(self.params_per_layer) == 0 or self.acc_grad_coef is None:
            return []
        error_stats = []
        for b_idx, bucket_params in enumerate(self.params_per_layer):
            error_stats.append([])
            for l_idx, p in enumerate(bucket_params):
                if p < 0 or p > 8:
                    error_stats[b_idx].append(0.0)
                    continue
                state = self.layers_states[b_idx][l_idx]
                buf = state["acc_grad"]
                buf_copy = buf.clone()
                self.compressor.compress(buf_copy, state)
                error_stats[b_idx].append(torch.dist(buf, buf_copy).item())
        return error_stats

    def set_DP_weights(self, DP_weights):
        assert len(DP_weights) == len(self.layers_states)
        for state_b, w_b in zip(self.layers_states, DP_weights):
            assert len(state_b) == len(w_b)
            for state, weight in zip(state_b, w_b):
                state["DP_weight"] = weight

def _allreduce_fut(
        process_group: dist.ProcessGroup, tensor: torch.Tensor
) -> torch.futures.Future[torch.Tensor]:
    "Averages the input gradient tensor by allreduce and returns a future."
    group_to_use = process_group if process_group is not None else dist.group.WORLD

    # Apply the division first to avoid overflow, especially for FP16.
    tensor.div_(group_to_use.size())

    return (
        dist.all_reduce(tensor, group=group_to_use, async_op=True)
            .get_future()
            .then(lambda fut: fut.value()[0])
    )


def sim_compression_hook(
        state: SimCompressionState, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
    numel = bucket.buffer().numel()
    if state.step >= state.warmup_period:
        # This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP,
        # and this can conflict with any tensor indexation before the rebuild process.
        bucket_index = bucket.index()
        numel = 0
        l_idx = 0

        for tensor in bucket.gradients():
            if tensor.numel() < 1024 or tensor.dim() < 2:
                continue
            if len(state.layers_states) <= bucket_index:
                state.layers_states.append([])
            bucket_states = state.layers_states[bucket_index]

            if len(bucket_states) <= l_idx:
                bucket_states.append({"step": 0, "layer_id": (bucket_index, l_idx), "layer_size": tensor.numel(), "DP_weight": 1.0})
                if state.acc_grad_coef is not None:
                    bucket_states[l_idx]["acc_grad"] = torch.zeros_like(tensor, dtype=torch.float32)
                if state.compressor:
                    bucket_states[l_idx][state.compressor.get_compression_parameter_name()] = state.compressor.get_default_param()

            layer_state = bucket_states[l_idx]

            layer_state["step"] = state.step
            if state.acc_grad_coef is not None and state.adjust_freq is not None and (state.step % state.adjust_freq) >= (state.adjust_freq - 10):
                layer_state["acc_grad"].add_(tensor.to(torch.float32), alpha=state.acc_grad_coef)
            l_numel = tensor.numel()
            if state.compressor:
                if state.params_per_layer:
                    assert len(state.params_per_layer) > bucket_index
                    if len(state.params_per_layer[bucket_index]) == 0:
                        continue
                    assert len(state.params_per_layer[bucket_index]) > l_idx
                    layer_state[state.compressor.get_compression_parameter_name()] = state.params_per_layer[bucket_index][l_idx]
                state.compressor.compress(tensor, layer_state)
                l_numel *= layer_state[state.compressor.get_compression_parameter_name()]
                l_numel = int(l_numel)
            l_idx += 1
            numel += l_numel
    if bucket.is_last():
        # We can not rely on is_the_last_bucket_to_allreduce as it returns True after the first bucket
        if state.compressor and state.compressor.adjuster and state.adjust_freq is not None \
                and state.step > state.warmup_period and state.step % state.adjust_freq == 0:
            best_params = state.compressor.adjust_params([layer_state for b in state.layers_states for layer_state in b])
            if torch.distributed.get_rank() == 0:
                report = {"bits": best_params}
                lgreco_compressed_size = 0
                static_compressed_size = 0
                uncompressed_size = 0
                for b in state.layers_states:
                    for l_state in b:
                        lgreco_compressed_size += state.compressor.get_compressed_size(l_state["acc_grad"], l_state[state.compressor.get_compression_parameter_name()])
                        static_compressed_size += state.compressor.get_compressed_size(l_state["acc_grad"], state.compressor.get_default_param())
                        uncompressed_size += l_state["layer_size"]
                report["lgreco_size"] = int(lgreco_compressed_size)
                report["uniform_size"] = int(static_compressed_size)
                report["uncompressed_size"] = int(uncompressed_size)
                report["lgreco_ratio"] = report["lgreco_size"] / report["uncompressed_size"]
                report["uniform_ratio"] = report["uniform_size"] / report["uncompressed_size"]
                print(report)
                # result_dir="wlgreco_san_qsgd_reports"
                if state.save_dir:
                    result_dir=state.save_dir
                    os.makedirs(result_dir, exist_ok=True)
                    with open(f"{result_dir}/optimal_{np.min(state.compressor.values_set)}_{state.compressor.get_default_param()}_{np.max(state.compressor.values_set)}_{state.step  // state.adjust_freq}.txt",
                            'w') as f:
                        json.dump(report, f)

        state.step += 1
    return _allreduce_fut(state.process_group, bucket.buffer()[:numel])


class SimCompressionStateFused(SimCompressionState):
    def __init__(self, process_group: dist.ProcessGroup,
                 compressor: Union[Compressor, CompressorManager] = NoneCompressor,
                 warmup_period: int = None, parameters: Iterator[torch.nn.Parameter] = None):
        super().__init__(process_group, compressor)
        self.model_size = 0
        self.warmup_period = warmup_period
        self.parameters = list(parameters)
        for p in self.parameters:
            if p.requires_grad:
                self.model_size += p.numel()
        self.memory = None
        self.compression_state = {"layer_id": (0, 0)}
        self.step = 0
        if compressor:
            self.compression_state[compressor.get_compression_parameter_name()] = compressor.get_default_param()

    def fuse_buffers(self):
        idx = 0
        for p in self.parameters:
            if p.requires_grad:
                self.memory[idx:idx + p.numel()].copy_(p.grad.view(-1))
                idx += p.numel()

    def copy_buffers_back(self, parameters=None):
        torch.cuda.synchronize()
        if parameters is None:
            parameters = self.parameters
        idx = 0
        for p in parameters:
            if p.requires_grad:
                p.grad.data.copy_(self.memory[idx:idx + p.numel()].view_as(p.grad.data))
                idx += p.numel()


def _none_future(tensor: torch.Tensor) -> torch.futures.Future[torch.Tensor]:
    fut = torch.futures.Future()
    # zeros = torch.zeros_like(tensor)
    fut.set_result(tensor)
    return fut


def sim_compression_hook_fused(state: SimCompressionStateFused, bucket: dist.GradBucket) -> torch.futures.Future[
    torch.Tensor]:
    input_tensor = bucket.buffer()
    device = input_tensor.device
    dtype = input_tensor.dtype
    state.compression_state["step"] = state.step

    if state.memory is None:
        state.memory = torch.empty(state.model_size, device=device, dtype=dtype)
    if state.step < state.warmup_period:
        if bucket.is_last():
            state.step += 1
        return _allreduce_fut(state.process_group, bucket.buffer())

    # def copy_result_back(future):
    #     tensor = future.value()
    #     state.copy_buffers_back(tensor)
    #     torch.cuda.synchronize(tensor.device)
    #     input_tensor.copy_(tensor[-input_tensor.numel():])
    #     return input_tensor

    if bucket.is_last():
        # do compression and actual reduction
        state.fuse_buffers()
        num_elems = state.memory.numel()
        if state.compressor:
            state.compressor.compress(state.memory, state.compression_state)
            ratio = state.compressor.get_default_param()
            # num_elems = int(np.floor(ratio * num_elems))

            # torch.cuda.synchronize(input_tensor.device)
        fut = _allreduce_fut(state.process_group, state.memory[:num_elems])
        state.step += 1
        return fut
    else:
        return _none_future(input_tensor)
