import math
import torch

def repeat_id(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch, num_key_value_heads, slen = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :].expand(batch, num_key_value_heads, n_rep, slen)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen)

def indexing(x, indices):
    """ 
    inputs:
        - x: 4d-tensor with shape [b, h, n, d] 
        - indices: 3d-tensor with shape [b, h, s] where each entry should be in [0, n-1]
    output:
        - out: 4d-tensor with shape [b, h, s, d] where out[i,j] = x[i,j][indices[i,j],:]
    
    A naive implementation:
        out = torch.zeros(b, h, s, d)
        for i in range(b):
            for j in range(h):
                out[i,j] = x[i,j][idx[i,j],:]
        return out
    """
    return x.gather(2, indices.unsqueeze(-1).expand(-1, -1, -1, x.shape[-1]))


class AngularLSH(torch.nn.Module):

    def __init__(self, num_projs, dim, rng=None):
        super().__init__()
        self.num_projs = num_projs

        if num_projs > 0:
            self.register_buffer('proj_dir', torch.randn(dim + (num_projs,), generator=rng), persistent=False)
            self.register_buffer('perm', self._unit_hamming_distance_array(self.num_projs), persistent=False)
            self.register_buffer('enc_vec', 2 ** torch.arange(self.num_projs).view(1, 1, 1, -1), persistent=False)
        else:
            raise ValueError("Invaid value for num_projs")
            
    def _unit_hamming_distance_array(self, size_n):
        if size_n == 1:
            return torch.tensor([0, 1])
        a = self._unit_hamming_distance_array(size_n - 1)
        return torch.concat([a, torch.flip(a, dims=[0]) + 2 ** (size_n - 1)], 0)

    def hash(self, mat):
        mask = torch.einsum('...nd,...dr -> ...nr', mat, self.proj_dir)
        mask = mask > 0
        bin_ids = (mask * self.enc_vec).sum(-1)
        return self.perm[bin_ids]
    
    def __repr__(self):
        return f"AngularLSH(num_proj={self.num_projs}, proj_dir.shape={self.proj_dir.shape})"


if __name__ == "__main__":
    torch.manual_seed(0)
    lsh = AngularLSH(num_projs=24, dim=(1, 1, 128))
    print(lsh.proj_dir.shape)
    print(lsh.perm[:20])
    # print(max(lsh.perm))
    print(lsh.enc_vec)
    # lsh.hash(query_states)
    # _, query_sort_idx = torch.sort(self.lsh.hash(query_states), dim=2, stable=True) # batch_size x head_size x n
    # _, key_sort_idx = torch.sort(self.lsh.hash(key_states), dim=2, stable=True)

    d = 128  # Dimension of vectors
    nb = 100  # Number of vectors in the dataset
    nq = 10  # Number of query vectors
    k = 20  # Number of nearest neighbors to search for
    device = 'cpu'

    data = torch.randn(1, 1, nb, d).float().to(device)
    out = lsh.hash(data)
    print(out)
    # # data2 = torch.randn(nb, d).float().to(device)
    # # queries = torch.randn(nq, d).float().to(device)

    # # # Add data to the index
    # # index = get_faiss_tree(d, nb)
    # # index.train(data)
    # # index.add(data)
    # # faiss_dis_index(index, queries, k)
    # # standard_dis_index(data, queries, k)

    # # index.train(data2)
    # # index.add(data2)
    # # print(index.is_trained)
    # # faiss_dis_index(index, queries, k)
    # # standard_dis_index(torch.cat((data,data2), dim=0), queries, k)


    # for i in [15, 16, 17, 18, 19, 20, 21, 22, 23]:
    #     nb = 2**i
    #     print(f'\n{i}, {nb}')
    #     k = int(nb ** 0.25)
    #     start = time.time()
    #     data = torch.randn(nb, d).float().to(device)
    #     end = time.time()

    #     print(f"Create data take {end - start} s")
    #     index = get_faiss_tree(d, int(nb**0.5), num_graph=128, device=device)
    #     start = time.time()
    #     index.train(data)
    #     index.add(data)
    #     end = time.time()
    #     print(index.ntotal)
    #     print(f"Build tree take {end - start} s")

    #     repeat = 1

    #     data_insert = torch.randn(1, d).float().to(device)
    #     queries = torch.randn(1, d).float().to(device)

    #     start = time.time()
    #     for j in range(repeat):
    #         # index.train(data_insert)
    #         # index.add(data_insert)
    #         faiss_dis_index(index, queries, k, print_result=True)
    #     end = time.time()
    #     print(index.ntotal)
    #     print(f"Insert and search {repeat} times take {end - start} s")

    #     start = time.time()
    #     for j in range(repeat):
    #         standard_dis_index(data, queries, k, print_result=True)
    #         # standard_dis_index(torch.cat((data,data_insert), dim=0), queries, k, print_result=True)
    #     end = time.time()
    #     print(f"Insert and search {repeat} times take {end - start} s")
    #     print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
        
    #     del index