import numpy as np
from scipy.special import erfinv, binom
from scipy.linalg import hadamard
from sklearn.preprocessing import MinMaxScaler
from pathos.multiprocessing import ProcessPool
from tqdm import tqdm

FLOAT_BITS = 32


class BaseCompressor:
    def __init__(self, m: int, d:int):
        # Define parameters for the compressor from the config
        self.d = d
        self.m = m
    
    def compress(self, client_arr : np.ndarray):
        """Encode and decode the client outputs
        Returns:
            Decoded vector
            Number of bits communicated
        """
        raise NotImplementedError

    def num_bits(self):
        """Number of bits the compressor takes without including float bits.

        Raises:
            ValueError: _description_ 

        Returns:
            _type_: _description_
        """
        raise NotImplementedError

    def num_bits_float(self):
        """Number of bits the compressor takes for all machines in each round including float bits.
        """
        return self.num_bits()

class SignCompressor(BaseCompressor):
    def compress(self, client_arr: np.ndarray):
        norms = np.linalg.norm(client_arr, ord=1, axis=1)
        signs = np.sign(client_arr)
        decoded = (norms.reshape(-1,1)*signs).mean()/self.d
                
        return decoded

    def num_bits(self):
        return self.m*self.d
    def num_bits_float(self):
        return self.m*(self.d + FLOAT_BITS)
    
class NoisySignCompressor(BaseCompressor):
    def __init__(self, m:int, d: int, sigma: float, num_reps: int = 1):
        super().__init__(m=m, d=d)
        # Noise level to add for noisy signs
        self.sigma = sigma ## Zeta
        self.num_reps = num_reps
        
    def compress(self, client_arr: np.ndarray):
        compressed = np.zeros(client_arr.shape)
        for _ in range(self.num_reps):
            compressed += np.sign(client_arr + np.random.normal(0.0, self.sigma, size = (self.m, self.d)))
        compressed /= self.num_reps
        
        decoded = np.sqrt(2)*self.sigma*erfinv(compressed.mean(axis=0))
        
        return decoded
    def num_bits(self):
        return self.m*self.d*self.num_reps
            
class HadamardCompressor(BaseCompressor):
    def __init__(self, m:int, d: int, B: float, seed:int ,num_reps = 1):
        super().__init__(m=m, d=d)
        # Input and output are scaled to [-B,B]
        self.B =  B
        self.scaler = MinMaxScaler(feature_range=(-1*self.B,self.B))        
        self.num_reps = num_reps
        self.seed = seed

    def run_single_compress(self, client_scalars, idx):
        np.random.seed(self.seed + idx)
        decoded = 0.0
        for _ in range(self.num_reps):
            perm = np.random.permutation(client_scalars)
            output_dim_bits = []
            for i in range(self.m):
                output_dim_bits.append(self.hadamard1denc(perm[i], level=i+1))
            output_dim_bits = np.array(output_dim_bits)
            decoded += self.hadamard1ddec(output_dim_bits)
        decoded /= self.num_reps

        return decoded    
    # def compress_single_dim_rep(self, idx, ):
    def compress(self, client_arr: np.ndarray):
        decoded = np.zeros(self.d)
        client_arr = self.scaler.fit_transform(client_arr)
        
        rand_state = np.random.get_state()
        pool = ProcessPool(100)
        run_single_compress = lambda idx : self.run_single_compress(client_arr[:, idx], idx)
        results = list(tqdm(pool.imap(run_single_compress,range(self.d))))        
        decoded = np.array(results)
        np.random.set_state(rand_state)            
        decoded = self.scaler.inverse_transform(decoded.reshape(1,-1)).reshape(-1)
        
        ## We also send max and min per coordinate
        return decoded
        
    def hadamard1denc(self, scalar:float, level : int):
        new_scalar = scalar
        if level >1 : 
            for l in range(level-1):
                # Perform binary search on scalar.
                new_scalar -= np.sign(new_scalar)*(self.B/(2**l)) 
        return np.sign(new_scalar)

    def hadamard1ddec(self, scalars : np.ndarray):
        return self.B*(scalars * np.power(2.0, -1*np.arange(scalars.shape[0]))).sum()
    
    def num_bits(self):
        return (self.m)* self.d*self.num_reps
#    def num_bits_float(self):
#        return (self.m  + 2*FLOAT_BITS)* self.d

