import os
import numpy as np
import torch
import networkx as nx

device = "cpu"

def get_problem_SK(N, i):
    np.random.seed(i)
    
    J_np = np.random.randn(N,N)
    J_np = J_np + J_np.T
    J_np = np.sign(J_np)
    J_np = J_np - np.diag(np.diag(J_np))
    
    return torch.tensor(J_np, device = device, dtype = torch.float32)


def get_ER_graph(N, p, i):
    np.random.seed(i)
    
    J_np = np.random.rand(N,N)
    J_np = J_np + J_np.T
    
    J_np = J_np - np.diag(np.diag(J_np))

    
    
    J_np = 1*(J_np >= np.percentile(J_np.reshape(-1), 100*(1 - p)))
    
    

    return torch.tensor(J_np, device=device, dtype=torch.float32)

def read_Jij_MIS(N_range, i, test = False, sdds = False):
    instance_set_path = "./MIS/rb%i-%i/train" % N_range
    if(test):
    	instance_set_path = "./MIS/rb%i-%i/test" % N_range
    
    if(sdds):
        instance_set_path = "./MIS/sdds-rb%i-%i/train" % N_range
        if(test):
            instance_set_path = "./MIS/sdds-rb%i-%i/test" % N_range

    inst_file_common_name = "/GR_%i_%i_%i.gpickle" % (N_range[0], N_range[1], i)
    path = instance_set_path + inst_file_common_name
    
    g = nx.read_gpickle(path)
    N = len(g.nodes())
    Jij = torch.zeros(N, N, device=device)
    
    for e in g.edges():
        if(e[0] != e[1]):
            Jij[e[0], e[1]] = Jij[e[1], e[0]] = 1
    
    return Jij, N

def get_problem(problem_type, instance_type, N, p, N_range, i):
    if problem_type == 1:#MIS
        if instance_type == 0:
            return get_ER_graph(N, p, i), N
        if instance_type == 1:
            Jij, N_ = read_Jij_MIS(N_range, i)
            return Jij, N_
    elif problem_type == 0:#Ising
        if instance_type == 0:
            return get_problem_SK(N, i)


def get_GE(instance_path, N, i):
    file = "sol.txt"
    
    if not os.path.exists(instance_path):
        os.makedirs(instance_path)
    
    sol = []
    if os.path.exists(instance_path + file):
        sol = np.loadtxt(instance_path + file)
        if(len(sol.shape) == 0):
            sol = [sol]
    
    if i >= len(sol):
        sol_ = np.zeros(i+1)
        sol_[:len(sol)] = sol
        sol = sol_
        np.savetxt(instance_path + file, sol)
    return sol[i]

def save_GE(instance_path, N, i, GE, objective_type):    
    file = "sol.txt"
    
    if not os.path.exists(instance_path):
        os.makedirs(instance_path)
    
    sol = []
    if os.path.exists(instance_path + file):
        sol = np.loadtxt(instance_path + file)
        if(len(sol.shape) == 0):
            sol = [sol]
    
    if i >= len(sol):
        sol_ = np.zeros(i+1)
        sol_[:len(sol)] = sol
        sol = sol_
    
    if objective_type == "max" and GE > sol[i] or objective_type == "min" and GE < sol[i]:
        print("updated sol", i, sol[i], GE)
        sol[i] = GE
        np.savetxt(instance_path + file, sol) 





def read_Jij(inst_num, N):
    
    instance_set_path = "../../InstanceSets/BENCHSKL/"

    instance_file_dir = "BENCHSKL_%i_100/" % N
    inst_file_common_name = "BENCHSKL_%u_100_" % N
    
    path = instance_set_path + instance_file_dir + inst_file_common_name + str(inst_num+1)
    
    file = np.loadtxt(path)
    
    
    Jij = np.zeros((N,N))
    
    for idx in range(len(file[:,0])):
        i,j,val = file[idx,:]
        #print(i,j,val)
        Jij[int(i)-1,int(j)-1] = Jij[int(j)-1,int(i)-1] = val
    return -Jij

def read_GS(inst_num, N):
    
    instance_set_path = "../../InstanceSets/BENCHSKL/"

    instance_file_dir = "BENCHSKL_%i_100/" % N
    inst_file_common_name = "BENCHSKL_%u_100_" % N
    
    path = instance_set_path + instance_file_dir + inst_file_common_name + str(inst_num+1) + "_GS"
    
    file = np.loadtxt(path)
    
    return file
    
def read_ground_E(inst_num, N):
    
    instance_set_path = "../../InstanceSets/BENCHSKL/"

    instance_file_dir = "BENCHSKL_%i_100/" % N
    inst_file_common_name = "BENCHSKL_%u_100_" % N
    
    path = instance_set_path + instance_file_dir + inst_file_common_name + "SOL"
    
    file = np.loadtxt(path)
    
    Jij = -read_Jij(inst_num, N)
    
    #print file[inst_num]
    #print np.sum(Jij)
    
    return -np.sum(Jij)*0.5 - 2*file[inst_num]
