import torch

from typing import Literal

from ..penalty_scheduler import oscillating_scheduler

def GHZ_State(num_qubits:int, device:torch.device|str|None=None)->torch.Tensor:
    '''Returns the GHZ state of the desired number of qubits on the specified 
    device'''
    assert num_qubits > 0
    psi = torch.zeros(2**num_qubits,dtype=torch.complex64,device=device)
    psi[0] = 1.0
    psi[-1] = 1.0
    return psi / torch.linalg.norm(psi)

def W_State(num_qubits:int, device:torch.device|str|None=None)->torch.Tensor:
    '''Return the W state of the desired number of qubits on the specified 
    device'''
    assert num_qubits > 1
    psi = torch.zeros(2**num_qubits,dtype=torch.complex64,device=device)
    for i in range(num_qubits):
        psi[2**i] = 1.0
    return psi/torch.linalg.norm(psi)

def getEntropySchedule(N:int, start_val:float, start_time:int,
                       duration:Literal['full','half'], 
                       num_oscillations:int)->torch.Tensor:
    assert start_val < 1.0
    assert start_time >= 0 and start_time < N
    assert num_oscillations > 0
    delta_t = N - start_time
    if duration == 'half':
        delta_t //= 2
    t = torch.cat([torch.zeros(start_time),
                   torch.linspace(0.0,1.0,delta_t),
                   torch.ones(N-(start_time+delta_t))])
    return oscillating_scheduler(t, start_val, 1.0, num_oscillations)