## Sparse regression compressor
class SparseRegCompressor(BaseCompressor):
    def __init__(self, m:int, d: int, K: int, seed:int, num_reps: int =1):
        super().__init__(m=m, d=d)
        self.K = K
        
        ## TODO : Define the coefficient array. it is all 1s right now.
        coeff = 1 - (2*np.log(K)/self.d)
        self.coeff_arr = np.sqrt((1 - coeff)*(coeff**np.arange(self.m)))
        self.num_reps = num_reps
        self.seed = seed
    
    def run_single_compress(self, client_arr, idx):
        np.random.seed(self.seed + idx)
        decoded = np.zeros(self.d)
        perm  = np.random.permutation(np.arange(self.m))
        sketch_arr = np.random.normal(loc=0, scale=1, size=(self.m*self.K, self.d))
        for i in range(self.m):
            ## Decode i^{th} machine's vector to i levels.
            sketch_arr_level = sketch_arr[i*self.K: (i+1)*self.K,:]
            dec_idx = self.single_vec_enc(client_arr[perm[i]], levels=i+1, sketch_arr = sketch_arr)
            decoded += self.coeff_arr[i]*sketch_arr_level[dec_idx]
        return decoded
    
    def compress(self, client_arr: np.ndarray):
        decoded = np.zeros(self.d)
        norms = np.linalg.norm(client_arr, axis=1).reshape(-1,1)
        client_arr /= norms
            
        rand_state = np.random.get_state()
        pool = ProcessPool(100)
        run_single_compress = lambda idx : self.run_single_compress(client_arr, idx)
        results = list(tqdm(pool.imap(run_single_compress,range(self.num_reps)), total=self.num_reps))        
        decoded = sum(results)/self.num_reps
        np.random.set_state(rand_state)        
        ## Normalize to get everything to average norm of sent vectors.
        decoded *= (norms.mean()/np.linalg.norm(decoded))
        
        ## Each machine sends norm index of its levels
        return decoded
        
    def single_vec_enc(self, vector:np.ndarray, levels : int, sketch_arr : np.ndarray):
        curr_vector = vector        
        for l in range(levels):
            sketch_arr_level = sketch_arr[l*self.K: (l+1)*self.K, :]
            dec_idx = (sketch_arr_level @ vector).argmax()
            curr_vector -= self.coeff_arr[l]*sketch_arr_level[dec_idx]           
        return dec_idx
    

    def num_bits(self):
        return self.m * np.log2(self.K)*self.num_reps
    
    def num_bits_float(self):
        return self.m * (np.log2(self.K)*self.num_reps + FLOAT_BITS)

        
class OneBitAverage(BaseCompressor):        
    def __init__(self, m: int, d: int, num_reps: int = 1):
        super().__init__(m=m, d=d)
        self.num_reps = num_reps
        
    def compress(self, client_arr: np.ndarray):
        
        norms = np.linalg.norm(client_arr, axis=1)
        client_arr /= norms.reshape(-1,1)
        decoded_unit_vector = np.zeros(self.d)
        for _ in range(self.num_reps):
            rand_gaussian = np.random.normal(loc=0, scale=1.0, size = (self.m, self.d))
            rand_uniform = rand_gaussian/np.linalg.norm(rand_gaussian, axis=1).reshape(-1,1)

            # Compute the signs
            signs = np.sign((rand_uniform * client_arr).sum(axis=1))

            # Compute decoded unit vector
            decoded_unit_vector += (signs.reshape(-1,1)* rand_uniform).mean(axis=0)
        decoded_unit_vector /= self.num_reps

        decoded_unit_vector = decoded_unit_vector/np.linalg.norm(decoded_unit_vector)
        decoded = norms.mean() * decoded_unit_vector
        
        ## Each machine sends norm and 1 sign
        return decoded 
    
    def num_bits(self):
        return self.m*self.num_reps
    def num_bits_float(self):
        return self.m *(self.num_reps + FLOAT_BITS)

class RandK(BaseCompressor):
    def __init__(self, m:int, d: int, K: int = 0, perc:float = 0.0):
        super().__init__(m=m, d=d)
        # Number of coordinates to select.
        # K is the number of coordinates to select and perc is fraction of coordinates to select
        if K !=0:
            self.K = K
        elif perc != 0.0:
            self.K = int(np.floor(perc * self.d))
        else:
            raise ValueError("Atleast one of perc or K should be non-zero")
    
    def compress(self, client_arr : np.ndarray):
        decoded = np.zeros(self.d)
        
        for i in range(self.m):
            compressed = client_arr[i]
            idx_to_keep = np.random.choice(self.d, self.K)
            compressed[~idx_to_keep] = 0.0
            decoded += compressed

        decoded = (decoded/self.m)*(self.d/self.K)
        ## Number of bits =  K (Floats  + idx per device)
        return decoded
    
    def num_bits(self):
        return self.m *self.K *(FLOAT_BITS + np.log2(self.d))

    

