import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import random
import math
from time import time
from utils import  node_mapping, node_mapping_kernel, core_decomposition, FindLBIS, FindClq
import os
from tqdm import tqdm
import random
import glob

cutoff_time=60




# def recolor(G,v, coloring):
#     new_color = coloring[v]
#     color_map={i:set() for i in range(1, new_color)}
#     for i in G.nodes():
#         if coloring[i] != 0 and coloring[i] != new_color:
#             color_map[coloring[i]].add(i)
    
#     neigh_v=set(G.neighbors(v))
#     conflict=[]
#     for i in range(1, new_color):
#         if len(neigh_v.intersection(color_map[i]))==1:
#             conflict.append((i,neigh_v.intersection(color_map[i]).pop()))
    
    
#     for i, t in conflict:
#         neig_t=set(G.neighbors(t))
#         for j in range(i+1, new_color):
#             if len(neig_t.intersection(color_map[j]))==0:
#                 coloring[v]=coloring[t]
#                 coloring[t]=j
#                 return True, coloring
    
#     return False, coloring 

# def recolor(G,v,coloring):
#     new_color=coloring[v]
#     conflicts=[0]*(new_color+1)
#     used=[-1]*(new_color+1) 
#     for w in G.neighbors(v):
#         if coloring[w]!=0:
#             conflicts[coloring[w]]+=1
#             used[coloring[w]]=w 
    
#     for i in range(1,new_color):
#         if conflicts[i]==1:
#             w=used[i]
#             for u in G.neighbors(w):
#                 if coloring[u]!=0:
#                     used[coloring[u]]=w     
#             c=0 
#             for j in range(i+1,new_color):
#                 if (used[j]!=w):
#                     c=j
#                     break
                
#             if c!=0:
#                 coloring[v]=coloring[w]
#                 coloring[w]=c
#                 return True, coloring

#     return False, coloring


def recolor(G,v,coloring):
    new_color=coloring[v]
    conflicts=[0]*(new_color+1)
    used=[-1]*(new_color+1) 
    for w in G.neighbors(v):
        if coloring[w]!=0:
            conflicts[coloring[w]]+=1
            used[coloring[w]]=w 
    
    for i in range(1,new_color):
        if conflicts[i]==1:
            w=used[i]
            used_new=[-1]*(new_color+1)
            for u in G.neighbors(w):
                if coloring[u]!=0:
                    used_new[coloring[u]]=w     
            c=0 
            for j in range(i+1,new_color):
                if (used_new[j]!=w):
                    c=j
                    break
                
            if c!=0:
                coloring[v]=coloring[w]
                coloring[w]=c
                return True, coloring

    return False, coloring
  
            
def find_saturation_degree(G, coloring):
    sat_degree = [0] * G.number_of_nodes()
    for node in G.nodes():
        color_neighbors=set()
        for neighbor in G.neighbors(node):
            if coloring[neighbor] != 0:
                color_neighbors.add(coloring[neighbor])
        sat_degree[node] = len(color_neighbors)
    return sat_degree

def update_saturation_degree(G, coloring, node, sat_degree):
    # Update the saturation degree of the neighbors of the given nodemax_color_used
    for neighbor in G.neighbors(node):
        color_used=set()
        for nodes in G.neighbors(neighbor):
            if coloring[nodes] != 0:
                color_used.add(coloring[nodes])
        sat_degree[neighbor] = len(color_used)  
    return sat_degree

def ColorKernel(G_main, is_colored=False):
    G,_=node_mapping_kernel(G_main)
    G_core,_=node_mapping(G)
    coloring= [0]*G.number_of_nodes()
    max_color_used=0
    #G_core,_=node_mapping(G)
    if is_colored==False:
        cores=core_decomposition(G_core)
        sorted_cores = sorted(range(len(cores)), key=lambda i: cores[i], reverse=True)
        for v in sorted_cores:
            #find the value i>0 such that no neighbor of u has taken color i 
            ct=1
            neighbor_coloring=set()
            for u in G.neighbors(v):
                if coloring[u]!=0:
                    neighbor_coloring.add(coloring[u])
            for i in range(1, G.number_of_nodes()+1):
                if i not in neighbor_coloring:
                    ct=i
                    break
                
            coloring[v]=ct
            if ct>max_color_used:
                status,coloring=recolor(G,v,coloring)
                if status==False:
                    max_color_used=ct
        
    else:
        V_set=list(G.nodes())
        saturation_degree=[0]*G.number_of_nodes()
        while len(V_set)>0:
            #saturation_degree=find_saturation_degree(G, coloring)
            max_saturaion_deg=0
            max_indices=[]
            for i in V_set:
                if saturation_degree[i]>max_saturaion_deg:
                    max_saturaion_deg=saturation_degree[i]
                    max_indices=[]
                    max_indices.append(i)
                elif saturation_degree[i]==max_saturaion_deg:
                    max_indices.append(i)
                     
            v=random.choice(max_indices)
            V_set.remove(v)

            ct=1
            neighbor_coloring=set()
            for u in G.neighbors(v):
                if coloring[u]!=0:
                    neighbor_coloring.add(coloring[u])
            for i in range(1, G.number_of_nodes()+1):
                if i not in neighbor_coloring:
                    ct=i
                    break
                
            coloring[v]=ct   
            if ct>max_color_used:
                status,coloring=recolor(G,v,coloring)
                if status==False:
                    max_color_used=ct
            saturation_degree=update_saturation_degree(G,coloring,v,saturation_degree)
                    
    return coloring

