import numpy as np
import networkx as nx
from copy import deepcopy
from time import time
from tqdm import tqdm
import random
import torch
from torch.utils.data import DataLoader
import dgl
from cb.ppo.actor_critic import ActorCritic as ActorCriticCb
from cb.ppo.graph_net import PolicyGraphConvNet as PolicyGraphConvNetCb
from cb.ppo.graph_net import ValueGraphConvNet as ValueGraphConvNetCb
from cb.env import VCP
from mis.ppo.actor_critic import ActorCritic as ActorCriticMis
from mis.ppo.graph_net import PolicyGraphConvNet as PolicyGraphConvNetMis
from mis.ppo.graph_net import ValueGraphConvNet as ValueGraphConvNetMis
from mis.env import MaximumIndependentSetEnv
from pulp import *
import pulp
from random import randint
from itertools import combinations, chain
from Coloring_networkx_addons import ThinGraph, is_coloring_feasible
from timeit import default_timer
import os
from argparse import ArgumentParser
import statistics
import scipy.stats as stats
import glob

#initialize argument parser
parser = ArgumentParser()
parser.add_argument('--mode', type=int, default='cpu')

args = parser.parse_args()

device = 'cpu'

# env
hamming_reward_coef = 0.1

# actor critic
num_layers = 4
hidden_dim = 128

#optimiazation
max_epi_t = 128 #hp
episode_length= 128 #hp
num_parallel_graph=100
sample_mis=100#hp

# dataset specific
min_num_nodes = 50 
max_num_nodes = 100 

num_colors=15

# construct everything for mis
env_mis = MaximumIndependentSetEnv(
    max_epi_t = max_epi_t,
    max_num_nodes = max_num_nodes,
    hamming_reward_coef = hamming_reward_coef,
    device = device
    )

# construct actor critic network
actor_critic_mis = ActorCriticMis(
    actor_class = PolicyGraphConvNetMis,
    critic_class = ValueGraphConvNetMis, 
    max_num_nodes = max_num_nodes, 
    hidden_dim = hidden_dim,
    num_layers = num_layers,
    device = device
    )

# Load the saved state dictionary
model_path_mis = './model_mis.pth'
state_dict_mis = torch.load(model_path_mis)

# Load the state dictionary into the model
actor_critic_mis.load_state_dict(state_dict_mis)

#constructing everything for cb

env_cb=VCP(max_epi_t=episode_length,device=device,num_colors=num_colors)

#construct actor critic network
actor_critic_cb =ActorCriticCb(
    actor_class= PolicyGraphConvNetCb,
    critic_class=ValueGraphConvNetCb,
    max_num_nodes=max_num_nodes,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    num_colors=num_colors,
    device=device
)

model_path_cb = './model_vcolrl.pth'
state_dict_cb = torch.load(model_path_cb,map_location=torch.device('cpu'))

# Load the state dictionary into the model
actor_critic_cb.load_state_dict(state_dict_cb)

filepath= f'./t2_ci_{args.mode}.txt'
result_file=open(filepath,'w')


def compute_stats(data, confidence=0.95):
    """Compute mean, standard deviation, and 95% confidence intervals with MoE."""
    
    n = len(data)
    if n < 2:
        raise ValueError("At least two data points are required.")

    mean = statistics.mean(data)
    std_dev = statistics.stdev(data)
    std_err = std_dev / (n ** 0.5)  # Standard error

    # t-distributed CI (for small samples)
    t_critical = stats.t.ppf((1 + confidence) / 2, df=n-1)
    t_margin = t_critical * std_err
    #t_ci = (mean, t_margin, (mean - t_margin, mean + t_margin))

    # Normal-distributed CI (for large samples)
    z_critical = stats.norm.ppf((1 + confidence) / 2)
    z_margin = z_critical * std_err
    #z_ci = (mean, z_margin, (mean - z_margin, mean + z_margin))

    return mean, std_dev, std_err, t_margin, z_margin

# define evaluate function for mis
def evaluate_mis(g, actor_critic):
    actor_critic.eval()
    cum_cnt = 0
    cum_eval_sol = 0.0
    
    g.set_n_initializer(dgl.init.zero_initializer)
    ob = env_mis.register(g, num_samples = 1)
    while True:
        with torch.no_grad():
            action = actor_critic.act(ob, g)

        ob, reward, done, info = env_mis.step(action)
        if torch.all(done).item():
            cum_eval_sol += info['sol'].max(dim = 1)[0].sum().cpu()
            cum_cnt += g.batch_size
            break
    
    ob_temp=ob.select(2,0)
    return cum_eval_sol, ob_temp.flatten()



