"""Euclidean operations utils functions."""

from readline import parse_and_bind
import torch
import torch.nn as nn


def euc_sqdistance(x, y, eval_mode=False):
    """Compute euclidean squared distance between tensors.

    Args:
        x: torch.Tensor of shape (N1 x d)
        y: torch.Tensor of shape (N2 x d)
        eval_mode: boolean

    Returns:
        torch.Tensor of shape N1 x 1 with pairwise squared distances if eval_mode is false
        else torch.Tensor of shape N1 x N2 with all-pairs distances

    """
    x2 = torch.sum(x * x, dim=-1, keepdim=True)
    y2 = torch.sum(y * y, dim=-1, keepdim=True)
    if eval_mode:
        y2 = y2.t()
        xy = x @ y.t()
    else:
        assert x.shape[0] == y.shape[0]
        xy = torch.sum(x * y, dim=-1, keepdim=True)
    return x2 + y2 - 2 * xy



def givens_rotations(r, x, transpose = False):
    """Givens rotations.

    Args:
        r: torch.Tensor of shape (N x d), rotation parameters
        x: torch.Tensor of shape (N x d), points to rotate
        transpose: whether to transpose the rotation matrix

    Returns:
        torch.Tensor os shape (N x d) representing rotation of x by r
    """
    givens = r.view((r.shape[0], -1, 2))
    givens = givens / torch.norm(givens, p=2, dim=-1, keepdim=True).clamp_min(1e-15)
    x = x.view((r.shape[0], -1, 2))
    if transpose:
        x_rot = givens[:, :, 0:1] * x - givens[:, :, 1:] * torch.cat((-x[:, :, 1:], x[:, :, 0:1]), dim=-1)
    else:
        x_rot = givens[:, :, 0:1] * x + givens[:, :, 1:] * torch.cat((-x[:, :, 1:], x[:, :, 0:1]), dim=-1)
    return x_rot.view((r.shape[0], -1))

def givens_rotations_reverse(r, x):
    """Givens rotations by minus theta

    Args:
        r: torch.Tensor of shape (N x d), rotation parameters
        x: torch.Tensor of shape (N x d), points to rotate

    Returns:
        torch.Tensor os shape (N x d) representing rotation of x by r
    """
    givens = r.view((r.shape[0], -1, 2))
    givens = givens / torch.norm(givens, p=2, dim=-1, keepdim=True).clamp_min(1e-15)
    x = x.view((r.shape[0], -1, 2))
    x_rot = givens[:, :, 0:1] * x + givens[:, :, 1:] * torch.cat((x[:, :, 1:], -x[:, :, 0:1]), dim=-1) # changed here
    return x_rot.view((r.shape[0], -1))



def rotation_scaling_to(r, x):
    """ Apply rotation and scaling to x
    """
    r = r.view((r.shape[0], -1, 2))
    x = x.view((r.shape[0], -1, 2))
    x_rs =  r[:, :, 0:1] * x + r[:, :, 1:] * torch.cat((-x[:, :, 1:], x[:, :, 0:1]), dim=-1)
    return x_rs.view((r.shape[0], -1))


def full_givens_rotations(theta, x):
    """Generate full rotation matrix via givens rotations.

    Args:
        theta: torch.Tensor of shape (N x d-1), rotation angle parameters
        x: torch.Tensor of shape (N x d), points to rotate

    Returns:
        torch.Tensor os shape (N x d) representing rotation of x by r
    """
    N = x.size(0)
    d = x.size(1) # get dimension
    rotation_matrix = torch.eye(d)
    rotation_matrix = rotation_matrix.unsqueeze_(0)
    rotation_matrix = rotation_matrix.expand(N,d,d).to("cuda") 
#    for i in range(d-1):
#        # may be it is better to avoid use cos and sin?
#        cos_theta = torch.cos(theta).to("cuda")
#        sin_theta = torch.sin(theta).to("cuda")
#
#        # consturct givens matrix
#        givens = torch.eye(d)
#        givens = givens.unsqueeze_(0)
#        givens = givens.expand(N,d,d).clone().to("cuda")
#
#        givens[:,i,i] = cos_theta[:,i]
#        givens[:,i+1,i+1] = cos_theta[:,i]
#        givens[:,i,i+1] = -sin_theta[:,i]
#        givens[:,i+1,i] = sin_theta[:,i]
#        rotation_matrix = torch.bmm(givens, rotation_matrix)  
        
    x = x.unsqueeze(-1)
    x_rot = rotation_matrix @ x
    return x_rot.squeeze(-1)

