import torch
import numpy as np

try:
    import horovod.torch as hvd

    hvd_available = True
except ImportError:
    hvd_available = False

try:
    from apex import amp

    amp_available = True
except ImportError:
    amp_available = False
OTHERS = 'grad_compression_sim.others'


class CompressorManager:
    # This class is a wrapper of multiple compressors.
    # User can compress a different layers of the model with different compressors
    def __init__(self, compressors, named_parameters=None, excluded_layer_names=["layer_norm", "bias", "bn"]):
        # compressors could be a single compressor, then CompressorManager acts as a compressor,
        # dictionary with at least two entries, as a key could be used name of tuple of names of the layers or `others`.
        self.compressors_per_layer = {}
        if isinstance(compressors, dict):
            assert named_parameters is not None, "named_parameters must not be None for multiple compressors setup"
            for key, value in compressors.items():
                if isinstance(key, str):
                    self.compressors_per_layer[key] = value
                elif isinstance(key, tuple):
                    for name in key:
                        self.compressors_per_layer[name] = value
            assert OTHERS in self.compressors_per_layer
        elif isinstance(compressors, Compressor):
            self.compressors_per_layer[OTHERS] = compressors
        self.compressors = set(self.compressors_per_layer.values())
        self.is_adaptive = False
        for compressor in self.compressors:
            if compressor.is_adaptive:
                self.is_adaptive = True
                break
        if named_parameters:
            named_parameters = list(named_parameters)
            self.parameters_names = {p: name for name, p in named_parameters if p.requires_grad}
            assert isinstance(excluded_layer_names, list)
            self.excluded_layer_names = excluded_layer_names
            other_layers = []
            if OTHERS in self.compressors_per_layer:
                for name, p in named_parameters:
                    if name not in self.compressors_per_layer:
                        excluded = False
                        for excl_name in excluded_layer_names:
                            if excl_name in name:
                                excluded = True
                                break
                        if not excluded:
                            other_layers.append(name)
            # remove unnecessary states from compressors
            for compressor in self.compressors:
                layers = []
                for name in self.compressors_per_layer.keys():
                    if self.compressors_per_layer[name] == compressor:
                        if name == OTHERS:
                            layers = other_layers
                        else:
                            excluded = False
                            for excl_name in excluded_layer_names:
                                if excl_name in name:
                                    excluded = True
                                    break
                            if not excluded:
                                layers.append(name)
                compressor.clean_states(layers)
        else:
            self.parameters_names = {}
            self.excluded_layer_names = []

    def add_excluded_layer_names(self, layer_names):
        if isinstance(layer_names, list):
            self.excluded_layer_names.extend(layer_names)
        else:
            self.excluded_layer_names.append(layer_names)

    def compress_param(self, p):
        if self.parameters_names:
            for name in self.excluded_layer_names:
                if name in self.parameters_names[p]:
                    return p.grad, None
        if self.parameters_names and self.parameters_names[p] in self.compressors_per_layer:
            return self.compressors_per_layer[self.parameters_names[p]].compress_param(p)
        else:
            return self.compressors_per_layer[OTHERS].compress_param(p)

    def decompress_param(self, p, ctx):
        if self.parameters_names:
            for name in self.excluded_layer_names:
                if name in self.parameters_names[p]:
                    return p.grad
        if self.parameters_names and self.parameters_names[p] in self.compressors_per_layer:
            return self.compressors_per_layer[self.parameters_names[p]].decompress_param(p, ctx)
        else:
            return self.compressors_per_layer[OTHERS].decompress_param(p, ctx)

    def compress(self, tensor, state):
        if "name" not in state:
            return self.compressors_per_layer[OTHERS].compress(tensor, state)
        layer_name = state["name"]
        for name in self.excluded_layer_names:
            if name in layer_name:
                return tensor
        if layer_name in self.compressors_per_layer:
            return self.compressors_per_layer[layer_name].compress(tensor, state)
        else:
            return self.compressors_per_layer[OTHERS].compress(tensor, state)

    def update_metric_stats(self, parameters):
        for p in parameters:
            if self.parameters_names and self.parameters_names[p] in self.compressors_per_layer:
                self.compressors_per_layer[self.parameters_names[p]].update_metric_stats([p])
            else:
                self.compressors_per_layer[OTHERS].update_metric_stats([p])

    def reset_metrics(self):
        for compressor in self.compressors:
            compressor.reset_metrics()

    def adjust_params(self):
        for compressor in self.compressors:
            compressor.adjust_params()

    def __getattr__(self, name):
        def method(*args, **kw):
            merged_result = {}
            for compressor in self.compressors:
                attr = getattr(compressor, name, None)
                if attr is None:
                    continue
                elif callable(attr):
                    result = compressor.__getattribute__(name)(*args, **kw)
                    if result is not None:
                        merged_result[type(compressor).__name__] = result
                else:
                    merged_result[type(compressor).__name__] = attr
            if merged_result:
                return merged_result

        return method


