import math
import struct
import numpy as np


class Identity():
    def __init__(self, d):
        self.w = 0
        self.uplink_cost = d
        self.name = "Identity"

    def compress(self, x):
        return x


class Rand_k():
    def __init__(self, k, d):
        self.k = k
        self.d = d
        self.w = d / k - 1
        coordinate_bits = np.ceil(k * np.log2(d))
        self.uplink_cost = k + coordinate_bits/32
        self.name = f"Rand-{self.k}"

    def compress(self, x):
        if self.k == 1:
            index = np.random.randint(self.d)
            compressed_x = np.zeros_like(x, dtype=np.float32)
            compressed_x[index] = x[index]
            return self.d * compressed_x

        indices = np.arange(self.d)

        # Shuffle the array of indices
        np.random.shuffle(indices)

        # Select d - k indices to be zeroed out
        zero_indices = indices[:self.d - self.k]

        # Copy x and set selected indices to zero
        compressed_x = np.copy(x)
        compressed_x[zero_indices] = 0

        return (self.d / self.k) * compressed_x
    
class Scaled_Rand_k():
    def __init__(self, k, d):
        self.k = k
        self.d = d
        self.alpha = k / d
        coordinate_bits = np.ceil(k * np.log2(d))
        self.uplink_cost = k + coordinate_bits/32
        self.name = f"Scaled Rand-{self.k}"

    def compress(self, x):
        if self.k == 1:
            index = np.random.randint(self.d)
            compressed_x = np.zeros_like(x, dtype=np.float32)
            compressed_x[index] = x[index]
            return self.d * compressed_x

        indices = np.arange(self.d)

        # Shuffle the array of indices
        np.random.shuffle(indices)

        # Select d - k indices to be zeroed out
        zero_indices = indices[:self.d - self.k]

        # Copy x and set selected indices to zero
        compressed_x = np.copy(x)
        compressed_x[zero_indices] = 0

        return compressed_x

class Top_k():
    def __init__(self, k, d):
        self.k = k
        self.d = d
        self.w = 1 - k/d
        coordinate_bits = np.ceil(k * np.log2(d))
        self.uplink_cost = k + coordinate_bits / 32
        self.name = f"Top-{self.k}"

    def compress(self, x):
        if self.k == 1:
            index = np.argmax(np.abs(x))
            compressed_x = np.zeros_like(x, dtype=np.float32)
            compressed_x[index] = x[index]
            return compressed_x

        # Find the indices of the top-k absolute values
        topk_indices = np.argpartition(np.abs(x), -self.k)[-self.k:]
        
        # Create the compressed vector
        compressed_x = np.zeros_like(x, dtype=np.float32)
        compressed_x[topk_indices] = x[topk_indices]

        return compressed_x

class Sign_1():
    def __init__(self, d):
        self.d = d
        self.w = self.d - 1
        coordinate_bits = np.ceil(np.log2(d))
        self.uplink_cost = 1 + coordinate_bits/32
        self.name = "Sign-1"

    def compress(self, x):
        compressed_x = np.zeros_like(x)
        p = np.absolute(x) / np.linalg.norm(x, ord=1)
        j = np.random.choice(np.arange(self.d), size=1, replace=False, p=p)
        compressed_x[j] = np.sign(x[j]) * np.linalg.norm(x, ord=1)

        return compressed_x


class Natural():
    def __init__(self, d):
        self.d = d
        self.name = 'Natural'
        self.w = 1 / 8

        self.uplink_cost = d * 9 / 32

        # computer_bits = struct.calcsize("P") * 8
        # if computer_bits == 32:
        #     self.uplink_cost = d * 9 / 32
        # elif computer_bits == 64:
        #     self.uplink_cost = d * 12 / 64
        # else:
        #     Exception(
        #         f"Invalid number of bits in computer format - bits = {computer_bits}")

    def compress(self, x):
        compressed_x = np.copy(x)  # just in case
        vectorized_natural_compression = np.vectorize(self.natural_compression)
        return vectorized_natural_compression(compressed_x)

    def natural_compression(self, t):
        if t == 0.0:
            return 0.0

        y = np.abs(t)
        denominator = 2**np.floor(np.log2(y))
        p = np.maximum((2**np.ceil(np.log2(y)) - y) / denominator, 0)
        return np.sign(t) * denominator * (1 + np.random.binomial(1, 1 - p))


class Natural_Rand_k():
    def __init__(self, d, k):
        self.k = k
        self.d = d
        self.natural = Natural(d=d)
        self.rand_k = Rand_k(d=d, k=k)
        self.name = f"{self.natural.name} + {self.rand_k.name}"
        self.w = (self.natural.w + 1) * (self.rand_k.w + 1) - 1
        self.uplink_cost = self.natural.uplink_cost / \
            d * k + (np.ceil(k * np.log2(d)) / 32)

    def compress(self, x):
        compressed_x = self.rand_k.compress(x)

        # Identify non-zero elements and their indices
        nonzero_indices = np.nonzero(compressed_x)[0]
        if len(nonzero_indices) == 0:
            return compressed_x

        nonzero_elements = compressed_x[nonzero_indices]

        # Compress only the non-zero elements using Natural
        compressed_nonzero_elements = self.natural.compress(nonzero_elements)

        # Place the compressed non-zero elements back in their positions
        compressed_x[nonzero_indices] = compressed_nonzero_elements

        return compressed_x


class Scaled_Natural_Rand_k():
    def __init__(self, d, k):
        self.natural_rand_k = Natural_Rand_k(d=d, k=k)
        w = self.natural_rand_k.w
        self.alpha = 1 / (1 + w)
        self.name = 'Scaled_' + self.natural_rand_k.name
        self.uplink_cost = self.natural_rand_k.uplink_cost

    def compress(self, x):
        return self.natural_rand_k.compress(x) * self.alpha
