import torch
import torch.nn as nn

from typing import Callable

from .batch_micro_gate_ensemble_psi import BatchMicroGateEnsemblePsi

class QDARTSMicroBatched(nn.Module):
    def __init__(self, total_qubits:int, num_subcircuit_qubits:int,
                 num_layers:int, super_circuit_structure:torch.Tensor,
                 gumbel_temp:float=0.5,gumbel_hard:bool=True,
                 psi0:torch.Tensor|None=None, thread_block_size:int=16):
        super().__init__()
        assert total_qubits >= 0
        self.total_qubits = total_qubits
        
        assert (num_subcircuit_qubits <= total_qubits 
                and num_subcircuit_qubits > 0)
        self.num_subcircuit_qubits = num_subcircuit_qubits
        
        assert num_layers > 0
        self.num_layers = num_layers
        
        assert thread_block_size > 0
        self.thread_block_size = thread_block_size

        assert gumbel_temp > 0.0
        self.gumbel_temp = gumbel_temp
        self.gumbel_hard = gumbel_hard

        if super_circuit_structure.dim() != 2:
            raise ValueError('`super_circuit_structure` must be a 2-dimensional'
                             ' tensor.')
        if super_circuit_structure.shape[1] != num_subcircuit_qubits:
            raise ValueError('Second dimension of `super_circuit_structure` '
            'must be equal to the number of qubits in the subcircuit. Expected '
            f'{num_subcircuit_qubits}, recieved '
            f'{super_circuit_structure.shape[1]}.')
        if not (  (super_circuit_structure < total_qubits) 
                & (super_circuit_structure >= 0) ).all(): 
            raise ValueError(
            '`super_circuit_structure` must contain qubit indices '
            f' 0 <= index < `total_qubits`={total_qubits}.'
            )
        for support in super_circuit_structure:
            if support.unique().shape[0] < num_subcircuit_qubits:
                raise ValueError(
                    f'Invalid qubit support: {support} in '
                    '`super_circuit_structure`, contains duplicate qubit '
                    'indices.'
                )
        self.register_buffer('super_circuit_structure', super_circuit_structure)
        self.num_subcircuits = super_circuit_structure.shape[0]

        if psi0 is not None:
            if psi0.shape != (2**self.total_qubits,):
                raise ValueError('`psi0` must have shape [2^Q] for Q qubits')
        else:
            psi0 = torch.zeros([2**self.total_qubits], dtype=torch.complex64)
            psi0[0] = 1.0
        self.register_buffer('psi0', psi0)
    
    def forward(self, logits:torch.Tensor, angles:torch.Tensor,
                angles_optimizer:torch.optim.Optimizer, num_iter:int,
                angle_loss_fn:Callable[[torch.Tensor,torch.Tensor],torch.Tensor],
                softmax_temperature:float=1.0, skipValidation:bool=True,
                psi0:torch.Tensor|None = None
                )->torch.Tensor:
        
        if psi0 is not None:
            assert psi0.dim() == 2, '`psi0` must be a batch of state vectors.'
            if psi0.size(1) != 2**self.total_qubits:
                raise ValueError('Statevector size of `psi0` does not match qubit count.')
            batch_size = psi0.size(0)
        else:
            batch_size = 1

        if not skipValidation:
            if logits.shape != (self.num_layers, self.num_subcircuit_qubits,
                                self.num_subcircuit_qubits+3):
                raise ValueError(
                    '`logits` must have shape [L, Q, Q+3] for '
                    ' L layers and Q qubits in the subcircuit.\nExpected '
                    f'({self.num_layers}, {self.num_subcircuit_qubits}, '
                    f'{self.num_subcircuit_qubits+3}), recieved: '
                    f'{logits.shape}.')
            if angles.shape != (self.num_subcircuits, self.num_layers,
                                self.num_subcircuit_qubits):
                raise ValueError(
                    '`angles` must have shape [M, L, Q] for M subcircuits,'
                    ' L layers, and Q qubits in the subcircuit.\nExpected '
                    f'({self.num_subcircuits}, {self.num_layers}, '
                    f'{self.num_subcircuit_qubits}), recieved: '
                    f'{logits.shape}.'
                )
            assert softmax_temperature > 0.0

        h = nn.functional.gumbel_softmax(logits/softmax_temperature, 
                                         self.gumbel_temp, self.gumbel_hard)
        def qsim(h:torch.Tensor, angles:torch.Tensor):
            if psi0 is None:
                psi = self.psi0.unsqueeze(0).expand(batch_size,-1).contiguous()
            else:
                psi = psi0.clone()
            
            for i in range(self.num_subcircuits):
                for layer in range(self.num_layers):
                    for target_qubit_id in range(self.num_subcircuit_qubits):
                        psi = torch.sum(
                            h[layer, target_qubit_id,:,None]
                            * BatchMicroGateEnsemblePsi.apply(
                                psi, angles[i,layer,target_qubit_id], 
                                self.total_qubits, self.super_circuit_structure[i],
                                target_qubit_id, self.thread_block_size
                            ),
                            dim=1
                        )
            return psi

        for iter in range(num_iter):
            angles_optimizer.zero_grad()
            psi = qsim(h.detach(), angles)
            angle_loss = angle_loss_fn(psi, angles)
            angle_loss.backward()
            angles_optimizer.step()
        
        psi = qsim(h, angles.detach())
        return psi
