import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import random
import math
from time import time
from utils import  node_mapping, node_mapping_kernel, core_decomposition, FindLBIS, FindClq
import os
from tqdm import tqdm
import random
import glob
import statistics

cutoff_time=60


def recolor(G,v,coloring):
    new_color=coloring[v]
    conflicts=[0]*(new_color+1)
    used=[-1]*(new_color+1) 
    for w in G.neighbors(v):
        if coloring[w]!=0:
            conflicts[coloring[w]]+=1
            used[coloring[w]]=w 
    
    for i in range(1,new_color):
        if conflicts[i]==1:
            w=used[i]
            for u in G.neighbors(w):
                if coloring[u]!=0:
                    used[coloring[u]]=w     
            c=0 
            for j in range(i+1,new_color):
                if (used[j]!=w):
                    c=j
                    break
                
            if c!=0:
                coloring[v]=coloring[w]
                coloring[w]=c
                return True, coloring

    return False, coloring

            
def find_saturation_degree(G, coloring):
    sat_degree = [0] * G.number_of_nodes()
    for node in G.nodes():
        color_neighbors=set()
        for neighbor in G.neighbors(node):
            if coloring[neighbor] != 0:
                color_neighbors.add(coloring[neighbor])
        sat_degree[node] = len(color_neighbors)
    return sat_degree

def update_saturation_degree(G, coloring, node, sat_degree):
    # Update the saturation degree of the neighbors of the given nodemax_color_used
    for neighbor in G.neighbors(node):
        color_used=set()
        for nodes in G.neighbors(neighbor):
            if coloring[nodes] != 0:
                color_used.add(coloring[nodes])
        sat_degree[neighbor] = len(color_used)  
    return sat_degree

def ColorKernel(G_main, is_colored=False):
    G,_=node_mapping_kernel(G_main)
    G_core,_=node_mapping(G)
    coloring= [0]*G.number_of_nodes()
    max_color_used=0
    #G_core,_=node_mapping(G)
    if is_colored==False:
        cores=core_decomposition(G_core)
        sorted_cores = sorted(range(len(cores)), key=lambda i: cores[i], reverse=True)
        for v in sorted_cores:
            #find the value i>0 such that no neighbor of u has taken color i 
            ct=1
            neighbor_coloring=set()
            for u in G.neighbors(v):
                if coloring[u]!=0:
                    neighbor_coloring.add(coloring[u])
            for i in range(1, G.number_of_nodes()+1):
                if i not in neighbor_coloring:
                    ct=i
                    break
                
            coloring[v]=ct
            if ct>max_color_used:
                status,coloring=recolor(G,v,coloring)
                if status==False:
                    max_color_used=ct
        
    else:
        V_set=list(G.nodes())
        saturation_degree=[0]*G.number_of_nodes()
        while len(V_set)>0:
            #saturation_degree=find_saturation_degree(G, coloring)
            max_saturaion_deg=0
            max_indices=[]
            for i in V_set:
                if saturation_degree[i]>max_saturaion_deg:
                    max_saturaion_deg=saturation_degree[i]
                    max_indices=[]
                    max_indices.append(i)
                elif saturation_degree[i]==max_saturaion_deg:
                    max_indices.append(i)
                     
            v=random.choice(max_indices)
            V_set.remove(v)

            ct=1
            neighbor_coloring=set()
            for u in G.neighbors(v):
                if coloring[u]!=0:
                    neighbor_coloring.add(coloring[u])
            for i in range(1, G.number_of_nodes()+1):
                if i not in neighbor_coloring:
                    ct=i
                    break
                
            coloring[v]=ct   
            if ct>max_color_used:
                status,coloring=recolor(G,v,coloring)
                if status==False:
                    max_color_used=ct
            saturation_degree=update_saturation_degree(G,coloring,v,saturation_degree)
                    
    return coloring

# def check_color(G, coloring):
#     for i in G.edges():
#         if coloring[i[0]]==coloring[i[1]]:
#             return False
#     return True


def construct_solution(G_k, G_m, e_m, alpha, lb_tracker, num_nodes):
    #this constructrion takes assumption that alpha is the best solution for G_k
    if len(G_m)==0: #if there is no reduction alpha will be the constructed solution
        return alpha
    
    alpha_plus=[0]*num_nodes #initializing the solution
    #making copies sot that the original values are not changed
    alpha=alpha.copy()
    lb_tracker=lb_tracker.copy()
    G_k=G_k.copy()
    G_m=G_m.copy()
    e_m=e_m.copy()
    
    #reoving the last entry of the lower bound tracker
    if len(G_m)<len(lb_tracker):
        lb_tracker=lb_tracker[:len(G_m)]
    
    assert len(G_m)==len(e_m)==len(lb_tracker), f"length of G_m and e_m lb_tracker should be same ====> {len(G_m)} {len(e_m)} {len(lb_tracker)}"
    
    for i in G_k.nodes():
        alpha_plus[i]=alpha.pop(0)
        
    for _ in range(len(G_m)):
        
        node_set = G_m.pop()
        edge_set = e_m.pop()
        lower_bound_prev=lb_tracker.pop()
        
        chromatic_num_current=max(alpha_plus)
        
        #construction of graph
        G_k.add_nodes_from(node_set)
        G_k.add_edges_from(edge_set)
        
        #assert chromatic_num_current<=lower_bound_prev, f"chromatic number of current graph should be less than or equal to the lower bound of previous graph"
        
        if chromatic_num_current<lower_bound_prev:
            new_color=max(alpha_plus)+1
            for node in node_set:
                alpha_plus[node]=new_color
                
        elif chromatic_num_current>=lower_bound_prev:
            for node in node_set:
                color_ued_by_neighbors=set()
                for neighbor in G_k.neighbors(node):
                    if alpha_plus[neighbor]!=0:
                        color_ued_by_neighbors.add(alpha_plus[neighbor])
            
                c=1
                
                for color_iter in range(1, chromatic_num_current+1):
                    if color_iter not in color_ued_by_neighbors:
                        c=color_iter
                        break
                alpha_plus[node]=c
        
    return alpha_plus


