import torch
import numpy as np
from scipy import stats
import pickle
import time

from utils.vecs_io import fvecs_read
from utils.vec_np import normalize
from .probabilistic_scalar_compressor import ProbabilisticScalarCompressor

class VoronoiCompressor(object):
    def __init__(self, size, shape, args):
        c_dim = args.c_dim
        k_bit = args.k_bit
        n_bit = args.n_bit
        compressed_norm = n_bit != 32
        assert c_dim > 0
        assert k_bit >= 0
        assert n_bit > 0
        
        self.quantizer = args.quantizer
        self.debias = args.debias
        self.device = args.device
        
        self.cuda = not args.no_cuda
        self.size = size
        self.shape = shape

        if c_dim == 0 or self.size < args.c_dim:
            self.dim = self.size
        else:
            self.dim = c_dim
            for i in range(0, 10):
                if size % self.dim != 0:
                    self.dim = self.dim // 2 * 3

        if c_dim != self.dim:
            print("alternate dimension form"
                  " {} to {}, size {} shape {}"
                  .format(c_dim, self.dim, size, shape))

        assert size % self.dim == 0, \
            "not divisible size {}  c_dim {} self.dim {}"\
                .format(size, c_dim, self.dim)

        if k_bit <= 0:
            self.K = self.dim
        else:
            self.K = 2 ** k_bit
        # self.codewords = normalize(np.random.normal(size=(self.K, self.dim)))[1].astype(np.float32)
        if self.K == self.dim:
            # self.codewords = stats.ortho_group.rvs(self.dim).astype(np.float32)
            # self.codewords = np.identity(self.dim, dtype=np.float32)
            self.codewords = np.random.normal(size=(self.dim, self.dim)).astype(np.float32)
        else:
            location = './codebooks/learned_codebook/' \
                       'angular_dim_{}_Ks_{}.fvecs'.format(self.dim, self.K)
           # _, self.codewords = normalize(fvecs_read(location))
            self.codewords = np.random.normal(size=(self.K, self.dim)).astype(np.float32)

        self.codewords = torch.from_numpy(self.codewords)
        if self.cuda:
            self.codewords = self.codewords.to(args.device)
        self.code_dtype = torch.uint8 if k_bit <= 8 else torch.int32

        self.compressed_norm = compressed_norm
        if self.compressed_norm:
            self.norm_compressor = ProbabilisticScalarCompressor(n_bit, args)

    def compress(self, vec):
        #t=time.time()
        ### Sample a different codebook at each epoch and for each different worker
        if self.quantizer == 'voronoi':
            self.codewords = np.random.normal(size=(self.K, self.dim)).astype(np.float32)
            self.codewords = torch.from_numpy(self.codewords)
            if self.cuda:
                self.codewords = self.codewords.to(self.device)
            
        vec = vec.view(-1, self.dim)
        mu = torch.mean(vec)
        sigma = torch.std(vec)
        vec = (vec - mu) / sigma
        # calculate probability, complexity: O(d*K)
        p = torch.mm(self.codewords, vec.transpose(0, 1)).transpose(0, 1)
        # probability = torch.abs(p)

        # choose codeword
        codes = torch.argmax(2*p -torch.norm(self.codewords.t(), dim=0)**2, dim=1)
        #u = p.gather(dim=1, index=codes.view(-1, 1)).view(-1)

        #if self.compressed_norm:
         #   u = self.norm_compressor.compress(u)
        #print(time.time()-t)
        
        debias = []
        if self.debias:
            with open("./radial_biases/radial_bias_"+str(self.dim)+"_"+str(self.K)+"_"+".txt", "rb") as fp:   #Pickling
                pi = pickle.load(fp)
            pich=[e[0] for e in pi]
            #debias = torch.from_numpy(np.array([self.fix_bias(pich, e) for e in vec]))
            #debias = self.fix_bias(pich, vec)
            norm = torch.norm(vec, dim=1)
            #debias = norm.cpu().apply_(lambda x: self.fix_bias0(pich, x))
            polynomial_approximation = np.poly1d(np.polyfit([r/10 for r in range(1, len(pich))], pich[1:], 4))
            debias = self.fix_bias2(min(pich[1:]), max(pich[1:]), polynomial_approximation(norm.cpu()))
            if self.cuda:
                debias = debias.to(self.device)
                
        #print(time.time()-t)

        #print(torch.norm(vec, dim=1))
        #print(debias)
        #return [u, codes.type(self.code_dtype)]
        return [mu, sigma, codes.type(self.code_dtype), debias]

    def decompress(self, signature):
        #[norms, codes] = signature
        [mu, sigma, codes, debias] = signature
#        if self.compressed_norm:
 #           norms =  self.norm_compressor.decompress(norms)

        codes = codes.view(-1).type(torch.long)
  #      norms = norms.view(-1)

        vec = self.codewords[codes]
       # recover = torch.mul(vec, norms.view(-1, 1).expand_as(vec))
        if self.debias:
            #print(vec.size())
            #print(debias.size())
            recover = (vec.t()*sigma / debias).t() + mu
        else:
            recover = vec*sigma + mu
        return recover.view(self.shape)

    def fix_bias(self, p, x):
        norm = torch.norm(x, dim=1)
        c = torch.stack([torch.tensor([k for k in range(len(p))]) for i in range(len(norm))]).to(self.device)
        amin = torch.argmin(torch.abs((c.t()-norm).t()), 1)
        
        #amin = [torch.argmin(torch.abs(torch.tensor([k for k in range(len(p))])-10*norm[i])) for i in range(len(norm))]
        #print(amin)
        return torch.tensor(p)[amin]                  
        #if torch.norm(x) > 5.9:
        #    return p[-1]
        #elif torch.norm(x) < 0.1:
        #    return p[0]
        #else:
        #    return p[int(10*torch.norm(x))]
    def fix_bias0(self, p, x):
        return torch.tensor(p)[torch.argmin(torch.abs(torch.tensor([k for k in range(len(p))])-10*x))]
    def fix_bias1(self, p, x):
        return torch.tensor(p)[torch.bernoulli(x/(len(p)-1)).type(torch.long)*(len(p)-1)]
    def fix_bias2(self, mi, ma, approx_norm):
        return torch.bernoulli(torch.clamp((torch.tensor(approx_norm)-mi)/(ma-mi), 0, 1))*(ma-mi) + mi
