import pickle
import scipy as sp
import numpy as np
from ortools.sat.python import cp_model
import networkx as nx
import matplotlib.pyplot as plt
from envs.util import ErdosRenyiGraphGenerator
from collections import Counter

def graph_color(g, color_num=3):
    model = cp_model.CpModel()
    adj = nx.to_numpy_array(g)
    node_num = adj.shape[0]
    nodes = [
        model.NewIntVar(0, color_num-1, 'x%i' % i) for i in range(node_num)
    ]
    for i in range(node_num):
        for j in range(node_num):
            if adj[i][j] != 0:
                model.Add(nodes[i] != nodes[j])

    solver = cp_model.CpSolver()
    status = solver.Solve(model)
    if status == cp_model.OPTIMAL or status == cp_model.FEASIBLE:
        colors = []
        for i in nodes:
            colors.append(solver.Value(i))
        return color_num
    else:
        return 0
        print('No solution found.')

def draw(colors, g):
    color_map = []
    for i in colors:
        if i == 0:
            color_map.append('red')
        elif i == 1:
            color_map.append('blue')
        elif i == 2:
            color_map.append('green')
        elif i == 3:
            color_map.append('yellow')
        elif i == 4:
            color_map.append('pink')
        elif i == 5:
            color_map.append('purple')
    pos = nx.spring_layout(g)
    nx.draw(g, node_color=color_map, with_labels=True, pos=pos)
    plt.show()

def get_opt(graph_save_loc):
    opt_color_nums = np.zeros(100)
    graphs_test = pickle.load(open(graph_save_loc, 'rb'))
    color_num = 2
    while len(opt_color_nums.nonzero()[0]) != 100:
        for j, test_graph in enumerate(graphs_test):
            if j not in opt_color_nums.nonzero()[0]:
                graph = nx.from_numpy_array(np.array(test_graph))
                opt_color_num = graph_color(graph, color_num)
                opt_color_nums[j] = opt_color_num
        color_num += 1
    print(Counter(opt_color_nums))
    return opt_color_nums









