import numpy as np
import networkx as nx
from copy import deepcopy
from time import time
from tqdm import tqdm
import random
import torch
from torch.utils.data import DataLoader
import dgl
from cb.ppo.actor_critic import ActorCritic as ActorCriticCb
from cb.ppo.graph_net import PolicyGraphConvNet as PolicyGraphConvNetCb
from cb.ppo.graph_net import ValueGraphConvNet as ValueGraphConvNetCb
from cb.env import VCP
from mis.ppo.actor_critic import ActorCritic as ActorCriticMis
from mis.ppo.graph_net import PolicyGraphConvNet as PolicyGraphConvNetMis
from mis.ppo.graph_net import ValueGraphConvNet as ValueGraphConvNetMis
from mis.env import MaximumIndependentSetEnv
from pulp import *
import pulp
from random import randint
from itertools import combinations, chain
from Coloring_networkx_addons import ThinGraph, is_coloring_feasible
from timeit import default_timer
import os


device = 'cpu'

# env
hamming_reward_coef = 0.1

# actor critic
num_layers = 4
hidden_dim = 128

#optimiazation
max_epi_t = 64 #hp
episode_length= 64 #hp
num_parallel_graph=400 #hp
gurobi_limit=5

# dataset specific
min_num_nodes = 50 
max_num_nodes = 100 

num_colors=15

# construct everything for mis
env_mis = MaximumIndependentSetEnv(
    max_epi_t = max_epi_t,
    max_num_nodes = max_num_nodes,
    hamming_reward_coef = hamming_reward_coef,
    device = device
    )

# construct actor critic network
actor_critic_mis = ActorCriticMis(
    actor_class = PolicyGraphConvNetMis,
    critic_class = ValueGraphConvNetMis, 
    max_num_nodes = max_num_nodes, 
    hidden_dim = hidden_dim,
    num_layers = num_layers,
    device = device
    )

# Load the saved state dictionary
model_path_mis = './model_mis.pth'
state_dict_mis = torch.load(model_path_mis)

# Load the state dictionary into the model
actor_critic_mis.load_state_dict(state_dict_mis)

#constructing everything for cb

env_cb=VCP(max_epi_t=episode_length,device=device,num_colors=num_colors)

#construct actor critic network
actor_critic_cb =ActorCriticCb(
    actor_class= PolicyGraphConvNetCb,
    critic_class=ValueGraphConvNetCb,
    max_num_nodes=max_num_nodes,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    num_colors=num_colors,
    device=device
)

model_path_cb = './model_300.pth'
state_dict_cb = torch.load(model_path_cb,map_location=torch.device('cpu'))

# Load the state dictionary into the model
actor_critic_cb.load_state_dict(state_dict_cb)

filepath='./results_benchmark_temp.txt'
result_file=open(filepath,'w')

# define evaluate function for mis
def evaluate_mis(g, actor_critic):
    actor_critic.eval()
    cum_cnt = 0
    cum_eval_sol = 0.0
    
    g.set_n_initializer(dgl.init.zero_initializer)
    ob = env_mis.register(g, num_samples = 1)
    while True:
        with torch.no_grad():
            action = actor_critic.act(ob, g)

        ob, reward, done, info = env_mis.step(action)
        if torch.all(done).item():
            cum_eval_sol += info['sol'].max(dim = 1)[0].sum().cpu()
            cum_cnt += g.batch_size
            break
    
    ob_temp=ob.select(2,0)
    return cum_eval_sol, ob_temp.flatten()



def colors_used_mis(graph):
    count=0
    while graph.number_of_nodes()!=0:
        _,ob=evaluate_mis(graph,actor_critic_mis)
        if torch.any(ob==1):
            count+=1
    
        else:
            count+=torch.sum(ob==0).item()
            break

        graph=graph.subgraph(ob==0)

    return count

def evaluate_cb(graph, actor_critic):
    actor_critic.eval()
    progress_bar = tqdm(total=episode_length, desc="Validation Progress", unit="iteration")
    t1=time()
    graph_list = [graph.clone() for _ in range(num_parallel_graph)]
    g=dgl.batch(graph_list)
    g.set_n_initializer(dgl.init.zero_initializer)
    ob = env_cb.register(episode_length,g,1)
    ob=ob.to(device)
    while True:
        with torch.no_grad():
            action = actor_critic.act(ob, g).to(device)
        ob,_,done = env_cb.step(action+1)
        progress_bar.update(1)
        if torch.all(done).item():
            break
    state=ob.select(2,0).int().squeeze().tolist()
    sat_best=0
    solution_best=10000
    ob_best=[]
    for graph in dgl.unbatch(g):
        num_nodes=graph.number_of_nodes()
        local_state=state[:num_nodes]
        state[:num_nodes]=[]
        satisfied=round(100-local_state.count(0)*100/num_nodes,2)
        my_soln=len(set(x for x in local_state if x != 0))

        if satisfied>sat_best:
            sat_best=satisfied
            solution_best=my_soln
            ob_best=deepcopy(local_state)
        elif satisfied==sat_best and my_soln<solution_best:
            solution_best=my_soln
            ob_best=deepcopy(local_state)
    solution_time=time()-t1 
    return sat_best,torch.tensor(ob_best),solution_best,solution_time

    