class RandKSpatialAvg(RandK):
    def __init__(self, m:int, d: int, K: int = 0 , perc:float = 0.0):
        super().__init__(m=m, d=d, K=K, perc=perc)
        self.T_func = lambda x : 1 + (self.m/2)*(self.m/(self.m-1))
        self.beta = self.d/(self.K*self.compute_expectation())
    
    def compute_expectation(self):
        avg = 0.0
        p = self.K/self.d
        for i in range(1, self.m+1):
            avg += binom(self.m, i) * (p**i)*((1-p)**(self.m - i))*(1/self.T_func(i))
        return avg
    
    def compress(self, client_arr: np.ndarray):
        decoded = np.zeros(self.d)
        idx_count = np.zeros(self.d)
        
        for i in range(self.m):
            compressed = client_arr[i]
            idx_to_keep = np.random.choice(self.d, self.K)
            
            compressed[~idx_to_keep] = 0.0
            idx_count[idx_to_keep] += 1.0
            
            decoded += compressed

        T_value = self.T_func(idx_count)
        decoded = (decoded/self.m) * (self.beta/T_value)
        return decoded
    
    
    
        
class RandKSpatialProjAvg(RandK):
    def __init__(self, m:int, d: int, K: int = 0 , perc:float =0.0):
        super().__init__(m=m, d=d, K=K, perc=perc)
        assert self.d & (self.d-1) == 0, f"Dimension d={self.d} is not a power of 2."
        # Generate Walsh-Hadamard Matrix
        self.H = hadamard(self.d)/np.sqrt(self.d)
        self.beta = self.d/(self.m*self.K) # This assumes \delta =0, from their rebuttal.
        ## RandKSpatialProjAvg uses an interpolation between 0 correlation and max correlation cases.
        self.T_func = lambda val : 1 + (self.m/2.0)*(val-1)/(self.m-1)
        

            
    def compress(self, client_arr: np.ndarray):
        # RandKSpatialProjAvg uses interpolation between max and 0 correlation cases.
        # assert self.corr is not None, "Correlation should be estimated from data initially"

        decoded = np.zeros(self.d)

        # Sum of projection matrices        
        proj_sum = np.zeros((self.d, self.d))

        for i in range(self.m):
            idx_to_keep = np.random.choice(self.d, self.K)

            # Subsampling matrix
            E_i = np.zeros((self.K, self.d))
            E_i[np.arange(self.K), idx_to_keep] = 1

            # Diagonal matrix with Rademacher entries
            D_i = np.diag(2*np.random.binomial(1,0.5, size=self.d)-1)
            
            # Projection matrix
            G_i = E_i @ self.H @ D_i            
            
            proj_i = G_i.T @ G_i 
            compressed = proj_i @ client_arr[i]
            
            decoded += compressed
            proj_sum += proj_i

        
        # Apply T function to eigenvalues
        eigvals, eigvecs = np.linalg.eigh(proj_sum)
        eigvals = self.T_func(eigvals)
        

        # Apply inverse projection and beta        
        non_zero_eigvals_idx = eigvals > 0.0
        inv_eigvals = np.zeros(eigvals.shape[0])
        inv_eigvals[non_zero_eigvals_idx] = 1./eigvals[non_zero_eigvals_idx]
        pinv_proj_sum = eigvecs.T @ np.diag(inv_eigvals) @ eigvecs
        
        decoded = pinv_proj_sum @ decoded * self.beta
        return decoded



