import networkx as nx
import csv
import math
import numpy as np
import torch

from flow import *

def read_net(road_net_filename, lcc=False):
    '''
        Reads network from csv, optionally gets
        largest connected component.
    '''
    G = nx.DiGraph()

    with open(road_net_filename, 'r') as file_in:
        reader = csv.reader(file_in)

        for r in reader:
            u = r[0]
            v = r[1]

            G.add_edge(u,v)

    if lcc:
        LWCC = sorted(nx.weakly_connected_components(G), key = len, reverse=True)[0]
        return G.subgraph(LWCC)
    else:
        return G

def normalize_features(features):
    """Normalizes features using standard scaler"""
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    
    n_features = features[list(features.keys())[0]].shape[0]
    
    feat_matrix = np.zeros((len(features), n_features))
    
    i = 0
    for e in features:
        feat_matrix[i] = features[e]
        i = i + 1
    
    scaler.fit(feat_matrix)
    feat_matrix = scaler.transform(feat_matrix)
    
    norm_features = {}
    
    i = 0
    for e in features:
        norm_features[e] = feat_matrix[i]
        i = i + 1
    
    return norm_features


def make_non_neg_norm(G, flows, features):
    '''
        Converts flow estimation instance to a non-negative
        one, i.e. where every flow is non-negative.
    '''
    new_flows = {}
    new_G = nx.DiGraph()
    new_feat = {}

    max_flow = np.max(list(flows.values()))

    for e in G.edges():
        if e in flows:
            if flows[e] < 0:
                new_e = (e[1],e[0])
                new_flows[new_e] = -flows[e] / max_flow
                new_G.add_edge(e[1],e[0])
                new_feat[new_e] = features[e]
            else:
                new_G.add_edge(e[0],e[1])
                new_flows[e] = flows[e] / max_flow
                new_feat[e] = features[e]
        else:
            new_G.add_edge(e[0],e[1])
            new_feat[e] = features[e]

    return new_G, new_flows, new_feat

invphi = (math.sqrt(5) - 1) / 2  # 1 / phi
invphi2 = (3 - math.sqrt(5)) / 2  # 1 / phi^2

def gss(f, args, a, b, tol=1e-5):
    '''Golden section search.

    Given a function f with a single local minimum in
    the interval [a,b], gss returns a subset interval
    [c,d] that contains the minimum with d-c <= tol.

    modified from: https://en.wikipedia.org/wiki/Golden-section_search
    
    Usage: gss(f_gss, [G_reg_2, ups, super_regions, updates_proj, .5, False, recall], 0., 1.)
    '''

    (a, b) = (min(a, b), max(a, b))
    h = b - a
    if h <= tol:
        return (a, b)

    # Required steps to achieve tolerance
    n = int(math.ceil(math.log(tol / h) / math.log(invphi)))

    c = a + invphi2 * h
    d = a + invphi * h
    yc = f(c, args)
    yd = f(d, args)

    for k in range(n-1):
        if yc < yd:
            b = d
            d = c
            yd = yc
            h = invphi * h
            c = a + invphi2 * h
            yc = f(c, args)
        else:
            a = c
            c = d
            yc = yd
            h = invphi * h
            d = a + invphi * h
            yd = f(d, args)

    if yc < yd:
        return (a, d)
    else:
        return (c, b)

def sparse_tensor_from_coo_matrix(matrix):
    '''
    '''
    values = matrix.data
    indices = np.vstack((matrix.row,matrix.col))

    use_cuda = torch.cuda.is_available()

    if use_cuda:
        i = torch.cuda.LongTensor(indices)
        v = torch.cuda.FloatTensor(values)
        M = torch.cuda.sparse.FloatTensor(i, v, torch.Size(matrix.shape))
    else:
        i = torch.LongTensor(indices)
        v = torch.FloatTensor(values)
        M = torch.sparse.FloatTensor(i, v, torch.Size(matrix.shape))

    return M


