from numpy import inner
import torch

def percentile(t: torch.tensor, q: float):
    """
    SOURCE https://gist.github.com/sailfish009/28b54c8aa6398148a6358b8f03c0b611
    
    Return the ``q``-th percentile of the flattened input tensor's data.
    
    CAUTION:
     * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used.
     * Values are not interpolated, which corresponds to
       ``numpy.percentile(..., interpolation="nearest")``.
       
    :param t: Input tensor.
    :param q: Percentile to compute, which must be between 0 and 100 inclusive.
    :return: Resulting value (scalar).
    """
    # Note that ``kthvalue()`` works one-based, i.e. the first sorted value
    # indeed corresponds to k=1, not k=0! Use float(q) instead of q directly,
    # so that ``round()`` returns an integer, even if q is a np.float32.
    k = 1 + round(.01 * float(q) * (t.numel() - 1))
    result = t.view(-1).kthvalue(k).values.item()
    return result

  
def percentile_of_val(values,query):
  result = torch.sum(values<query)/values.shape[0]
  return result

def get_divisors(n):
    '''
    Returns list divisors of a number
    '''
    import numpy as np
    divisors = []
    for i in range(1,n+1):
        if n%i==0:
            divisors.append(i)
    return divisors

def reconstruct_json(inner_template, node_values, jgraph, node_info, task = 'gridworld'):
    
    def keylist_to_str(keys):
        res = ""
        for key in keys:
            res+=f"['{key}']"
        return res

    cat_nodes = jgraph.cat_ranges['nodes']
    cat_ranges =jgraph.cat_ranges['ranges']
    nume_nodes = node_info['nume_nodes']
    bin_nodes = node_info['bin_nodes']
    for i in range(len(jgraph.node_table)):
      
        if i in cat_nodes:
            cat_idx = list(cat_nodes).index(i)
            
            value = jgraph.node_table[int(node_values[i])]
            exec("inner_template"+keylist_to_str(jgraph.node_table[i])+" = value")
        elif i in nume_nodes:
            exec("inner_template"+keylist_to_str(jgraph.node_table[i])+" = node_values[i][0]")
        elif i in bin_nodes:
            exec("inner_template"+keylist_to_str(jgraph.node_table[i])+" = bool(node_values[i] == 1)")
        elif type(jgraph.node_table[i]) != str:
            
            value = eval("inner_template"+keylist_to_str(jgraph.node_table[i]))
            exec("inner_template"+keylist_to_str(jgraph.node_table[i])+" = value")
    
    #str_to_pos = lambda x: [int(i) for i in x.split(',')]
    #if task == 'gridworld':
    #    inner_template['player']['pos'] = str_to_pos(inner_template['player']['pos'] )
    return inner_template