def colors_used_mis(graph):
    count=0
    while graph.number_of_nodes()!=0:
        _,ob=evaluate_mis(graph,actor_critic_mis)
        if torch.any(ob==1):
            count+=1
    
        else:
            count+=torch.sum(ob==0).item()
            break

        graph=graph.subgraph(ob==0)

    return count

def evaluate_cb(graph, actor_critic):
    actor_critic.eval()
    progress_bar = tqdm(total=episode_length, desc="Validation Progress", unit="iteration")
    t1=time()
    g=graph
    num_nodes=g.number_of_nodes()
    g.set_n_initializer(dgl.init.zero_initializer)
    ob = env_cb.register(episode_length,g,1)
    ob=ob.to(device)
    while True:
        with torch.no_grad():
            action = actor_critic.act(ob, g).to(device)
        ob,_,done = env_cb.step(action+1)
        progress_bar.update(1)
        if torch.all(done).item():
            break
    state=ob.select(2,0).int().squeeze().tolist()
    satisfied=round(100-state.count(0)*100/num_nodes,2)
    my_soln=len(set(x for x in state if x != 0))
    
    solution_time=time()-t1 
    return satisfied,torch.tensor(state),my_soln,solution_time

    


def load_dimacs_col_file(file_path):
    G = nx.Graph()
    max_node=0
    with open(file_path, 'r') as f:
        for line in f:
            line=line.strip()
            # split_list=line.split()
            # if len (split_list)==2:
            #     G.add_edge(int(split_list[0]) - 1, int(split_list[1]) - 1)
            #     max_node=max(int(split_list[0])-1, int(split_list[1])-1, max_node)
            if line.startswith('e'):  # 'e' lines represent edges
                _, node1, node2 = line.split()
                # Subtract 1 from node labels to make them zero-indexed
                G.add_edge(int(node1) - 1, int(node2) - 1)
                max_node=max(int(node1)-1, int(node2)-1, max_node)
    for node in range(max_node+1):
        if node not in G:
            G.add_node(node)
    return G



def compare(graph,file_name, tabucol_init=15):
    
    nx_graph=graph
    dgl_graph_main=dgl.from_networkx(nx_graph)
    
    match args.mode:
        case 1:
            print('RL agent.......')
            color=[]
            timee=[]
            for i in range(100):
                dgl_graph=deepcopy(dgl_graph_main)
                agent_time_begin=time()
                soln_cb_sat,ob_best,soln_cb_colors,time_cb=evaluate_cb(dgl_graph,actor_critic_cb)
                Flag=1
                while soln_cb_sat<100:
                    Flag+=1
                    if Flag>10:
                        break
                    dgl_graph=dgl_graph.subgraph(ob_best==0)
                    if dgl_graph.num_nodes()==1:
                        soln_cb_colors+=1
                        break
                    soln_cb_sat,ob_best,soln_cb_colors_ad,time_cb_ad=evaluate_cb(dgl_graph,actor_critic_cb)
                    soln_cb_colors+=soln_cb_colors_ad
                    time_cb+=time_cb_ad
                agent_time=time()-agent_time_begin
                color.append(soln_cb_colors)
                timee.append(agent_time)
                print(f'RL agent done with {i+1} trial with {soln_cb_colors} colors')
                
            result_file.write(f'{file_name}\n')
            mean, std_dev, std_err, t_margin, z_margin=compute_stats(color)
            result_file.write(f'{min(color)} {timee[color.index(min(color))]}\n')
            result_file.write(f'{mean} {std_dev} {std_err} {t_margin} {z_margin}\n')
            mean, std_dev, std_err, t_margin, z_margin=compute_stats(timee)
            result_file.write(f'{mean} {std_dev} {std_err} {t_margin} {z_margin}\n\n')
            result_file.flush()
            print('RL agent done')
            return 0
            

            
        case 2:
            print('vcolmis.......')
            color=[]
            timee=[]
            
            for _ in range(100):
                e=time()
                soln_mis=colors_used_mis(dgl_graph_main)
                time_mis=time()-e
                color.append(soln_mis)
                timee.append(time_mis)
                
           
            mean, std_dev, std_err, t_margin, z_margin=compute_stats(color)
            mean_t, std_dev_t, std_err_t, t_margin_t, z_margin_t=compute_stats(timee)
            result_file.write(f'{file_name} {min(color)} {timee[color.index(min(color))]} {mean} {std_dev} {mean_t} {std_dev_t}\n')
            result_file.flush()
            print('vcolmis completed')
            return 0
            


    #logging to file

    

directory_path='benchmarks'