def minimum_graph_coloring(graph,best_time_cb=5,max_colors=15):
 
    # Number of vertices in the graph
    num_vertices = graph.number_of_nodes()

    # Create the ILP problem
    problem = pulp.LpProblem("MinimumGraphColoring", pulp.LpMinimize)

    # Variables
    x = pulp.LpVariable.matrix("x",(range(num_vertices),range(max_colors)),cat="Binary")
    y = pulp.LpVariable.matrix("y", range(max_colors), cat='Binary')

    # Objective: Minimize the number of colors used
    problem += pulp.lpSum(y), "MinimizeColors"

    # Each vertex must receive exactly one color _clubbed
    for i in range(num_vertices):
        problem += pulp.lpSum(x[i]) == 1, f"ColorAssignment_{i}"

    # No two adjacent vertices can share the same color and if any color is given to a node, it is marked with y
    for i, j in graph.edges():
        for c in range(max_colors):
            problem += x[i][c] + x[j][c] <= y[c], f"AdjacentColorConflict_{i}_{j}_{c}"

    # Solve the problem
    problem.solve(pulp.GUROBI_CMD(msg=0,options=[('timeLimit',best_time_cb)]))
    colors_used = pulp.value(problem.objective)
    print(LpStatus[problem.status])
    try:
        return LpStatus[problem.status], int(colors_used)
    except:
        return LpStatus[problem.status], 0



#tabucol
def Tabucol_opt(G, k, C_L=6, C_lambda=0.6, C_maxiter=100000, verbose=False):
    '''Tabucol_opt provides the graph coloring with the smallest number of
    colors'''
    assert len(G) > 0
    # print(k,len(G))
    # assert k > 0 and k <= len(G)
    best_colors, best_niter = {}, 0
    # compute length of max clique as number of colors cannot be less
    length_max_clique = 1
    ncolors = k
    while ncolors >= length_max_clique:
        colors, niter = Tabucol(G, ncolors, C_L, C_lambda, C_maxiter)
        if is_coloring_feasible(G, colors):
            best_colors = dict(colors)
            best_niter = niter
            ncolors = max(colors.values())
            if verbose:
                print('Tabucol_opt found a solution with %d colors' % (max(colors.values())+1))
        else:
            if not best_colors:
                print('Tabucol_opt did not find a solution with %d colors' % k)
            break

    return best_colors, best_niter


def Tabucol(G, k, C_L=6, C_lambda=0.6, C_maxiter=100000, verbose=False):
    '''Tabucol provides a vertex k-coloring if such coloring is found
    using a tabu search scheme. Tabucol features are inspired from
    "A survey of local search methods for graph coloring" by Philippe Galiniera
    and Alain Hertzb" '''

    def is_tabu_allowed(tabu_d, nc, n_iter):
        ''' assess whether candidate is in tabu dictionary'''
        return (nc not in tabu_d) or (nc in tabu_d and tabu_d[nc] < n_iter)

    # generate random coloring and compute color classes
    colors = {i: randint(0, k-1) for i in G.nodes()}
    color_classes = {col: set(i for i in G.nodes()
                     if colors[i] == col) for col in range(k)}
    # compute actual violations and number of violated edges as F
    viol_edges = {col: [edge for edge in combinations(col_set, 2)
                        if edge in G.edges()]
                  for col, col_set in color_classes.items()}
    #  viol_nodes contains nodes involved in violated edges, a node
    #  being counted as many times it appears in viol_edges
    viol_nodes = {col: list(chain.from_iterable(viol_edges[col]))
                  for col in color_classes.keys()}
    # F is the total number of violations (violated edges)
    F = sum(len(v) for v in viol_edges.values())

    # initiate local search in 1-move neighborhood
    # create tabu dictionary with (node, col) as key and niter as value
    tabu = {}
    niter = 0
    restrictive = False
    while F > 0 and niter < C_maxiter:

        # generate candidates with violation variation
        delta = {}
        for col, node_list in viol_nodes.items():
            for node in node_list:
                old_count = viol_nodes[col].count(node)
                for col_cand in range(k):
                    nc = (node, col_cand)
                    if col_cand != col and is_tabu_allowed(tabu, nc, niter):
                        new_count = len(set(G[node]).intersection(color_classes[col_cand]))
                        delta[nc] = new_count-old_count
        if not delta:
            # skip current iteration
            if not restrictive:
                if verbose:
                    print('tabu scheme is probably too restrictive')
                restrictive = True
        else:
            # select a candidate among the ones with lowest delta
            delta_c = min(delta.values())
            final_cand = [(n, c) for (n, c), value in delta.items()
                          if value == delta_c]
            # choose a candidate with lowest value at random
            (node_c, col_c) = final_cand[randint(0, len(final_cand)-1)]
            # update tabu dictionary
            deleting_tabu_list = [key_t for key_t, iter_t in tabu.items()
                                  if iter_t <= niter]
            for key_t in deleting_tabu_list:
                del tabu[key_t]
            old_col = colors[node_c]
            tabu[(node_c, old_col)] = niter + int(C_L + C_lambda*F)
            # update number of violations
            F += delta_c
            # update modified violation edge and node classes
            viol_edges[old_col] = [edge for edge in viol_edges[old_col]
                                   if node_c not in edge]
            viol_nodes[old_col] = list(chain.from_iterable(viol_edges[old_col]))
            new_viol_edges = ((node_c, u) for u in set(G[node_c]).intersection(color_classes[col_c]))
            viol_edges[col_c].extend(new_viol_edges)
            viol_nodes[col_c] = list(chain.from_iterable(viol_edges[col_c]))
            # update colors and color classes
            colors[node_c] = col_c
            color_classes[old_col] -= {node_c}
            color_classes[col_c] = color_classes[col_c].union({node_c})

        niter += 1

    if verbose and niter == C_maxiter:
        print('exiting loop as max iterations exceeded')

    # get rid of empty color classes to speed up next search
    if not all(col_class for col_class in color_classes.values()):
        colors = {}
        count = 0
        for col_class in color_classes.values():
            if col_class:
                for n in col_class:
                    colors[n] = count
                count += 1

    return colors, niter


