try:
    from scipy import optimize
    from scipy.cluster.vq import whiten, kmeans2, vq
    scipy_available = True
except ModuleNotFoundError:
    scipy_available = False
    pass

try:
    from bayes_optim import BO, OrdinalSpace
    from bayes_optim.Surrogate import GaussianProcess, RandomForest
    bayes_opt_available = True
except ModuleNotFoundError:
    bayes_opt_available = False


import numpy as np
import torch


class Adjuster:
    def __init__(self, compressor, alpha=1.0):
        self.compressor = compressor
        self.static_mapping = {}
        self.excluded_provided = set()
        self.excluded = set()
        self.alpha = alpha

    def fit_predict(self, states, values_set):
        raise NotImplementedError("fit_predict is not implemented")

    def add_excluded(self, layer, param=None):
        self.excluded_provided.add(layer)
        if param:
            self.static_mapping[layer] = param

    def compute_error(self, states, compress_params):
        result = 0.0
        for i, state in enumerate(states):
            if state["layer_id"] in self.excluded:
                continue
            buf = state["acc_grad"]
            if buf.dtype == torch.float16:
                value = torch.norm(buf.astype(torch.float32), p=2).item()
            else:
                value = torch.norm(buf, p=2).item()
            if value == float("inf") or value != value or value < 1e-10:
                # don't take this value into account
                self.excluded.add(state["layer_id"])
                continue
            buf_copy = buf.clone()
            old_param = self.compressor.get_compression_parameter(state)
            self.compressor.set_compression_parameter(state, compress_params[i])
            self.compressor.compress_(buf_copy, state)
            self.compressor.set_compression_parameter(state, old_param)
            val = torch.norm(buf_copy - buf, p=2).item()
            if "DP_weight" in state:
                val *= state["DP_weight"]
            if val != float("inf"):
                result += val
        return result

class DPAdjuster(Adjuster):
    def __init__(self, compressor, alpha=0.5):
        super(DPAdjuster, self).__init__(compressor, alpha)

    def fit_predict(self, states, values_set):
        self.excluded = self.excluded_provided.copy()
        static_sum = self.compute_error(states,
                                        [self.compressor.get_default_param() for i in range(len(states))])
        compressed_sizes = []
        compression_errors = []
        target_score = static_sum * (1 + self.alpha)
        for i, state in enumerate(states):
            compressed_sizes.append([])
            compression_errors.append([])
            for compression_param in values_set:
                if state["layer_id"] in self.excluded:
                    compressed_sizes.pop()
                    compression_errors.pop()
                    break
                buf = state["acc_grad"]
                value = torch.norm(buf, p=2).item()
                if value == float("inf") or value != value or value < 1e-10:
                    # don't take this value into account
                    self.excluded.add(state["layer_id"])
                    compressed_sizes.pop()
                    compression_errors.pop()
                    break
                buf_copy = buf.clone()
                self.compressor.set_compression_parameter(state, compression_param)
                self.compressor.compress_(buf_copy, state)
                if buf_copy.dtype == torch.float16:
                    val = torch.norm((buf_copy - buf).astype(torch.float32), p=2).item()
                else:
                    val = torch.norm(buf_copy - buf, p=2).item()
                if "DP_weight" in state:
                    val *= state["DP_weight"]
                compression_errors[-1].append(val)
                compressed_sizes[-1].append(self.compressor.get_compressed_size(buf, compression_param))
                # compressed_sizes[-1].append(buf.numel() * compression_param)

        DPBUCKETS = 10000
        bucket_size = target_score / DPBUCKETS
        for compression_error_per_layer in compression_errors:
            for i in range(len(compression_error_per_layer)):
                compression_error_per_layer[i] = min(max(int(np.ceil(compression_error_per_layer[i] / bucket_size)), 1), DPBUCKETS)
        DP = np.full((len(compressed_sizes), DPBUCKETS + 1), float('inf'))
        PD = np.full((len(compressed_sizes), DPBUCKETS + 1), -1)

        for val_idx in range(len(values_set)):
            if val_idx >= len(compression_errors[0]):
                print(f"For {val_idx} out of index:  {len(compression_errors[0])}")
                return [self.compressor.get_default_param() for i in range(len(states))]
            if compression_errors[0][val_idx] > DPBUCKETS:
                print(f"For {values_set[val_idx]} error: {compression_errors[0][val_idx]}")
                return [self.compressor.get_default_param() for i in range(len(states))]
            if compressed_sizes[0][val_idx] < DP[0][compression_errors[0][val_idx]]:
                DP[0][compression_errors[0][val_idx]] = compressed_sizes[0][val_idx]
                PD[0][compression_errors[0][val_idx]] = val_idx

        for layer_idx in range(1, len(DP)):
            for val_idx in range(len(values_set)):
                comp_size = compressed_sizes[layer_idx][val_idx]
                comp_error = compression_errors[layer_idx][val_idx]
                tmp = DP[layer_idx - 1][:-comp_error] + comp_size
                if comp_error >= len(DP[layer_idx]) or comp_error < 0:
                    print(f"Failed at layer {layer_idx}")
                    print(f"{comp_error} >= {len(DP[layer_idx])}")
                better = tmp < DP[layer_idx][comp_error:]
                if np.sum(better):
                    DP[layer_idx][comp_error:][better] = tmp[better]
                    PD[layer_idx][comp_error:][better] = val_idx

        opt_compression_rate = np.min(DP[-1, :])
        opt_compression_error = np.argmin(DP[-1, :])

        result = []
        layer_idx = len(DP) - 1
        for state in reversed(states):
            if state["layer_id"] in self.excluded:
                result.append(self.compressor.get_default_param())
                state[self.compressor.get_compression_parameter_name()] = self.compressor.get_default_param()
            else:
                opt_compress_param = PD[layer_idx][opt_compression_error]
                opt_compression_error -= compression_errors[layer_idx][opt_compress_param]
                result.append(values_set[opt_compress_param])
                state[self.compressor.get_compression_parameter_name()] = values_set[opt_compress_param]
                layer_idx -= 1
        result.reverse()
        return result

    # def compute_compression_metric(self, states, compress_params):
    #     result = 0
    #     for i, (p, state) in enumerate(states.items()):
    #         result += p.numel() * compress_params[i]
    #     return result
    #
    # # Here the adjuster assumes that compressor computes metric for provided compression parameter
    # def get_cumulative_objective(self, states, compress_params, static_sum):
    #     dynamic_sum = self.compute_error(states, compress_params)
    #     err_diff = dynamic_sum - static_sum
    #     if err_diff / static_sum > self.alpha:
    #         # print("Didn't pass: ", compress_params)
    #         return err_diff / static_sum * self.compute_compression_metric(states, [2 * self.compressor.get_default_param()]*len(states))
    #     # print("Passed: ", compress_params)
    #     return self.compute_compression_metric(states, compress_params)
