import torch
import torch.nn.functional as F
import numpy as np
import random
import math

from utils.Jorth_CSinitial import Jorth_CSinitial
from utils.OrthProj import OrthProj
from utils.Jexc import exc_operator
from utils.U_proj import U_proj
from utils.initial_x import getU, getV, getH
def generate_allsparse_triple(entity_num, relation_num, sparsity):
    all_pairs = torch.tensor(np.array(np.triu_indices(entity_num, 1)).T)
    relations = torch.randint(0, relation_num, (all_pairs.size(0), 1))
    triple = torch.cat((all_pairs, relations), dim=1)
    n = triple.size(0)
    index = torch.randperm(n)[:int(sparsity * n)]

    triple = triple[index]
    return triple

def split_train_test(triple, rate_train, rate_test):
    n = triple.size(0)

    num_train = round(rate_train * n)
    num_test = round(rate_test * n)

    indices = torch.randperm(n)
    index_train = indices[:num_train]

    index_rest = torch.tensor(torch.arange(n).tolist())
    index_test = torch.tensor(random.sample(index_rest.tolist(), num_test))

    train = triple[index_train]
    test = triple[index_test]

    return train, test

def UltraE_initial(entity_num, relation_num, p, d, beta):
    # vec_entity
    vec_entity = torch.zeros(entity_num, d)
    for i in range(entity_num):
        vec_entity[i, :] = U_proj(torch.randn(d), p, beta)

    vec_relation = torch.zeros(relation_num, d, d)
    theta = torch.randn(relation_num, int(d / 2))
    mu = torch.randn(relation_num, int(d - p))
    ksi = torch.randn(relation_num, int(d / 2))

    U = torch.stack([getU(theta) for theta in theta])
    H = torch.stack([getH(d, p, theta) for theta in torch.clamp(mu, -1, 1)])
    V = torch.stack([getV(theta) for theta in ksi])
    vec_relation = torch.bmm(U, torch.bmm(H, V))

    # vec_bias
    vec_bias = torch.zeros(entity_num)
    for i in range(entity_num):
        vec_bias[i] = torch.randn(1)

    return vec_entity, vec_relation, vec_bias, theta, mu, ksi

def Corrupt(triple, entity_num, k):
    n = triple.shape[0]
    choices = torch.rand(n, k, device=triple.device) > 0.5

    Corrupted_triple = triple.unsqueeze(1).repeat(1, k, 1)
    head_candidates = generate_random_except_vectorized(entity_num, Corrupted_triple[:, :, 0].reshape(-1))
    head_candidates = head_candidates.reshape(n, k)
    head_mask = choices
    Corrupted_triple[:, :, 0] = torch.where(head_mask, head_candidates, Corrupted_triple[:, :, 0])

    tail_candidates = generate_random_except_vectorized(entity_num, Corrupted_triple[:, :, 1].reshape(-1))
    tail_candidates = tail_candidates.reshape(n, k)
    tail_mask = ~choices
    Corrupted_triple[:, :, 1] = torch.where(tail_mask, tail_candidates, Corrupted_triple[:, :, 1])

    return Corrupted_triple

def generate_random_except_vectorized(n, exclude):
    random_nums = torch.rand(exclude.size()).to(exclude.device) * (n - 1)
    random_nums = random_nums.long()
    adjustments = random_nums >= exclude
    random_nums += adjustments.long()
    return random_nums

def generate_random_except(n, m):
    nums = torch.arange(n).to(m.device)
    nums = nums[nums != m]
    idx = torch.randint(0, len(nums), (1,))
    random_num = nums[idx]
    return random_num.item()