def tabu_main(G, initial):
    sols=[]
    time=0
    colors = {}
    start_time = default_timer()
    colors, niter = Tabucol_opt(G,
                                initial,
                                C_L=6,
                                C_lambda=0.6,
                                C_maxiter=100000,
                                verbose=False)
    elapsed = default_timer() - start_time
    try:
        return max(colors.values())+1, elapsed
    except:
        return -1 , -1



def greedy_coloring(G):
    V = len(G.nodes)
    result = [-1] * V

    # Assign the first color to the first vertex
    result[0] = 0

    # A temporary array to store the available colors.
    available = [False] * V

    # Track the number of colors used
    max_color = 0

    # Assign colors to remaining V-1 vertices
    for u in range(1, V):
        # Process all adjacent vertices and
        # flag their colors as unavailable
        for i in G.neighbors(u):
            if result[i] != -1:
                available[result[i]] = True

        # Find the first available color
        cr = 0
        while cr < V:
            if not available[cr]:
                break
            cr += 1

        # Assign the found color
        result[u] = cr

        # Update the maximum color used
        if cr > max_color:
            max_color = cr

        # Reset the values back to False for the next iteration
        for i in G.neighbors(u):
            if result[i] != -1:
                available[result[i]] = False

    # Return the number of colors used
    return max_color + 1  # +1 because colors start from 0



def load_dimacs_col_file(file_path):
    G = nx.Graph()
    max_node=0
    with open(file_path, 'r') as f:
        for line in f:
            line=line.strip()
            if line.startswith('e'):  # 'e' lines represent edges
                _, node1, node2 = line.split()
                # Subtract 1 from node labels to make them zero-indexed
                G.add_edge(int(node1) - 1, int(node2) - 1)
                max_node=max(int(node1)-1, int(node2)-1, max_node)
    for node in range(max_node+1):
        if node not in G:
            G.add_node(node)
    return G



