import torch

from functools import lru_cache

from qtorch import CTYPE
from qtorch.unitaries.standard_gate_matrices import PAULIS

def validateInput(qs:torch.Tensor, total_qubits:int, target_qubits:int|list[int], 
                  control_qubits:int|list[int]|None=None,
                  num_targets:int|None=None,
                  num_controls:int|None=None)->None:
    assert qs.dim() < 3 and qs.dim() > 0, '`qs` must be a 1 or 2 dimensional tensor!'
    assert total_qubits > 0, '`total_qubits` must be greater than 0'

    if isinstance(target_qubits,int):
        target_qubits = [target_qubits]
    for tq in target_qubits:
        assert tq < total_qubits and tq >= 0, f'Invalid target qubit index ({tq}) for total qubits {total_qubits}'
    if num_targets is not None:
        assert len(target_qubits) == num_targets, f'Expected {num_targets} target qubits, recieved {len(target_qubits)}'
    
    if control_qubits is None:
        control_qubits = []
    if isinstance(control_qubits,int):
        control_qubits = [control_qubits]
    for cq in control_qubits:
        assert cq >= 0 and cq < total_qubits, f'Invalid control qubit index ({cq}) for total qubits {total_qubits}'
        assert cq not in target_qubits, f'Control qubit ({cq}) cannot also be a target qubit!'
    if num_controls is not None:
        assert len(control_qubits) == num_controls, f'Expected {num_controls} target qubits, recieved {len(control_qubits)}'

    if qs.dim() == 1:
        assert qs.shape[0] == 2**total_qubits, f'Expected statevector of size {2**total_qubits}, recieved {qs.shape}'
    elif qs.dim() == 2:
        assert qs.shape[0] == qs.shape[1], 'Density Matrix must be a square matrix'
        assert qs.shape[0] == 2**total_qubits, f'Expected density matrix of size {(2**total_qubits,2**total_qubits)}, recieved {qs.shape}'

@lru_cache
def getFullMatrix(total_qubits:int, gateMatrix:torch.Tensor, target_qubits:int|tuple[int], control_qubits:int|tuple[int]|None=None)->torch.Tensor:
    '''
    Returns the full matrix of the gate acting on a system of the specified
    total number of qubits given the target and control qubit indices
    '''
    if isinstance(target_qubits,int):
        target_qubits = [target_qubits]
    if control_qubits is None:
        control_qubits = []
    if isinstance(control_qubits, int):
        control_qubits = [control_qubits]


    for t in target_qubits:
        assert 0 <= t
        assert t < total_qubits
    for c in control_qubits:
        assert c not in target_qubits
        assert 0 <= c
        assert c < total_qubits
    
    assert gateMatrix.dim() == 2
    assert gateMatrix.shape[0] == gateMatrix.shape[1]
    assert gateMatrix.shape[0] == 2**len(target_qubits)

    if len(target_qubits) > 1:
        raise NotImplementedError('Does not support porting multi-qubit gates')
    else:
        if len(control_qubits) > 0:
            fullMatrix = torch.zeros([2**total_qubits, 2**total_qubits],dtype=CTYPE,device=gateMatrix.device)
            id = torch.eye(2,dtype=CTYPE,device=gateMatrix.device)
            _0 = torch.tensor([[1,0],[0,0]],dtype=CTYPE,device=id.device)
            _1 = torch.tensor([[0,0],[0,1]],dtype=CTYPE,device=id.device)

            for i in range(2**len(control_qubits)):
                # print(f'{i=}')
                c_index = 0
                term = torch.tensor(1.0+0.0j)
                for j in range(total_qubits):
                    # print(f'\t{j=}')
                    if j == target_qubits[0]:
                        # print('\ttarget')
                        term = torch.kron(gateMatrix if i==2**len(control_qubits)-1 else id, term)
                    elif j not in control_qubits:
                        # print('\tother')
                        term = torch.kron(id, term)
                    else:
                        # print(f'\tcontrol: {c_index=}, mask={(1<<c_index)}')
                        term = torch.kron(_1 if (i & (1 << c_index)) else _0, term)
                        c_index += 1
                fullMatrix += term
        else:
            num_right = target_qubits[0]
            num_left = total_qubits - target_qubits[0] - 1

            right_ids = torch.eye(2**num_right,dtype=CTYPE,device=gateMatrix.device) if num_right > 0 else torch.tensor(1.0+0.0j,device=gateMatrix.device)
            left_ids  = torch.eye(2**num_left,dtype=CTYPE,device=gateMatrix.device)  if num_left > 0  else torch.tensor(1.0+0.0j,device=gateMatrix.device)

            fullMatrix = torch.kron(left_ids, torch.kron(gateMatrix, right_ids))

        return fullMatrix

@lru_cache
def getPauliStringMatrix(total_qubits:int, pauli_string:str, target_qubits:list[int], pauli_matrices:torch.Tensor=PAULIS)->torch.Tensor:
    assert len(pauli_string) == len(target_qubits)
    assert len(target_qubits) <= total_qubits
    assert len(set(target_qubits)) == len(target_qubits)
    for t in target_qubits:
        assert t >= 0 and t < total_qubits
    for p in pauli_string:
        assert p in 'IXYZ'
    mat = torch.tensor(1.0+0.j,device=pauli_matrices.device)
    pstring_to_index = {p:i for i,p in enumerate('IXYZ')}
    for i in range(total_qubits):
        if i in target_qubits:
            mat = torch.kron(pauli_matrices[pstring_to_index[pauli_string[target_qubits.index(i)]]], mat)
        else:
            mat = torch.kron(pauli_matrices[0], mat)
    return mat

def applyGate(quantumstate:torch.Tensor, gateMatrix:torch.Tensor, total_qubits:int, target_qubits:int|list[int], control_qubits:int|list[int]|None=None)->torch.Tensor:
    assert quantumstate.shape[0] == 2**total_qubits
    fullMatrix = getFullMatrix(total_qubits, gateMatrix, target_qubits, control_qubits)
    if quantumstate.dim() == 1:
        return fullMatrix @ quantumstate
    elif quantumstate.dim() == 2:
        return fullMatrix @ quantumstate @ fullMatrix.mH
    else:
        raise ValueError('`quantumstate` must be a 1 or 2 dimensional tensor')
