import torch


from qtorch.quantumstate.measurements import getExpectation

from examples.qas.metrics import anglePenalty,normalized_mean_entropy


def loss(qs:torch.Tensor, 
         angles:torch.Tensor, 
         probs:torch.Tensor,
         hamiltonian:torch.Tensor,
         energy_shift:float,
         iteration:int,
         entropy_schedule:torch.Tensor, 
         entropy_penalty_str:float, 
         angle_penalty_str:float,
         batched:bool=True)->torch.Tensor:
    
    if batched:
        energy = torch.vmap(getExpectation,(0,None))(qs, hamiltonian)[0]
    else:
        energy = getExpectation(qs, hamiltonian)
    angle_penalty = anglePenalty(angles)

    entropy =  normalized_mean_entropy(probs)
    loss_val = (
        (energy - energy_shift)
        + (
            entropy_schedule[iteration]
            * entropy
            * entropy_penalty_str
        )
        + (angle_penalty * angle_penalty_str)
    )

    return loss_val, torch.stack([energy, entropy, angle_penalty],dim=0)

def qdarts_arch_loss(qs:torch.Tensor,
                     probs:torch.Tensor,
                     hamiltonian:torch.Tensor,
                     energy_shift:float,
                     iteration:int,
                     entropy_schedule:torch.Tensor, 
                     entropy_penalty_str:float, 
                     batched:bool=True
                     )->tuple[torch.Tensor,torch.Tensor]:
    if batched:
        energy = torch.vmap(getExpectation,(0,None))(qs, hamiltonian)[0]
    else:
        energy = getExpectation(qs, hamiltonian)

    entropy =  normalized_mean_entropy(probs)
    loss_val = (
        (energy - energy_shift)
        + (
            entropy_schedule[iteration]
            * entropy
            * entropy_penalty_str
        )
    )

    return loss_val, torch.stack([energy, entropy],dim=0)

def qdarts_angle_loss(qs:torch.Tensor,
                      angles:torch.Tensor, 
                      hamiltonian:torch.Tensor,
                      energy_shift:float,
                      angle_penalty_str:float,
                      batched:bool=True
                      )->torch.Tensor:
    if batched:
        energy = torch.vmap(getExpectation,(0,None))(qs, hamiltonian)[0]
    else:
        energy = getExpectation(qs, hamiltonian)
    angle_penalty = anglePenalty(angles)

    loss_val = (
        (energy - energy_shift)
        + (angle_penalty * angle_penalty_str)
    )

    return loss_val