class RotatedKLevelQuant(BaseCompressor):
    def __init__(self, m:int, d:int, levels:int):
        super().__init__(m=m, d=d)
        self.levels = levels
        assert self.d & (self.d-1) == 0, f"Dimension d={self.d} is not a power of 2."
        self.H = hadamard(self.d)
        
    def klevelquant(self, Z_i: np.ndarray):
        # Min and max coordinate values for each machine
        Z_i_min = Z_i.min(axis=0)
        Z_i_max = Z_i.max(axis=0)
        
        # Compute difference
        Z_i_diff = (Z_i_max - Z_i_min)/(self.levels -1)
        
        
        Z_i_coeff = (Z_i - Z_i_min)/(Z_i_diff)
        Z_i_idx = np.floor(Z_i_coeff)
        Y_i_prob = Z_i_coeff - Z_i_idx
        Y_i = (Z_i_idx + np.random.binomial(1, p=Y_i_prob))*Z_i_diff
                
        return Y_i
    
    def compress(self, client_arr : np.ndarray):
        R = self.H @ np.diag(2*np.random.binomial(1,0.5, size=self.d)-1) /np.sqrt(self.d)
        Z_i = client_arr @ R

        # Perform K-Level quantization
        Y_i = self.klevelquant(Z_i)    
        
        decoded = R.T @ Y_i.mean(axis=0)
        
        return decoded
    
    def num_bits(self):
        return self.m*(self.d*np.log2(self.levels))
    
class DriveOneBit(RotatedKLevelQuant):
    def __init__(self, m: int, d: int, levels: int, num_reps:int=1):
        super().__init__(m, d, levels)
        self.num_reps = num_reps
        
    def compress(self, client_arr: np.ndarray):
        decoded = np.zeros(self.d)
        for _ in range(self.num_reps):
            R = (self.H * np.diag(2*np.random.binomial(1,0.5, size=self.d)-1) /np.sqrt(self.d))
            Z =  client_arr @ R
            S = np.linalg.norm(Z, axis=1)**2/np.abs(Z).sum(axis=1)
            signs = np.sign(Z)
            
            decoded += R.T @ (S.reshape(-1,1) * signs).mean(axis=0)
        decoded /= self.num_reps
        ## Each machine sends norms
        return decoded 
    def num_bits(self):
        return self.m*(self.d + FLOAT_BITS)*self.num_reps 


## Correlated Quantization for Distributed Mean Estimation and Optimization. Suresh et al.
class RotatedKLevelCorrelatedQ(RotatedKLevelQuant):
    def __init__(self, m:int, d:int, levels:int):
        super().__init__(m=m, d=d, levels=levels)
        
        self.beta = (self.levels + 1)/(self.levels*(self.levels - 1))
        ## Normalization constant.
        self.normalization = np.sqrt(self.d)/np.sqrt(8*np.log(self.d*self.m))
        ## Assuming each vector is in [-B,B] where B=1.
        self.B = 1
        
    def compress(self, client_arr : np.ndarray):
        decoded = np.zeros(self.d)
        norms = np.linalg.norm(client_arr, axis=1).reshape(-1,1)
        
        client_arr /= norms
        ## Apply hadamard transform
        
        ## Walsh-Hadamard matrix
        W = self.H @ np.diag(2*np.random.binomial(1,0.5, size=self.d)-1) /np.sqrt(self.d)

        Y = self.normalization * client_arr @ W

        Y = np.maximum(-1, np.minimum(Y,1))
        
        for i in range(self.m):
            decoded[i] = self.onedimcorrelatedq(Y[:,i])
            
        decoded =norms.mean()*(W.T @ decoded)/self.normalization

        ## Each machine sends norms and 1 index per dimension
        return decoded
    
    
    
    def onedimcorrelatedq(self, client_arr_one_dim : np.ndarray):
        ## Generate random permutation
        perm  = np.random.permutation(np.arange(self.m))
        ## Generate levels
        c_1 = np.random.uniform(low=(-1)/self.levels, high=0.0)
        c_arr = c_1 + np.arange(self.m)*self.beta
        
        y = (client_arr_one_dim + self.B)/(2*self.B*self.beta)

        ## Compute c_max(i) = max_{c(i) < y(i)} c(i)        
        c_expanded = np.tile(c_arr.reshape(1,-1), (y.shape[0], 1))
        y_expanded = np.tile(y.reshape(-1,1), (1, c_arr.shape[0]))
        c_expanded[c_expanded >= y_expanded] = -1
        c_max = c_expanded.max(axis=1)
                
        gamma = np.random.uniform(low=0.0, high=1.0/self.m, size=(self.m))
        U = perm/self.m + gamma
        Q = (2*self.B)*(c_max + self.beta*(U <= y).astype(int))
        return Q.mean()
    
    def num_bits(self):
        return self.m*(self.d*np.log2(self.levels) + FLOAT_BITS)
    

