import torch

from qtorch import CTYPE

STANDARD_GATES_DICT:dict[str,torch.Tensor] = {
    'I': torch.eye(2,dtype=CTYPE),
    'X': torch.tensor([[0,1],
                       [1,0]],dtype=CTYPE),
    'Y': torch.tensor([[0,-1j],
                       [1j,0]],dtype=CTYPE),
    'Z': torch.tensor([[1,0],
                       [0,-1]],dtype=CTYPE),
    'H': torch.tensor([[1,1],
                       [1,-1]],dtype=CTYPE)/torch.sqrt(torch.tensor(2)),
    'S': torch.tensor([[1,0],
                       [0,1j]],dtype=CTYPE),
    'T': torch.tensor([[1,0],
                       [0,torch.exp(torch.tensor(1j*torch.pi/4))]])
}

PAULIS:torch.Tensor = torch.stack([STANDARD_GATES_DICT[key] for key in 'IXYZ'],
                                  dim=0)

def copyStandardGatesTo(device:str|torch.device,
                        gate_list:list[str]|None=None,
                        gate_dict:dict[str,torch.Tensor]|None=None
                        )->dict[str,torch.Tensor]:
    '''
    Sends copies of the matrices in gate_dict to the specified
    device, call this prior to sending the models to device

    Arguments
    ---------
    device: str | torch.device
        The device to send the matrices to
    gate_list: list[str], optional
        A list of the keys of the standard gates you wish to send to the device
        by default all of the gates are sent.
    gate_dict: dict[str,torch.Tensor], optional
        The dictionary of matrices you want to copy from, by default
        STANDARD_GATES_DICT is selected.
    
    Returns
    -------
    dict[str,torch.Tensor]
        Dictionary of the Pauli matrices sent to the specified device
    '''
    new_dict = {}
    if gate_dict is None:
        gate_dict = STANDARD_GATES_DICT
    
    if gate_list is None:
        gate_list = gate_dict.keys()

    for key in gate_list:
        new_dict[key] = gate_dict[key].to(device=device)
    
    return new_dict