import torch


from qtorch.quantumstate.measurements import getExpectation

from ..metrics import anglePenalty,normalized_mean_entropy

def loss(qs:torch.Tensor, 
         angles:torch.Tensor, 
         probs:torch.Tensor,
         hamiltonian:torch.Tensor,
         energy_normalization_factor:float,
         iteration:int,
         entropy_schedule:torch.Tensor, 
         entropy_penalty_str:float, 
         angle_penalty_str:float,
         batched:bool=True)->torch.Tensor:
    
    if batched:
        energy = torch.vmap(getExpectation,(0,None))(qs, hamiltonian)[0]
    else:
        energy = getExpectation(qs, hamiltonian)
    angle_penalty = anglePenalty(angles)

    entropy =  normalized_mean_entropy(probs)
    loss_val = (
        (energy / energy_normalization_factor)
        + (
            entropy_schedule[iteration]
            * entropy
            * entropy_penalty_str
        )
        + (angle_penalty * angle_penalty_str)
    )

    return loss_val, torch.stack([energy, entropy, angle_penalty],dim=0)

def batch_loss(batch_size:int,
               qs:torch.Tensor,
               angles:torch.Tensor,
               probs:torch.Tensor,
               hamiltonians:torch.Tensor,
               max_cut_values:torch.Tensor,
               iteration:int,
               entropy_schedule:torch.Tensor,
               entropy_penalty_str:float,
               angle_penalty_str:float,
               mse_flag:bool
               )->tuple[torch.Tensor,torch.Tensor,torch.Tensor,torch.Tensor]:
    '''
    batch_size: B
    number of qubits: Q
    number of epochs: E
    
    Tensor Shapes:
    qs: [B, 2**Q] or [B, 2**Q, 2**Q]
    angles: [L, Q, B]
    probs: [L, Q, Q+3]
    hamiltonians: [B, 2**Q] 
    max_cut_values: [B] (int)
    entropy_schedule: [E]
    '''
    
    energies = torch.vmap(getExpectation, (0,0),0)(qs, hamiltonians)
    entropy = normalized_mean_entropy(probs)
    angle_penalty = anglePenalty(angles)

    loss_val = (
        ((energies/max_cut_values).sum()/batch_size if not mse_flag else
        torch.nn.functional.mse_loss(energies, -max_cut_values.to(energies.dtype)))

        + entropy_penalty_str*entropy_schedule[iteration]*entropy
        + angle_penalty_str * angle_penalty / batch_size
    )

    return loss_val, energies, entropy, angle_penalty

def qdarts_arch_loss(qs:torch.Tensor,
                     probs:torch.Tensor,
                     hamiltonian:torch.Tensor,
                     energy_normalization_factor:float,
                     iteration:int,
                     entropy_schedule:torch.Tensor, 
                     entropy_penalty_str:float, 
                     batched:bool=True
                     )->tuple[torch.Tensor,torch.Tensor]:
    if batched:
        energy = torch.vmap(getExpectation,(0,None))(qs, hamiltonian)[0]
    else:
        energy = getExpectation(qs, hamiltonian)

    entropy =  normalized_mean_entropy(probs)
    loss_val = (
        (energy / energy_normalization_factor)
        + (
            entropy_schedule[iteration]
            * entropy
            * entropy_penalty_str
        )
    )

    return loss_val, torch.stack([energy, entropy],dim=0)

def qdarts_angle_loss(qs:torch.Tensor,
                      angles:torch.Tensor, 
                      hamiltonian:torch.Tensor,
                      energy_normalization_factor:float,
                      angle_penalty_str:float,
                      batched:bool=True
                      )->torch.Tensor:
    if batched:
        energy = torch.vmap(getExpectation,(0,None))(qs, hamiltonian)[0]
    else:
        energy = getExpectation(qs, hamiltonian)
    angle_penalty = anglePenalty(angles)

    loss_val = (
        (energy / energy_normalization_factor)
        + (angle_penalty * angle_penalty_str)
    )

    return loss_val