def check_color(G, coloring):
    for i in G.edges():
        if coloring[i[0]]==coloring[i[1]]:
            return False
    return True


def construct_solution(G_k, G_m, e_m, alpha, lb_tracker, num_nodes):
    #this constructrion takes assumption that alpha is the best solution for G_k
    if len(G_m)==0: #if there is no reduction alpha will be the constructed solution
        return alpha
    
    alpha_plus=[0]*num_nodes #initializing the solution
    #making copies sot that the original values are not changed
    alpha=alpha.copy()
    lb_tracker=lb_tracker.copy()
    G_k=G_k.copy()
    G_m=G_m.copy()
    e_m=e_m.copy()
    
    #reoving the last entry of the lower bound tracker
    if len(G_m)<len(lb_tracker):
        lb_tracker=lb_tracker[:len(G_m)]
    
    assert len(G_m)==len(e_m)==len(lb_tracker), f"length of G_m and e_m lb_tracker should be same ====> {len(G_m)} {len(e_m)} {len(lb_tracker)}"
    
    for i in G_k.nodes():
        alpha_plus[i]=alpha.pop(0)
        
    for _ in range(len(G_m)):
        
        node_set = G_m.pop()
        edge_set = e_m.pop()
        lower_bound_prev=lb_tracker.pop()
        
        chromatic_num_current=max(alpha_plus)
        
        #construction of graph
        G_k.add_nodes_from(node_set)
        G_k.add_edges_from(edge_set)
        
        #assert chromatic_num_current<=lower_bound_prev, f"chromatic number of current graph should be less than or equal to the lower bound of previous graph"
        
        if chromatic_num_current<lower_bound_prev:
            new_color=max(alpha_plus)+1
            for node in node_set:
                alpha_plus[node]=new_color
                
        elif chromatic_num_current>=lower_bound_prev:
            for node in node_set:
                color_ued_by_neighbors=set()
                for neighbor in G_k.neighbors(node):
                    if alpha_plus[neighbor]!=0:
                        color_ued_by_neighbors.add(alpha_plus[neighbor])
            
                c=1
                
                for color_iter in range(1, chromatic_num_current+1):
                    if color_iter not in color_ued_by_neighbors:
                        c=color_iter
                        break
                alpha_plus[node]=c
        
    return alpha_plus


def FastColor(G):
    
    G_k=G.copy()
    G_m=[]
    e_m=[]
    lb_G=0
    lb_tracker=[]
    ub_G=G.number_of_nodes()
    lb_tracker.append(0)
    alpha_best=[0]*G.number_of_nodes()
    lb_k=0
    isColored=False
    t=1
    time_start=time()
    ub_tracker=[]
    ub_tracker.append(ub_G)
    while 1:
        bms_param_adjustment_required,lb_k=FindClq(G_k, lb=lb_k, t=t)
        if bms_param_adjustment_required:
            t=2*t
        if t>64:
            t=1
        if lb_k>lb_G:
            lb_G=lb_k
            
        lb_tracker[-1]=lb_k #we are overriding the current reduced graph's lower bound, this can be a debug point
            
        I=FindLBIS(G_k, lb_k)
        # print(I)
        # print(len(I))
        #a=input()
        edges_before_reduction=G.edges()
        removed_edges=set() #keeps track for removed edges in this iteration
        remaining_nodes=set(G_k.nodes()).difference(I)
        G_k=G_k.subgraph(remaining_nodes) #this is new Gk but the nodes will not be renumbered, if you remove 0 then there is no 9 in the node set
        #print(G_k.number_of_nodes(), G_k.number_of_edges())
        #a=input()
        for ed in edges_before_reduction:
            if ed not in G_k.edges():
                removed_edges.add(ed)
                
        #if there is an independent set independednt set, then we need to reset some values as the graph is reduces 
        if len(I)!=0:
            lb_k=0
            isColored=False
            G_m.append(I)
            e_m.append(removed_edges)
            lb_tracker.append(lb_k) #adding a new entry for the lower bound of reduced graph #reduce kia to pehle waala daalna padega, yaha tak upper bound change nahi kia hai, may be problematic
        
        alpha=ColorKernel(G_k, is_colored=isColored) #this is problematic as it's lebgth is dynamic
        alpha_plus=construct_solution(G_k, G_m, e_m, alpha,lb_tracker,G.number_of_nodes()) #this construction is problematic
        isColored=True
       
        if max(alpha_plus)<ub_G:  
            alpha_best=alpha_plus.copy()
            ub_G=max(alpha_plus)
            ub_tracker.append(max(alpha_plus))
            
            
        time_elapsed=time()-time_start
        if ub_G==lb_G: #or lb_k+len(G_m)==ub_G: # the second condition will work for triavial case in which length of gm is 1 (construction based method needs to be implemented)
            status=check_color(G, alpha_best)
            print(ub_tracker, min(ub_tracker), lb_k,lb_G)
            return 'proved optimality', max(alpha_best), time_elapsed, status, len(G_m)
        
        if time_elapsed>cutoff_time:
            status=check_color(G, alpha_best)
            break
    print(ub_tracker, min(ub_tracker), lb_k,lb_G)
    return 'from outside while', max(alpha_best), time()-time_start, status, len(G_m)
            
            

