import torch
import torch.nn as nn

from .batch_micro_gate_ensemble_rho import BatchMicroGateEnsembleRho

class RhoDARTSMicroBatched(nn.Module):
    def __init__(self, total_qubits:int, num_subcircuit_qubits:int,
                 num_layers:int, super_circuit_structure:torch.Tensor,
                 psi0:torch.Tensor|None=None,
                 thread_block_size:int=16):
        super().__init__()
        assert total_qubits >= 0
        self.total_qubits = total_qubits
        
        assert (num_subcircuit_qubits <= total_qubits 
                and num_subcircuit_qubits > 0)
        self.num_subcircuit_qubits = num_subcircuit_qubits
        
        assert num_layers > 0
        self.num_layers = num_layers
        
        assert thread_block_size > 0
        self.thread_block_size = thread_block_size

        if super_circuit_structure.dim() != 2:
            raise ValueError('`super_circuit_structure` must be a 2-dimensional'
                             ' tensor.')
        if super_circuit_structure.shape[1] != num_subcircuit_qubits:
            raise ValueError('Second dimension of `super_circuit_structure` '
            'must be equal to the number of qubits in the subcircuit. Expected '
            f'{num_subcircuit_qubits}, recieved '
            f'{super_circuit_structure.shape[1]}.')
        if not (  (super_circuit_structure < total_qubits) 
                & (super_circuit_structure >= 0) ).all(): 
            raise ValueError(
            '`super_circuit_structure` must contain qubit indices '
            f' 0 <= index < `total_qubits`={total_qubits}.'
            )
        for support in super_circuit_structure:
            if support.unique().shape[0] < num_subcircuit_qubits:
                raise ValueError(
                    f'Invalid qubit support: {support} in '
                    '`super_circuit_structure`, contains duplicate qubit '
                    'indices.'
                )
        self.register_buffer('super_circuit_structure', super_circuit_structure)
        self.num_subcircuits = super_circuit_structure.shape[0]

        if psi0 is not None:
            if psi0.shape != (2**self.total_qubits,):
                raise ValueError('`psi0` must have shape [2^Q] for Q qubits')
            rho = torch.outer(psi0, psi0.conj())
        else:
            rho = torch.zeros([2**self.total_qubits,2**self.total_qubits], 
                              dtype=torch.complex64)
            rho[0,0] = 1.0
        self.register_buffer('rho0', rho)
        
    
    def forward(self, logits:torch.Tensor, angles:torch.Tensor, 
                softmax_temperature:float=1.0, skipValidation:bool=True,
                psi0:torch.Tensor|None=None
                )->torch.Tensor:
        if psi0 is not None:
            assert psi0.dim() == 2, '`psi0` must be a batch of state vectors'
            if psi0.size(1) != 2**self.total_qubits:
                raise ValueError('Statevector size of `psi0` does not match qubit count.')
            batch_size = psi0.size(0)
            rho = torch.vmap(torch.outer)(psi0,psi0.conj()).contiguous()
        else:
            batch_size = 1
            rho = self.rho0.unsqueeze(0).expand(batch_size,-1,-1).contiguous()
        
        if not skipValidation:
            if logits.shape != (self.num_layers, self.num_subcircuit_qubits,
                                self.num_subcircuit_qubits+3):
                raise ValueError(
                    '`logits` must have shape [L, Q, Q+3] for '
                    'L layers and Q qubits in the subcircuit.\n'
                    'Expected '
                    f'({self.num_layers}, {self.num_subcircuit_qubits}, '
                    f'{self.num_subcircuit_qubits+3}), recieved: '
                    f'{logits.shape}.')
            if angles.shape != (self.num_subcircuits, self.num_layers,
                                self.num_subcircuit_qubits):
                raise ValueError(
                    '`angles` must have shape [M, L, Q] for M subcircuits,'
                    ' L layers, and Q qubits in the subcircuit.\nExpected '
                    f'(B,{self.num_subcircuits}, {self.num_layers}, '
                    f'{self.num_subcircuit_qubits}), recieved: '
                    f'{logits.shape}.'
                )
            assert softmax_temperature > 0.0
        
        h = torch.softmax(logits/softmax_temperature,dim=-1)

        for i in range(self.num_subcircuits):
            for layer in range(self.num_layers):
                for target_qubit_id in range(self.num_subcircuit_qubits):
                    rho = torch.sum(
                        h[layer, target_qubit_id,:,None,None]
                        * BatchMicroGateEnsembleRho.apply(
                            rho, angles[i,layer,target_qubit_id], 
                            self.total_qubits, self.super_circuit_structure[i],
                            target_qubit_id, self.thread_block_size
                        ),
                        dim=1
                    )
        return rho