#PERMUTATION COMPRESSORS FOR PROVABLY FASTER DISTRIBUTED NONCONVEX OPTIMIZATION. Szlendak et al.
class PermK(BaseCompressor):
    def __init__(self, m:int, d:int, num_reps : int=1):
        super().__init__(m=m, d=d)
        ## levels is always >=1 . handles either case 
        self.num_reps = num_reps
        
    def compress(self, client_arr : np.ndarray):
        decoded = np.zeros(self.d)
        for _ in range(self.num_reps):
            if self.d >= self.m:
                levels = self.d//self.m
                # Permutation to ensure the order of coordinates and clients changes.
                perm = np.random.permutation(np.arange(self.d))
                
                client_arr = client_arr[:, perm]
                for i in range(self.m):
                    low_idx = i*levels
                    if i == (self.m-1):
                        high_idx = max(self.d, (i+1)*levels)
                    else:
                        high_idx = (i+1)*levels
                        
                    decoded[perm[low_idx: high_idx]] += self.m*client_arr[i,perm[low_idx:high_idx]]
                    

            else:
                # Number of repetitions
                levels = np.ceil(self.m/self.d).astype(int)
                S = np.repeat(np.arange(self.d), levels)
                S = np.random.permutation(S)[:self.m]
                for i in range(self.d):
                    idx = (S == i)
                    decoded[i] += self.d * client_arr[idx, i].mean()
        decoded /= self.num_reps

        # Here need to multiply number of bits sent with number of bits per scalar.
        return decoded

    def num_bits(self):
        return max(self.d, self.m)*FLOAT_BITS*self.num_reps
## 
class Kashin(RotatedKLevelQuant):
    def __init__(self, m:int, d:int, levels:int, lamb: float, r: int):
        super().__init__(m=m, d=d, levels=levels)

        ## Kashin parameters
        self.lamb = lamb
        self.r = r
        self.D = int(self.d*self.lamb)
        
        ## Generate Kashin frame
        from scipy.stats import ortho_group
        U = ortho_group.rvs(dim=self.D)
        self.U = U[:self.d,:]
        self.eta = 3/4 + 1/(4*np.sqrt(self.lamb))
        self.delta = (1/np.power(5.0, 4))*((1 - 1/np.sqrt(self.lamb))**2)
        
    def get_kashin_coeffs(self, vec: np.ndarray):
        coeff = np.zeros(self.D)
        M = np.linalg.norm(vec)/np.sqrt(self.delta*self.D)
        x = vec
        for _ in range(self.r):
            b = self.U.T @  x 
            b_hat = np.sign(b)*min(np.linalg.norm(b), M)
            x = x - self.U @ b_hat
            coeff = coeff + b_hat
            M = self.eta * M 
        return coeff 
    
    def compress(self, client_arr: np.ndarray):
        decoded = np.zeros(self.d)
        kashin_coeffs = np.zeros(self.D)
        for i in range(self.m):
            
            ## Standard QSGD for coefficients with self.levels quantization level.
            coeffs = self.get_kashin_coeffs(client_arr[i,:])
            normalized_coeffs = np.abs(coeffs)/np.linalg.norm(coeffs)
            l_vals = np.floor(normalized_coeffs* self.levels)
            
            kashin_coeffs += np.linalg.norm(coeffs)*np.sign(coeffs)* (l_vals + np.random.binomial(1, p=(normalized_coeffs*self.levels - l_vals)))/self.levels

        kashin_coeffs /= self.m
        decoded = self.U @ kashin_coeffs
        
        return decoded
    def num_bits(self):
        return self.m * (FLOAT_BITS  + self.levels*(self.levels+np.sqrt(self.D))*(3 + 1.5*np.log(2*(self.levels**2 + self.D)/(self.levels*(self.levels +np.sqrt(self.D))))))


COMPRESSORS = {"sign": SignCompressor, "noisy_sign": NoisySignCompressor,"hadamard": HadamardCompressor,
               "onebitavg": OneBitAverage, "randk": RandK,
               "randkspatial": RandKSpatialAvg,"randkproj": RandKSpatialProjAvg,
                "rotatedquant": RotatedKLevelQuant, "rotatedcorrelatedquant": RotatedKLevelCorrelatedQ, 
                "drive": DriveOneBit, "permk" : PermK, "kashin":Kashin, "sparsereg": SparseRegCompressor
               }