def FastColor(G):
    
    G_k=G.copy()
    G_m=[]
    e_m=[]
    lb_G=0
    lb_tracker=[]
    ub_G=G.number_of_nodes()
    lb_tracker.append(0)
    alpha_best=[0]*G.number_of_nodes()
    lb_k=0
    isColored=False
    t=1
    time_start=time()
    ub_tracker=[]
    ub_tracker.append(ub_G)
    while 1:
        bms_param_adjustment_required,lb_k=FindClq(G_k, lb=lb_k, t=t)
        if bms_param_adjustment_required:
            t=2*t
        if t>64:
            t=1
        if lb_k>lb_G:
            lb_G=lb_k
            
        lb_tracker[-1]=lb_k #we are overriding the current reduced graph's lower bound, this can be a debug point
            
        I=FindLBIS(G_k, lb_k)
        # print(I)
        # print(len(I))
        #a=input()
        edges_before_reduction=G.edges()
        removed_edges=set() #keeps track for removed edges in this iteration
        remaining_nodes=set(G_k.nodes()).difference(I)
        G_k=G_k.subgraph(remaining_nodes) #this is new Gk but the nodes will not be renumbered, if you remove 0 then there is no 9 in the node set
        #print(G_k.number_of_nodes(), G_k.number_of_edges())
        #a=input()
        for ed in edges_before_reduction:
            if ed not in G_k.edges():
                removed_edges.add(ed)
                
        #if there is an independent set independednt set, then we need to reset some values as the graph is reduces 
        if len(I)!=0:
            lb_k=0
            isColored=False
            G_m.append(I)
            e_m.append(removed_edges)
            lb_tracker.append(lb_k) #adding a new entry for the lower bound of reduced graph #reduce kia to pehle waala daalna padega, yaha tak upper bound change nahi kia hai, may be problematic
        
        alpha=ColorKernel(G_k, is_colored=isColored) #this is problematic as it's lebgth is dynamic
        alpha_plus=construct_solution(G_k, G_m, e_m, alpha,lb_tracker,G.number_of_nodes()) #this construction is problematic
        isColored=True
       
        if max(alpha_plus)<ub_G:  
            alpha_best=alpha_plus.copy()
            ub_G=max(alpha_plus)
            ub_tracker.append(max(alpha_plus))
            
            
        time_elapsed=time()-time_start
        if ub_G==lb_G: #or lb_k+len(G_m)==ub_G: # the second condition will work for triavial case in which length of gm is 1 (construction based method needs to be implemented)
            return max(alpha_best), time_elapsed
        
        if time_elapsed>cutoff_time:
            break
        
    return max(alpha_best), time()-time_start
            
            

def load_dimacs_col_file(filepath):
    """
    Reads a .col file in DIMACS format and returns a NetworkX graph.
    """
    G = nx.Graph()
    with open(filepath, 'r') as f:
        for line in f:
            if line.startswith('c'):
                continue  # comment
            elif line.startswith('p'):
                parts = line.strip().split()
                num_nodes = int(parts[2])
                # optionally add nodes explicitly
                G.add_nodes_from(range(num_nodes))
            elif line.startswith('e'):
                #print('count', count)
                parts = line.strip().split()
                u = int(parts[1])-1
                v = int(parts[2])-1
                G.add_edge(u, v)
    return G


result_file= open('fastcolor_ci.txt', 'w')
graph_count=0
benchmarks=glob.glob(os.path.join('benchmarks', '*.col'))
#benchmarks=glob.glob(os.path.join('benchmarks', '*karate.col'))
num_graphs=len(benchmarks)
for file_name in benchmarks:
    graph_count+=1
    nx_graph = load_dimacs_col_file(file_name)
    nx_graph,_=node_mapping_kernel(nx_graph)
    nodes=nx_graph.number_of_nodes()
    edges=nx_graph.number_of_edges()

    print(f'\nProcessing******* {graph_count}/{num_graphs} ****************{file_name}********{nodes}********{edges}*******************************************************')
    solution=[]
    timee=[]
    for _ in range(50):
        sol, time_exec=FastColor(nx_graph)
        solution.append(sol)
        timee.append(time_exec)
    
    best_solution=min(solution)
    best_time_index=solution.index(best_solution)
    best_time=timee[best_time_index]
    mean_sol=statistics.mean(solution)
    stdev_sol=statistics.stdev(solution)
    mean_time=statistics.mean(timee)
    stdev_time=statistics.stdev(timee)
    result_file.write(f'{file_name.split("/")[-1]} {nodes} {edges} {best_solution} {best_time} {mean_sol} {stdev_sol} {mean_time} {stdev_time}\n')
    result_file.flush()