def schmidt_orth(Tensor_alpha: torch.Tensor):
    """ Gram Schmidt Orth, handle a batch at the same time

    Args:
        Matrix_alpha (torch.Tensor): original tensor (batch_size, rank, rank)
    """
    Tensor_Beta = []
    rank = Tensor_alpha.shape[1] # get rank
    for i in range(rank):
        batch_alpha_col_i = Tensor_alpha[:, :, i]

        # handle previous vectors
        batch_beta_col_i = batch_alpha_col_i
        for j in range(i): 
            batch_beta_col_j = Tensor_Beta[j]
            batch_beta_col_i = batch_beta_col_i - ((torch.sum(batch_alpha_col_i * batch_beta_col_j, dim = -1)) / (torch.sum(batch_beta_col_j * batch_beta_col_j, dim = -1))).unsqueeze(-1) * batch_beta_col_j # multiply this scale
        Tensor_Beta.append(batch_beta_col_i) # turn it back
    
    Tensor_Beta = torch.stack(Tensor_Beta, dim=-1)
    
    # for i in range(rank):
    #     for j in range(i):
    #         Tensor_alpha[:, :, i] = Tensor_alpha[:, :, i] - ((torch.sum(Tensor_alpha[:, :, i] * Tensor_alpha[:, :, j], dim = -1)) / (torch.sum(Tensor_alpha[:, :, j] * Tensor_alpha[:, :, j], dim = -1))).unsqueeze(-1) * Tensor_alpha[:, :, j]
    
    
    return Tensor_Beta / (Tensor_Beta.norm(dim = 1, keepdim=True) + 1e-15) # to avoid nan
                                

def householder_transformation(x: torch.Tensor):
    """generate the householder transformation matrices for a batch of matrices

    Args:
        x (torch.Tensor): x (batch_size, rank, reflection_times)
    """
    batch_size, rank, reflection_times = x.shape
    householder_matrices = []

    # generate the correspodding householder matrices
    # for i in range(reflection_times):
    #     batch_v = x[:, :, i]
    #     batch_v = batch_v / torch.norm(batch_v, dim=-1, keepdim=True)
    #     batch_v = batch_v.unsqueeze(-1) # batch_size x rank x 1

    #     householder_matrix = torch.eye(rank).unsqueeze(0).expand(batch_size, -1, -1).cuda() # batch_size x I_{rank x rank}
    #     householder_matrix = householder_matrix - torch.bmm(batch_v, batch_v.transpose(1,2)) # batch_size x rank x rank

    #     householder_matrices.append(householder_matrix)

    # ! refactoring, parallelization
    # generate the correspodding householder matrices 
    x = x.transpose(1,2) # batch_size x reflection_times x rank
    x = x / torch.norm(x, dim=-1, keepdim=True)
    parallel_v = x.reshape(-1, rank).unsqueeze(-1) #batch_size * ref_times x rank x 1
    parallel_eyes = torch.eye(rank).unsqueeze(0).expand(batch_size * reflection_times, -1, -1).cuda()
    householder_matrices = parallel_eyes - torch.bmm(parallel_v, parallel_v.transpose(1,2)) # batch_size * ref_timex x rank x rank
    householder_matrices = householder_matrices.reshape(batch_size, reflection_times, rank, rank)
    
    # multiply together
    # todo further parallel this part by multiply matrix group in 2, this can further reduce from N to logN
    # ! but here it is also slow when N is small, so maybe parallel will not help a lot
    multiplied_matrices = householder_matrices[:,0,:,:] # batch_size x rank x rank
    for i in range(1, reflection_times):
        multiplied_matrices = torch.bmm(multiplied_matrices, householder_matrices[:,i,:,:])
    
    return multiplied_matrices




    return

    


def givens_rotations_h(r, x):
    """Givens rotations under cosh and sinh.

    Args:
        r: torch.Tensor of shape (N x d), rotation parameters
        x: torch.Tensor of shape (N x d), points to rotate

    Returns:
        torch.Tensor os shape (N x d) representing rotation of x by r
    """
    givens = r.view((r.shape[0], -1, 2)) 
    sinh = givens[:,:,0:1]
    cosh = 1 + sinh.pow(2)
    cosh = cosh.pow(0.5)
    x = x.view((r.shape[0], -1, 2))
    x_rot = cosh * x + sinh * torch.cat((x[:, :, 1:], x[:, :, 0:1]), dim=-1)
    return x_rot.view((r.shape[0], -1))

def givens_reflection(r, x):
    """Givens reflections.

    Args:
        r: torch.Tensor of shape (N x d), rotation parameters
        x: torch.Tensor of shape (N x d), points to reflect

    Returns:
        torch.Tensor os shape (N x d) representing reflection of x by r
    """
    givens = r.view((r.shape[0], -1, 2))
    givens = givens / torch.norm(givens, p=2, dim=-1, keepdim=True).clamp_min(1e-15)
    x = x.view((r.shape[0], -1, 2))
    x_ref = givens[:, :, 0:1] * torch.cat((x[:, :, 0:1], -x[:, :, 1:]), dim=-1) + givens[:, :, 1:] * torch.cat(
        (x[:, :, 1:], x[:, :, 0:1]), dim=-1)
    return x_ref.view((r.shape[0], -1))



# class CircleLoss(nn.Module):
#     def __init__(self, m: float, gamma: float) -> None:
#         super(CircleLoss, self).__init__()
#         self.m = m
#         self.gamma = gamma
#         self.soft_plus = nn.Softplus()

#     def forward(self, sp: Tensor, sn: Tensor) -> Tensor:
#         ap = torch.clamp_min(- sp.detach() + 1 + self.m, min=0.)
#         an = torch.clamp_min(sn.detach() + self.m, min=0.)

#         delta_p = 1 - self.m
#         delta_n = self.m

#         logit_p = - ap * (sp - delta_p) * self.gamma
#         logit_n = an * (sn - delta_n) * self.gamma

#         loss = self.soft_plus(torch.logsumexp(logit_n, dim=0) + torch.logsumexp(logit_p, dim=0))

#         return loss