import torch
import numpy as np

def generate_corr_vectors(v, c, d, num_vectors=10):
    vectors = []
    v = (v/torch.norm(v)).reshape(1,d)
    u = torch.randn(num_vectors,d)
    u = u/torch.norm(u,dim=-1,keepdim=True)
    v_orth = u-(u*v).sum(-1,keepdim=True)*v
    v_orth = v_orth/torch.norm(v_orth,dim=-1,keepdim=True)
    u = c*v + np.sqrt(1-c**2) * v_orth
    return u


def get_ij_vec_from_k(k, ik, jk):
    '''
    k, ik should be normalized
    '''
    k = k.reshape(1,-1)
    ci = -1/(2*(k*ik).sum(-1, keepdims=True))
    i = k + ci * ik

    cj = -1/(2*(k*jk).sum(-1, keepdims=True))
    j = k + cj * jk
    return i, j

def mc_sim(p, num_items, d, num_q = 100000):
    c = np.cos((1-p)*np.pi)
    jk = torch.randn(1,d)
    jk = jk/torch.norm(jk, dim=-1, keepdim=True)
    ik = generate_corr_vectors(jk, c, d, num_vectors=num_items)
    # WLOG
    k = torch.zeros(1,d); k[0] = 1
    i, j = get_ij_vec_from_k(k,ik,jk) # n_i x d, 1 x d
    q = torch.randn(num_q, d) # nq x d
    q = q/torch.norm(q, dim=-1, keepdim=True)
    qj = torch.matmul(q,j.T).reshape(-1,1) # n_q x 1
    qi = torch.matmul(q,i.T) # n_q x n_i
    qk = torch.matmul(q,k.T).reshape(-1,1) # n_q x 1
    cond_mask = ((qi>qk).eq_(0)).sum(-1, keepdim=True).eq_(0) # n_q x 1
    post_mask = qj>qk # n_q x 1
    return (cond_mask * post_mask).sum()/(cond_mask).sum()


    

    