import torch

from typing import Literal

from ..penalty_scheduler import oscillating_scheduler

def generateRandomGraph(num_vertices:int, edge_probability:float)->torch.Tensor:
    '''
    Generate an Erdos-Renyi graph with the specified number of vertices and the
    edge creation probability.
    Returns an upper-triangular matrix corresponding to the adjacency matrix of
    the random graph

    Arguments
    ---------
    num_vertices: int
        The number of vertices in the graph
    edge_probability: float
        The edge creation probability for the graph
    
    Returns
    -------
    torch.Tensor
        Adjacency matrix for the graph
    '''
    # Create an NxN matrix with random values in [0,1)
    random_matrix = torch.rand((num_vertices, num_vertices))
    
    # Create upper triangular adjacency matrix with edges where random value < p
    adj_matrix = (random_matrix < edge_probability).int()
    adj_matrix = torch.triu(adj_matrix, diagonal=1)
    
    return adj_matrix

def getMaxCutHamiltonian(adj_matrix:torch.Tensor)->torch.Tensor:
    '''
    Calculates the diagonal elements of the max-cut Hamiltonian for the given
    graph. It is assumed that the graph is undirected, unweighted and no edges
    start and end at the same vertex.

    $$H = - \sum_{(i,j) \in E}\frac{( I - Z_i Z_j )}{2}$$

    Arguments
    ---------
    adj_matrix: torch.Tensor
        The adjacency matrix for the graph
    '''
    n = adj_matrix.shape[0]
    H = torch.zeros(2**n)
    edges = adj_matrix.triu(diagonal=1).nonzero()
    K = torch.arange(2**n)
    for edge in edges:
        i,j = edge[0].item(), edge[1].item()
        H[K] -= 1.0 - (-1)**( ((K>>i)^(K>>j)) & 1 )
    return H/2.0

def getEntropySchedule(N:int, start_val:float, start_time:int,
                       duration:Literal['full','half'], 
                       num_oscillations:int)->torch.Tensor:
    assert start_val < 1.0
    assert start_time >= 0 and start_time < N
    assert num_oscillations > 0
    delta_t = N - start_time
    if duration == 'half':
        delta_t //= 2
    t = torch.cat([torch.zeros(start_time),
                   torch.linspace(0.0,1.0,delta_t),
                   torch.ones(N-(start_time+delta_t))])
    return oscillating_scheduler(t, start_val, 1.0, num_oscillations)

def graphToSuperCircuitStructure(adj_matrix:torch.Tensor)->torch.Tensor:
    return torch.stack(torch.where(adj_matrix.triu(1)==1)).T
