import torch
import torch.nn as nn

from functools import lru_cache

from qtorch import CTYPE
from qtorch.config import QTORCH_CONFIG


@lru_cache
def build_CNOT_matrix(num_qubits:int, target_qubit:int, control_qubit:int,
                      device:torch.device='cpu')->torch.Tensor:
    '''
    Returns the CNOT matrix on a system of n qubits given the target and control
    qubits

    Arguments
    ---------
    num_qubits: int
        The total number of qubits
    target_qubit: int
        The index of the target qubit
    control_qubit: int
        The index of the control qubit
    device: torch.device
        The device on which to store the CNOT matrix
    
    Returns
    -------
    torch.Tensor
        The CNOT gate matrix
    
    Raises
    ------
    TypeError:
        - When num_qubits, target_qubit or control_qubit are not ints
    ValueError:
        - num_qubits < 2
        - invalid target or control qubit indices
    '''
    if not QTORCH_CONFIG['skipValidation']:
        if not isinstance(num_qubits,int):
            raise TypeError('`num_qubits` must be an int')
        if not isinstance(target_qubit,int):
            raise TypeError('`target_qubit` must be an int')
        if not isinstance(control_qubit,int):
            raise TypeError('`control_qubit` must be an int')
        if num_qubits < 2:
            raise ValueError('`num_qubits` must be at least 2')
        if target_qubit < 0 or target_qubit >= num_qubits:
            raise ValueError(f'Invalid `target_qubit` index ({target_qubit})')
        if control_qubit < 0 or control_qubit >= num_qubits:
            raise ValueError(f'Invalid `control_qubit` index ({control_qubit})')
        if target_qubit == control_qubit:
            raise ValueError('Target and control qubits must be different')

    cx = torch.zeros([2**num_qubits,2**num_qubits],dtype=CTYPE,device=device)
    target_mask = 1 << target_qubit
    control_mask = 1 << control_qubit
    for i in range(2**num_qubits):
        if i & control_mask:
            cx[i^target_mask,i] = 1.0
        else:
            cx[i,i] = 1.0
    return cx

@lru_cache
def build_rotation_generators(num_qubits:int, target_qubit:int, 
                              paulis:torch.Tensor)->torch.Tensor:
    '''
    Returns the Pauli X,Y and Z gates acting on the specified target qubit

    Arguments
    ---------
    num_qubits: int
        The total number of qubits
    target_qubit: int
        The target qubit index
    paulis: torch.Tensor
        A [4,2,2] tensor containing the Pauli I,X,Y,Z matrices in that order
    
    Returns
    -------
    torch.Tensor
        A [3,2**num_qubits,2**num_qubits] tensor containing the Pauli X,Y,Z
        matrices acting on the target qubit. This tensor is stored on 
        paulis.device
    
    Raises
    ------
    TypeError:
        - when num_qubits, target_qubit aren't ints
    ValueError:
        - num_qubits < 1
        - invalid target qubit index
        - paulis is not [4,2,2]
    '''
    if not QTORCH_CONFIG['skipValidation']:
        if not isinstance(num_qubits,int):
            raise TypeError('`num_qubita` must be an int')
        if not isinstance(target_qubit,int):
            raise TypeError('`target_qubit` must be an int')
        if num_qubits < 1:
            raise ValueError('`num_qubits` must be at least 1')
        if target_qubit < 0 or target_qubit >= num_qubits:
            raise ValueError(f'Invalid target qubit index ({target_qubit})')
        if paulis.shape != (4,2,2):
            raise ValueError('`paulis` must have shape [4,2,2]')
    num_right = target_qubit
    num_left = num_qubits - target_qubit - 1

    right_ids = torch.eye(2**num_right,dtype=CTYPE,device=paulis.device) if num_right > 0 else torch.tensor(1.0+0.0j,device=paulis.device)
    left_ids  = torch.eye(2**num_left,dtype=CTYPE,device=paulis.device)  if num_left > 0  else torch.tensor(1.0+0.0j,device=paulis.device)

    gens = torch.zeros([3,2**num_qubits,2**num_qubits],dtype=CTYPE,device=paulis.device)
    for i in range(3):
        gens[i] = torch.kron(torch.kron(left_ids, paulis[i+1]),right_ids)
    return gens