class Compressor:
    def __init__(self, enable_error_correction=False, warmup_steps=None):
        self.save_error_correction = enable_error_correction
        self.apply_error_correction = enable_error_correction
        self.is_adaptive = False
        self.warmup_steps = warmup_steps
        self.adjuster = None
        self.values_set = None

    def add_adjuster(self, adjuster, values_set=None):
        self.is_adaptive = adjuster is not None
        self.adjuster = adjuster
        self.values_set = values_set
        values_set = np.array(values_set)
        makes_sense = np.sum(values_set > self.get_default_param()) > 0 and np.sum(
            values_set < self.get_default_param()) > 0
        if not makes_sense:
            raise RuntimeError(
                "Adjsuter: values set: {} has to contain values higher and lower than default param {}".format(
                    values_set, self.get_default_param()))

    def get_all_metrics(self, states):
        d = {}
        d["static"] = 0.0
        d["dynamic"] = 0.0
        d["errs"] = []
        for state in states:
            val = self.get_metric(state, d)
            d["errs"].append(-1.0)
            if val:
                d["errs"][-1] = val
        return d

    def get_metric(self, state, d=None):
        # if "error_correction" not in state:
        #     return None
        # buf = state["error_correction"]
        if "acc_grad" not in state:
            return None
        buf = state["acc_grad"]
        # estimate_num = 100
        # values, _ = buf.abs().view(-1).topk(min(estimate_num, buf.numel()))
        # return torch.norm(values, p=2).item()
        value = torch.norm(buf, p=2).item()
        if value == float("inf") or value != value or value < 1e-10:
            # don't take this value into account
            return None
        if d and self.is_adaptive:
            old_params = self.get_compression_parameter(state)
            buf_copy = buf.clone()
            self.compress(buf_copy, state)
            d["dynamic"] += torch.norm(buf_copy - buf, p=2).item()
            self.set_compression_parameter(state, self.get_default_param())
            buf_copy = buf.clone()
            self.compress(buf_copy, state)
            d["static"] += torch.norm(buf_copy - buf, p=2).item()
            self.set_compression_parameter(state, old_params)
        return value

    def get_compression_error(self, state, comp_param):
        if "acc_grad" not in state:
            return None
        buf = state["acc_grad"]
        if torch.isinf(buf).sum() > 0:
            return float("inf")
        old_comp_param = self.get_compression_parameter(state)
        self.set_compression_parameter(state, comp_param)
        val = (buf - self.compress_(buf.detach().clone(), state)).norm(p=2).item()
        self.set_compression_parameter(state, old_comp_param)
        return val

    def compress(self, grad, state):
        if "step" not in state:
            state["step"] = 0
        step = state["step"]
        if self.warmup_steps and step < self.warmup_steps:
            return grad, None
        grad_ = grad
        # if amp_available and grad.dtype == torch.float16:
        #     loss_scale = amp.state_dict()["loss_scaler0"]['loss_scale']
        #     grad_ = grad.float().div_(loss_scale)

        if self.save_error_correction:
            if "error_correction" not in state:
                state["error_correction"] = torch.zeros_like(grad_)
            e_c = state["error_correction"]
            # update error correction before subtraction
            e_c.add_(grad_)
            # add error correction
            if self.apply_error_correction:
                grad_.copy_(e_c)

        self.compress_(grad_, state)

        if self.save_error_correction:
            e_c.sub_(grad_)
            # if torch.distributed.get_rank() == 0:
            #     print(f"Error correction {state['layer_id']}: {torch.norm(e_c, p=2)}")
        # if amp_available and grad.dtype == torch.float16:
        #     grad.copy_(grad_.mul_(loss_scale).half())
        return grad, None

    def adjust_params(self, states):
        if not states or not self.is_adaptive:
            return
        assert isinstance(states[0], dict)
        if torch.distributed.get_rank() == 0:
            best_params = self.do_adjust(states)
            torch.distributed.broadcast_object_list([best_params], src=0)
        else:
            broadcast_output = [None]
            torch.distributed.broadcast_object_list(broadcast_output, src=0)
            best_params = broadcast_output[0]
        for p, state in zip(best_params, states):
            state[self.get_compression_parameter_name()] = p
        return best_params
        # self.set_states_by_compression_scheme(best_params)

    def compress_(self, grad, state):
        raise NotImplementedError

    def get_default_param(self):
        raise NotImplementedError

    def set_default_param(self, param):
        raise NotImplementedError

    def set_states_by_compression_scheme(self, best_params):
        raise NotImplementedError

    def get_compression_parameter(self, state):
        raise NotImplementedError

    def set_compression_parameter(self, state, parameter):
        raise NotImplementedError

    def get_compression_scheme(self):
        return {}

    def get_compressed_size(self, t, b):
        raise NotImplementedError

    def do_adjust(self, states):
        raise NotImplementedError

    def get_compression_parameter_name(self):
        raise NotImplementedError


class NoneCompressor(Compressor):
    def __init__(self):
        super().__init__()

    def compress_(self, grad, state):
        return grad


class NoiseCompressor(Compressor):
    def __init__(self, internal_compressor, named_parameters=None):
        super().__init__(internal_compressor.apply_error_correction, named_parameters, internal_compressor.warmup_steps)
        self.internal_compressor = internal_compressor

    def compress_(self, grad, state):
        grad_copy = grad.clone()
        self.internal_compressor.compress_(grad_copy, state)
        error_l2 = torch.norm(grad_copy - grad, p=2)
        noise = torch.randn_like(grad)
        noise.div_(torch.norm(noise, p=2)).mul_(error_l2)
        grad.add_(noise)
        return grad
