
import torch


def chain_batchmult(l):
    if len(l) >= 3:
        return torch.bmm(torch.bmm(l[0],  l[1]), chain_batchmult(l[2:]))
    elif len(l) == 2:
        return torch.bmm(l[0], l[1])
    else:
        return l[0]

def cs_to_matrix(cs, dim):
    '''
        cs : 
    '''
    params = dim * (dim - 1) // 2
    assert cs[0].shape == cs[1].shape
    batch_size = cs[0].shape[0]
    assert cs[0].shape[1] == params
    assert cs[1].shape[1] == params
    k = 0
    mats = []
    for i in range(dim - 1):
        for j in range(dim - 1 - i):
            c,s = cs[0][:,k], cs[1][:,k]
            k += 1

            rotation_i = torch.eye(dim, dim).repeat(batch_size, 1, 1)
            rotation_i[:, i, i] = c
            rotation_i[:, i, i + j + 1] = s
            rotation_i[:, j + i + 1, i] = -s
            rotation_i[:, j + i + 1, j + i + 1] = c

            mats.append(rotation_i)
    return chain_batchmult(mats)

def eye_like(batch):
    assert batch.shape[1] == batch.shape[2]
    device = batch.device
    dim = batch.shape[1]
    batch_size = batch.shape[0]
    eye = torch.eye(dim,dim).unsqueeze(0).expand(batch_size, -1, -1)
    eye = eye.to(device)
    return eye


epsilon = 1e-9
class SO_n_product():
    def __init__(self, dim_list, device):
        self.dim_list = dim_list
        self.param_list = [ dim * (dim -1)//2 for dim in dim_list]
        self.latent_num = 2 * sum(self.param_list)
        self.decoder_num = sum([dim**2 for dim in self.dim_list])
        self.device = device
    def list_cs_to_matrix(self, cs_list):
        matrix_list = []
        for cs, dim in zip(cs_list, self.dim_list):
            matrix = cs_to_matrix(cs, dim)
            matrix = matrix.cuda()
            matrix_list.append(matrix)
        return matrix_list
    def parse(self, cs_tensor):
        cs_parsed_list = []
        assert cs_tensor.shape[1] == self.latent_num
        start_idx = 0 
        for param in self.param_list:
            cs_parsed_list.append(cs_tensor[:,start_idx:start_idx + param * 2])
            start_idx += param * 2
        return cs_parsed_list
    def normalize_cs(self, cs):
        cs = cs.reshape(-1, 2, cs.shape[1]//2)
        cs = cs/(torch.linalg.norm(cs, dim = 1, keepdim = True) + epsilon)
        return cs[:,0], cs[:,1]
    def forward(self, cs_tensor):
        parsed_list = self.parse(cs_tensor)
        normalized_cs_list = []
        for parsed_cs in parsed_list:
            normalized_cs_list.append(self.normalize_cs(parsed_cs))
        return normalized_cs_list, self.list_cs_to_matrix(normalized_cs_list)

class SO_n_Representation():

    def __init__(self, dim = 3):
        self.dim = dim
        self.params = self.dim * (self.dim - 1) // 2
        self.batch_size = None
        self.clear_matrix()

    def cossin(self, theta):
        return torch.cos(theta), torch.sin(theta)

    def set_cs(self, cs):
        self.cs = cs
        batch_size, params = self.cs[0].shape
        assert self.params == params
        self.batch_size = batch_size
        self.clear_matrix()

    def init_random(self, batch_size):
        self.set_thetas(torch.autograd.Variable((2*torch.rand(batch_size, self.params)-1) *torch.pi, requires_grad=True))

    def set_thetas(self, thetas):
        self.theta = thetas
        self.cs = self.cossin(self.theta)
        self.clear_matrix()

    def clear_matrix(self):
        self.matrix = None

    def get_matrix(self):
        if self.matrix is None:
            self.matrix = cs_to_matrix(self.cs, self.dim)
        return self.matrix