benchmarks={'ash331GPIA.col':(15,4), 'ash608GPIA.col':(15,4), 'ash958GPIA.col':(15,4) , '1-FullIns_3.col':(15,4), '1-FullIns_4.col':(15,5),  '1-FullIns_5.col':(15,6) , 
            '2-FullIns_3.col':(15,5), '2-FullIns_4.col':(15,6), '2-FullIns_5.col':(15,7) ,'3-FullIns_3.col':(15,6), '3-FullIns_4.col':(15,1), '4-FullIns_3.col':(15,7), 
            '4-FullIns_4.col':(15,8), '4-FullIns_5.col':(15,9), '5-FullIns_3.col':(15,1) , '1-Insertions_4.col': (15,4) , '1-Insertions_5.col': (15,1), 
            '1-Insertions_6.col': (15,7), '2-Insertions_3.col': (15,4), '2-Insertions_4.col': (15,1), '2-Insertions_5.col': (15,1), '3-Insertions_3.col': (15,4), 
            '3-Insertions_4.col': (15,1), '3-Insertions_5.col': (15,6), '4-Insertions_3.col': (15,1), '4-Insertions_4.col': (15,1),
            'le450_5a.col':(15,5),  'le450_5b.col':(15,5),  'le450_5c.col':(15,5), 'le450_5d.col':(15,5), 'mug88_1.col': (15,4) , 'mug88_25.col': (15,4), 
            'mug100_1.col': (15,4), 'mug100_25.col': (15,4), 'myciel3.col':(15,4), 'myciel4.col':(15,5), 'myciel5.col':(15,6), 'myciel6.col':(15,7), 'myciel7.col':(15,8), 
            'queen5_5.col':(15,5), 'queen6_6.col':(15,7), 'queen7_7.col':(15,7), 'queen8_8.col':(15,9), 'DSJC125.1.col':(15,5), 'DSJC250.1.col':(15,8),'will199GPIA.col': (15,1) }

# benchmarks={'tech-p2p-gnutella.mtx':(15,4)}

benchmarks={'ash331GPIA.col':(15,4), 'ash608GPIA.col':(15,4), 'ash958GPIA.col':(15,4) , '2-FullIns_5.col':(15,7) ,'4-FullIns_5.col':(15,9), '3-Insertions_5.col': (15,6), 
  'le450_5a.col':(15,5),  'le450_5b.col':(15,5),  'le450_5c.col':(15,5), 'le450_5d.col':(15,5),'DSJC250.1.col':(15,8)}

benchmarks={'1-FullIns_3.col':(15,4), '1-FullIns_4.col':(15,5),  '1-FullIns_5.col':(15,6) , '2-FullIns_3.col':(15,5), '2-FullIns_4.col':(15,6),'3-FullIns_3.col':(15,6), 
            '3-FullIns_4.col':(15,1), '4-FullIns_3.col':(15,7), '4-FullIns_4.col':(15,8),  '5-FullIns_3.col':(15,1) , '1-Insertions_4.col': (15,4) , 
            '1-Insertions_5.col': (15,1), '1-Insertions_6.col': (15,7), '2-Insertions_3.col': (15,4), '2-Insertions_4.col': (15,1), '2-Insertions_5.col': (15,1), '3-Insertions_3.col': (15,4), 
            '3-Insertions_4.col': (15,1),  '4-Insertions_3.col': (15,1), '4-Insertions_4.col': (15,1),'mug88_1.col': (15,4) , 'mug88_25.col': (15,4), 
            'mug100_1.col': (15,4), 'mug100_25.col': (15,4), 'myciel3.col':(15,4), 'myciel4.col':(15,5), 'myciel5.col':(15,6), 'myciel6.col':(15,7), 'myciel7.col':(15,8), 
            'queen5_5.col':(15,5), 'queen6_6.col':(15,7), 'queen7_7.col':(15,7), 'queen8_8.col':(15,9), 'DSJC125.1.col':(15,5), 'will199GPIA.col': (15,1) }

benchmarks=glob.glob('benchmarks/*.col')
print(benchmarks)
num_graphs = len(benchmarks)
graph_count=0
for file_name in benchmarks:
    graph_count+=1
    print(f'\nProcessing********************{graph_count}/{num_graphs} ****************{file_name}*********************************************************************************************')
    #file_path=os.path.join(directory_path,file_name)
    nx_graph = load_dimacs_col_file(file_name)
    nodes=nx_graph.number_of_nodes()
    edges=nx_graph.number_of_edges()
    print(nodes,edges)
    a=compare(nx_graph,file_name.split('/')[-1],15)
    try:
        pass
    except:
        print("Some error occured")
        continue
 


result_file.close()
