import torch
from torch.autograd import Variable

def rot_matrix(angle):
    c = torch.cos(angle)
    s = torch.sin(angle)
    R = torch.zeros(2, 2)
    R[0, 0] = c
    R[0, 1] = s
    R[1, 0] = -s
    R[1, 1] = c
    torch.mm(R.T, R)
    return R

def ref_matrix(angle):
    c = torch.cos(angle)
    s = torch.sin(angle)
    R = torch.zeros(2, 2)
    R[0, 0] = -c
    R[0, 1] = s
    R[1, 0] = s
    R[1, 1] = c
    torch.mm(R.T, R)
    return R

def getU(d,p,theta):
    U = torch.stack([rot_matrix(angle) for angle in theta])
    _, m, _ = U.size()
    U = torch.block_diag(*U.unbind())
    U.requires_grad_(True)
    return U

def getH(d,p,mu):
    q = int(d-p)
    H = torch.zeros([d, d])
    H[0:q, 0:q] = torch.diag(torch.cosh(mu))
    H[q:p, q:p] = torch.eye(p - q)
    H[0:q, p:d] = torch.diag(torch.sinh(mu))
    H[p:d, 0:q] = torch.diag(torch.sinh(mu))
    H[p:d, p:d] = torch.diag(torch.cosh(mu))
    return H

def getV(d,p,ksi):
    V = torch.stack([ref_matrix(angle) for angle in ksi])
    _, m, _ = V.size()
    V = torch.block_diag(*V.unbind())
    V.requires_grad_(True)
    return V