def load_dimacs_col_file(filepath):
    """
    Reads a .col file in DIMACS format and returns a NetworkX graph.
    """
    G = nx.Graph()
    with open(filepath, 'r') as f:
        for line in f:
            if line.startswith('c'):
                continue  # comment
            elif line.startswith('p'):
                parts = line.strip().split()
                num_nodes = int(parts[2])
                # optionally add nodes explicitly
                G.add_nodes_from(range(num_nodes))
            elif line.startswith('e'):
                #print('count', count)
                parts = line.strip().split()
                u = int(parts[1])-1
                v = int(parts[2])-1
                G.add_edge(u, v)
    return G



directory_path='benchmark_dataset'


benchmarks={'ash331GPIA.col':(15,4), 'ash608GPIA.col':(15,4), 'ash958GPIA.col':(15,4) , '1-FullIns_3.col':(15,4), '1-FullIns_4.col':(15,5),  '1-FullIns_5.col':(15,6) , 
            '2-FullIns_3.col':(15,5), '2-FullIns_4.col':(15,6), '2-FullIns_5.col':(15,7) ,'3-FullIns_3.col':(15,6), '3-FullIns_4.col':(15,1), '4-FullIns_3.col':(15,7), 
            '4-FullIns_4.col':(15,8), '4-FullIns_5.col':(15,9), '5-FullIns_3.col':(15,1) , '1-Insertions_4.col': (15,4) , '1-Insertions_5.col': (15,1), 
            '1-Insertions_6.col': (15,7), '2-Insertions_3.col': (15,4), '2-Insertions_4.col': (15,1), '2-Insertions_5.col': (15,1), '3-Insertions_3.col': (15,4), 
            '3-Insertions_4.col': (15,1), '3-Insertions_5.col': (15,6), '4-Insertions_3.col': (15,1), '4-Insertions_4.col': (15,1),
            'le450_5a.col':(15,5),  'le450_5b.col':(15,5),  'le450_5c.col':(15,5), 'le450_5d.col':(15,5), 'mug88_1.col': (15,4) , 'mug88_25.col': (15,4), 
            'mug100_1.col': (15,4), 'mug100_25.col': (15,4), 'myciel3.col':(15,4), 'myciel4.col':(15,5), 'myciel5.col':(15,6), 'myciel6.col':(15,7), 'myciel7.col':(15,8), 
            'queen5_5.col':(15,5), 'queen6_6.col':(15,7), 'queen7_7.col':(15,7), 'queen8_8.col':(15,9), 'DSJC125.1.col':(15,5), 'DSJC250.1.col':(15,8),'will199GPIA.col': (15,1) }


result_file= open('fastcolor_def_m_3_recol.txt', 'w')
num_graphs = len(benchmarks.keys())
graph_count=0
directory_path='colbenchmarks'

benchmarks=glob.glob(os.path.join('benchmark_dataset', '*.col'))
num_graphs=len(benchmarks)
# benchmarks=glob.glob(os.path.join('colbenchmarks', '*.col'))
for file_name in benchmarks:
    graph_count+=1
    nx_graph = load_dimacs_col_file(file_name)
    nx_graph,_=node_mapping_kernel(nx_graph)
    #print(nx_graph.nodes()) 
    nodes=nx_graph.number_of_nodes()
    edges=nx_graph.number_of_edges()
    

    # pos = nx.spring_layout(nx_graph, seed=45)  # or use nx.kamada_kawai_layout(G), etc.
    # nx.draw(nx_graph, with_labels=True, node_color='skyblue', node_size=100, font_weight='bold', edge_color='black')
    # plt.title("NCX Graph Visualization")
    # plt.show()
    # a=input()
    print(f'\nProcessing******* {graph_count}/{num_graphs} ****************{file_name}********{nodes}********{edges}*******************************************************')
    # status_string, sol, time_exec, correct_status, num_reductions=FastColor(nx_graph)
    # result_file.write(f'{file_name} {status_string} {sol} {time_exec} {correct_status} {num_reductions}\n')
    # print(f'{file_name} {status_string} {sol} {time_exec}')
    # result_file.flush()
    #print('max clique size', nx.graph_clique_number(nx_graph))
    for _ in range(50):
        status_string, sol, time_exec, correct_status, num_reductions=FastColor(nx_graph)
        result_file.write(f'{file_name} {status_string} {sol} {time_exec} {correct_status} {num_reductions}\n')
        print(f'{file_name} {status_string} {sol} {time_exec}')
        result_file.flush()
        if status_string=='proved optimality':
            break
        
        
    result_file.write('\n')
    