def compare(graph,file_name, tabucol_init=15):
    best_time_cb=0
    nx_graph=graph
    dgl_graph=dgl.from_networkx(nx_graph)
    dgl_graph_mis=deepcopy(dgl_graph)

   #checking solution for rl agent 
    agent_time_begin=time()
    soln_cb_sat,ob_best,soln_cb_colors,time_cb=evaluate_cb(dgl_graph,actor_critic_cb)
    print(soln_cb_sat,soln_cb_colors)
    Flag=1
    while soln_cb_sat<100:
        Flag+=1
        if Flag>10:
            break
        dgl_graph=dgl_graph.subgraph(ob_best==0)
        if dgl_graph.num_nodes()==1:
            soln_cb_colors+=1
            break
        soln_cb_sat,ob_best,soln_cb_colors_ad,time_cb_ad=evaluate_cb(dgl_graph,actor_critic_cb)
        soln_cb_colors+=soln_cb_colors_ad
        time_cb+=time_cb_ad
        print(soln_cb_sat,soln_cb_colors)
    agent_time=time()-agent_time_begin
    best_soln_cb=(soln_cb_colors,agent_time,Flag)
    print('RL Agent:',best_soln_cb)
   
    #checking solution for gurobi
    a=time()
    status, color_pulp= minimum_graph_coloring(nx_graph,900)
    b=time()-a

    #checking solution for greedy
    c=time()
    color_greedy=greedy_coloring(nx_graph)
    d=time()-c

    #checking solution for mis
    e=time()
    soln_mis=colors_used_mis(dgl_graph_mis)
    time_mis=time()-e
    
    
    #tabucol
    channel_tabucol, f=tabu_main(nx_graph,tabucol_init)
    best_soln_tabucol=(100,channel_tabucol,f)
    print('tabucol',best_soln_tabucol)

    #checking solutions for gurobi
    best_soln_pulp=(100 if status=='Optimal' else 0, color_pulp,b)
    print('Pulp',best_soln_pulp)

    #logging to file

    #greedy
    best_soln_greedy=(color_greedy,d)
    print('Greedy:',best_soln_greedy)

    #mis
    best_soln_mis=(soln_mis,time_mis)
    print('MIS:',best_soln_mis)
   
    result_file.write(file_name+'\n')
    result_file.write(f"{nx_graph.number_of_nodes()} {nx_graph.number_of_edges()}\n")
    result_file.write(f"{best_soln_pulp[0]} {best_soln_pulp[1]} {best_soln_pulp[2]}\n")
    result_file.write(f"{best_soln_tabucol[0]} {best_soln_tabucol[1]} {best_soln_tabucol[2]:.8f}\n")
    result_file.write(f"{best_soln_greedy[0]} {best_soln_greedy[1]:.8f}\n")
    result_file.write(f"{best_soln_mis[0]} {best_soln_mis[1]}\n")
    result_file.write(f"{best_soln_cb[0]} {best_soln_cb[1]} {best_soln_cb[2]}\n")
    result_file.flush()
    result_file.write('\n')

    

directory_path='benchmark_dataset'


benchmarks={'ash331GPIA.col':(15,4), 'ash608GPIA.col':(15,4), 'ash958GPIA.col':(15,4) , '1-FullIns_3.col':(15,4), '1-FullIns_4.col':(15,5),  '1-FullIns_5.col':(15,6) , 
            '2-FullIns_3.col':(15,5), '2-FullIns_4.col':(15,6), '2-FullIns_5.col':(15,7) ,'3-FullIns_3.col':(15,6), '3-FullIns_4.col':(15,1), '4-FullIns_3.col':(15,7), 
            '4-FullIns_4.col':(15,8), '4-FullIns_5.col':(15,9), '5-FullIns_3.col':(15,1) , '1-Insertions_4.col': (15,4) , '1-Insertions_5.col': (15,1), 
            '1-Insertions_6.col': (15,7), '2-Insertions_3.col': (15,4), '2-Insertions_4.col': (15,1), '2-Insertions_5.col': (15,1), '3-Insertions_3.col': (15,4), 
            '3-Insertions_4.col': (15,1), '3-Insertions_5.col': (15,6), '4-Insertions_3.col': (15,1), '4-Insertions_4.col': (15,1),
            'le450_5a.col':(15,5),  'le450_5b.col':(15,5),  'le450_5c.col':(15,5), 'le450_5d.col':(15,5), 'mug88_1.col': (15,4) , 'mug88_25.col': (15,4), 
            'mug100_1.col': (15,4), 'mug100_25.col': (15,4), 'myciel3.col':(15,4), 'myciel4.col':(15,5), 'myciel5.col':(15,6), 'myciel6.col':(15,7), 'myciel7.col':(15,8), 
            'queen5_5.col':(15,5), 'queen6_6.col':(15,7), 'queen7_7.col':(15,7), 'queen8_8.col':(15,9),  'queen9_9.col':(15,10), 'DSJC125.1.col':(15,5),  
            'DSJC125.5.col':(30,17),  'DSJC125.9.col':(45,44), 'DSJC250.1.col':(15,8),  'DSJC250.5.col':(30,28), 'DSJC250.9.col':(75,72), 'will199GPIA.col': (15,1) }

num_graphs = len(benchmarks.keys())
graph_count=0
for file_name, values in benchmarks.items():
    graph_count+=1
    print(f'\nProcessing********************{graph_count}/{num_graphs} ****************{file_name}*********************************************************************************************')
    file_path=os.path.join(directory_path,file_name)
    nx_graph = load_dimacs_col_file(file_path)
    nodes=nx_graph.number_of_nodes()
    edges=nx_graph.number_of_edges()
    print(nodes,edges)
    if nodes<=5000 and edges<=100000:
        try:
            a=compare(nx_graph,file_name,values[0])
        except:
            print("Some error occured")
            continue

    else:
        print("Too big for processing")

result_file.close()
