import torch

from qtorch.quantumstate.fidelity import fidelity

from ..metrics import anglePenalty, normalized_mean_entropy

def loss(qs:torch.Tensor, 
         angles:torch.Tensor, 
         probs:torch.Tensor,
         ref_state:torch.Tensor,
         iteration:int,
         entropy_schedule:torch.Tensor, 
         entropy_penalty_str:float, 
         angle_penalty_str:float)->torch.Tensor:
    fid = torch.vmap(fidelity,(0,None))(qs, ref_state)[0]
    angle_penalty = anglePenalty(angles)
    entropy = normalized_mean_entropy(probs)
    
    loss_val = (
        (1.0 - fid)
        + (
            entropy_schedule[iteration]
            * entropy
            * entropy_penalty_str
        )
        + (angle_penalty * angle_penalty_str)
    )
    
    return loss_val, torch.stack([fid, entropy, angle_penalty],dim=0)

def qdarts_angle_loss(qs:torch.Tensor,
               angles:torch.Tensor, 
               ref_state:torch.Tensor,
               angle_penalty_str:float)->torch.Tensor:
    fid = torch.vmap(fidelity,(0,None))(qs, ref_state)[0]
    angle_penalty = anglePenalty(angles)

    loss_val = (
        (1.0 - fid)
        + angle_penalty_str*angle_penalty
    )
    return loss_val

def qdarts_arch_loss(qs:torch.Tensor,
                    probs:torch.Tensor,
                    ref_state:torch.Tensor,
                    iteration:int,
                    entropy_schedule:torch.Tensor, 
                    entropy_penalty_str:float, 
                    )->tuple[torch.Tensor, torch.Tensor]:
    fid = torch.vmap(fidelity,(0,None))(qs, ref_state)[0]
    entropy = normalized_mean_entropy(probs)
    
    loss_val = (
        (1.0 - fid)
        + (
            entropy_schedule[iteration]
            * entropy
            * entropy_penalty_str
        )
    )
    
    return loss_val, torch.stack([fid, entropy],dim=0)