import numpy as np
import torch
from numpy.linalg import inv


def get_matrix(device, dtype):
    n = 3
    I = [[1,0],[0,1]]
    e = np.matrix([[1],[1]])
    alpha = np.zeros((3, 3))
    alpha = [[1,0,1],[0,1,1],[0,0,0]]
    desired = np.zeros((3, 3))
    desired = [[1,0,0],[0,1,0],[0,0,0]]
    gamma = np.random.randn(n)
    t = np.random.randn()
    t2 = np.random.rand(n-1)
    t2m = np.array([[t2[0]],[t2[1]]])
    T1 = I-np.matmul(e,t2m.T)
    T = np.zeros((3, 3))
    x = [T1[0,0], T1[0,1], -t]
    T = [[T1[0,0], T1[0,1], -t],[T1[1,0], T1[1,1], -t],[t2[0], t2[1], t]]
    gamma= np.matrix([[gamma[0], 0 , 0],[0, gamma[1], 0],[0, 0, gamma[2]]])
    beta = np.matmul(inv(gamma), T)

    blind_mat = torch.tensor(beta, dtype=dtype, device=device)
    unblind_mat = blind_mat.inverse()

    merge_grad_mat = torch.tensor([[1., 0., 0.,], [0.,1.,0.], [1.0, 1.0, 0.0]], dtype=dtype, device=device)
    grad_mat = torch.tensor(gamma, dtype=dtype, device=device).matmul(merge_grad_mat)
    sliced_grad_mat = partial_grad_mat = torch.index_select(grad_mat, 1, torch.tensor([0, 1],device=device))
    partial_grad_mat = torch.index_select(grad_mat,         0, torch.tensor([0, 1],device=device))
    partial_grad_mat = torch.index_select(partial_grad_mat, 1, torch.tensor([0, 1],device=device))
    partial_inv_grad = partial_grad_mat.inverse()

    inv_grad = torch.cat((partial_inv_grad, torch.zeros((2, 1),device=device)), axis=1)
    inv_grad_mat = torch.cat((inv_grad        , torch.zeros((1, 3),device=device)), axis=0)

    return blind_mat, unblind_mat, sliced_grad_mat, inv_grad_mat

def identity_matrix(device, dtype):
    blind_mat = torch.tensor([[1., 0., 0.,], [0., 1., 0.,], [0., 0., 1,]], device=device, dtype=dtype)

    unblind_mat = torch.tensor([[1., 0., 0,], [0., 1., 0,], [0., 0., 1,]], device=device, dtype=dtype)

    grad_mat = torch.tensor([[1., 0., 0,], [0., 1., 0,], [0., 0., 1,]], device=device, dtype=dtype)
    sliced_grad_mat = partial_grad_mat = torch.index_select(grad_mat, 1, torch.tensor([0, 1],device=device))
    inv_grad_mat = torch.tensor([[1., 0., 0,], [0., 1., 0,], [0., 0., 1,]], device=device, dtype=dtype)
    return blind_mat, unblind_mat, sliced_grad_mat, inv_grad_mat