import pandas as pd
import numpy as np
import networkx as nx
import pickle
import logging
logging.basicConfig(level=logging.DEBUG)
import argparse

import our_gloabls
import get_delays
import util
import score_all_edges
import model
import dataloader as dl
import find_topological_order

def introduce_cycle(edges, edge_cand):
    bin_rep = np.vectorize(lambda x: 0 if x==None else 1)(edges)
    DiG = nx.DiGraph(bin_rep)
    DiG.add_edge(edge_cand.cause, edge_cand.effect)
    try:
        cycles = nx.find_cycle(DiG, edge_cand.cause, orientation='original')
        return True
    except nx.NetworkXNoCycle: 
        return False
    

def set_causal_prior_edges(causal_prior, delays, candidates, tested_candidates):
    causal_edges = np.where(causal_prior == 1)
    for cause, effect in zip(causal_edges[0], causal_edges[1]):
        # if cause != effect: # why was this here?
        i = delays.where((delays['effect'] == effect) & (delays['cause'] == cause)).dropna().index[0] #should only be one #TODO change back to int once delays are updated
        if not tested_candidates[i]:
            candidates[i].gain_est = -np.inf # hack to add prior edges first


def search(alarms_df:pd.DataFrame, topology_matrix, causal_prior, no_topo=False):
    base_cost = util.get_base_alarm_cost(alarms_df)
    alarms = np.arange(causal_prior.shape[0])

    G, device_dict = dl.all_data_graph(alarms_df, topology_matrix)
    m = model.Model(G,len(alarms),base_cost, len(alarms_df), device_dict)

     #assume sorted by start times, to avoid sorting each time in get_delays
    delays = get_delays.get_event_delays(m.all_alarms, topology_matrix, len(causal_prior), len(topology_matrix), delta_t=our_gloabls.max_delay, long_format=True, save_path=None)
    candidates = score_all_edges.score_all_edges(delays,base_cost,m.alarm_counts)
    topological_order = find_topological_order.find_topological_order_from_candidates(candidates)
    tested_candidates = np.zeros(len(candidates), dtype=bool)

    set_causal_prior_edges(causal_prior, delays, candidates, tested_candidates)
    
    non_causal_edges = np.where(causal_prior == 0)
    for cause, effect in zip(non_causal_edges[0], non_causal_edges[1]):
        if cause != effect:
            i = delays.where((delays['effect'] == effect) & (delays['cause'] == cause)).dropna().index[0]
            tested_candidates[i] = True
    
    #TODO add checks for circles and causal prior etc 
    # add edges from causal prior first?
    # skip edges that would violate causal prior
    while True:
        while True:#TODO don't add edge already in graph, extend to no violation of causal prior and no circles anything else?
            edgeCand_pt = candidates.argmin() # using simple min instead of heap, since we have to re - heapify after each edge addition anyway

            if m.edges[candidates[edgeCand_pt].cause,candidates[edgeCand_pt].effect] == None and\
               tested_candidates[edgeCand_pt] == False and\
                not introduce_cycle(m.edges, candidates[edgeCand_pt])\
                and (topological_order.index(candidates[edgeCand_pt].effect) > topological_order.index(candidates[edgeCand_pt].cause) or no_topo):
                break
            elif candidates[edgeCand_pt].gain_est > 0:
                break
            else:
                candidates[edgeCand_pt].gain_est = 1 #mark as already added #TODO with alreaded tested we can ignore that 
        edgeCand = candidates[edgeCand_pt]
        logging.debug(f"testing edge number {tested_candidates.sum()} / {tested_candidates.size}")
        if edgeCand.gain_est < 0:
            edge = edgeCand.get_edge()
            alarms = m.test_add_edge(edge)
            # delays = get_delays.update_event_delays(delays, alarms)
            # delays = get_delays.get_event_delays(m.all_alarms, topology_matrix, len(causal_prior), len(topology_matrix), delta_t=our_gloabls.max_delay, long_format=True, save_path=None)
            new_delays = get_delays.get_event_delays(m.all_alarms, topology_matrix, len(causal_prior), len(topology_matrix), delta_t=our_gloabls.max_delay, long_format=True, save_path=None, effect_event_types=[edge.effect])
            new_delays = new_delays.where((new_delays['effect'] == edge.effect)).dropna()
            delays.loc[new_delays.index] = new_delays
            # update_index = delays.where((delays['effect'] == edge.effect)).dropna().index
            update_index = new_delays.index
            candidates[update_index] = score_all_edges.score_all_edges(delays.loc[update_index],base_cost, m.alarm_counts)
            candidates[edgeCand_pt].gain_est = 1 # mark as already tested #TODO will reset after another edge with same effect is added, maybe this makes sense?
            tested_candidates[edgeCand_pt] = True
            set_causal_prior_edges(causal_prior, delays, candidates, tested_candidates) #TODO not a fan of resetting them every iteration
        else:
            break

    return m


def topological_search(alarms_df:pd.DataFrame, topology_matrix, causal_prior, init_with_self_loops=False):
    '''
    TODO causal prior is ignored for now
    '''
    
    unique_alarms = np.arange(causal_prior.shape[0])

    G, device_dict = dl.all_data_graph(alarms_df, topology_matrix)
    if init_with_self_loops:
        m = set_matrix_as_model(alarms_df, topology_matrix, np.identity(len(unique_alarms)))
    else:
        m = model.Model(G,len(unique_alarms), len(alarms_df),  device_dict)

    delays = get_delays.get_event_delays(m.all_alarms, topology_matrix, len(causal_prior), len(topology_matrix), delta_t=our_gloabls.max_delay, long_format=True, save_path=None)
    candidates = score_all_edges.score_all_edges(delays,m.base_cost,m.alarm_counts)
    candidate_cause  = np.vectorize(lambda candidate: candidate.cause)(candidates)
    candidate_effect = np.vectorize(lambda candidate: candidate.effect)(candidates)

    tested_candidates = np.zeros(len(candidates), dtype=bool) #TODO still needed? 
    
    added_nodes = set()
    while len(added_nodes) < len(unique_alarms): 
        next_node = find_topological_order.get_next_node(candidates, added_nodes)
        m.test_remove_effect_edges(next_node)
        added_nodes.add(next_node)
        next_node_candidates_pt = np.where((candidate_cause == next_node) & (np.vectorize(lambda x: x not in added_nodes)(candidate_effect)))
        next_node_candidates = candidates[next_node_candidates_pt]
        its_time_to_stop = False
        while len(next_node_candidates) > 0 and not its_time_to_stop:
            
            edgeCand_pt = next_node_candidates.argmin()
            edgeCand = next_node_candidates[edgeCand_pt]
            edge = edgeCand.get_edge()
            #TODO right candidate used etc?? 
            
            removed_edges = m.remove_effect_edges(edgeCand.effect) if edgeCand.gain_est < np.inf else [] 

            if len(removed_edges) > 0:
                new_delays = get_delays.get_event_delays(m.all_alarms, topology_matrix, len(causal_prior), len(topology_matrix), delta_t=our_gloabls.max_delay, long_format=True, save_path=None, effect_event_types=[edge.effect])
                new_delays = new_delays.where((new_delays['effect'] == edge.effect)).dropna()
                delays.loc[new_delays.index] = new_delays
                # update_index = delays.where((delays['effect'] == edge.effect)).dropna().index
                update_index = new_delays.index
                candidates[update_index] = score_all_edges.score_all_edges(delays.loc[update_index],m.base_cost, m.alarm_counts)
                next_node_candidates = candidates[next_node_candidates_pt]
                iteration_candidates_pt = np.where((np.isin(candidate_cause,removed_edges)) & (candidate_effect == edgeCand.effect))
                iteration_candidates = candidates[iteration_candidates_pt]
                seen_causes = set()
                while len(iteration_candidates) > 0:
                    
                    while True:
                        iteration_edgeCand_pt = iteration_candidates.argmin()
                        if iteration_candidates[iteration_edgeCand_pt].gain_est < next_node_candidates[edgeCand_pt].gain_est:
                            iteration_edgeCand = iteration_candidates[iteration_edgeCand_pt]
                        else:
                            iteration_edgeCand = next_node_candidates[edgeCand_pt]
                        if iteration_edgeCand.gain_est == np.inf:
                            break
                        elif iteration_edgeCand.cause not in seen_causes:
                            seen_causes.add(iteration_edgeCand.cause)
                            break
                        else:
                            iteration_edgeCand.gain_est = np.inf

                    if iteration_edgeCand.gain_est < 0:
                        iteration_edge = iteration_edgeCand.get_edge()
                        m.test_add_edge(iteration_edge)

                        new_delays = get_delays.get_event_delays(m.all_alarms, topology_matrix, len(causal_prior), len(topology_matrix), delta_t=our_gloabls.max_delay, long_format=True, save_path=None, effect_event_types=[edge.effect])
                        new_delays = new_delays.where((new_delays['effect'] == edge.effect)).dropna()
                        delays.loc[new_delays.index] = new_delays
                        # update_index = delays.where((delays['effect'] == edge.effect)).dropna().index
                        update_index = new_delays.index
                        candidates[update_index] = score_all_edges.score_all_edges(delays.loc[update_index],m.base_cost, m.alarm_counts)
                        next_node_candidates = candidates[next_node_candidates_pt]
                        iteration_candidates = candidates[iteration_candidates_pt]
                    else:
                        break

                    iteration_edgeCand.gain_est = np.inf # mark as already tested
                next_node_candidates[edgeCand_pt].gain_est = np.inf  # mark as already tested
            elif edgeCand.gain_est < 0:
                alarms = m.test_add_edge(edge)
                
                new_delays = get_delays.get_event_delays(m.all_alarms, topology_matrix, len(causal_prior), len(topology_matrix), delta_t=our_gloabls.max_delay, long_format=True, save_path=None, effect_event_types=[edge.effect])
                new_delays = new_delays.where((new_delays['effect'] == edge.effect)).dropna()
                delays.loc[new_delays.index] = new_delays
                # update_index = delays.where((delays['effect'] == edge.effect)).dropna().index
                update_index = new_delays.index
                candidates[update_index] = score_all_edges.score_all_edges(delays.loc[update_index],m.base_cost, m.alarm_counts)
                next_node_candidates = candidates[next_node_candidates_pt]
                next_node_candidates[edgeCand_pt].gain_est = np.inf # mark as already tested #TODO will reset after another edge with same effect is added, maybe this makes sense?

            else:
                break
    return m

def set_matrix_as_model(alarms_df:pd.DataFrame, topology_matrix, causal_matrix, skip_reassign=False):
    unique_alarms = np.arange(causal_matrix.shape[0])

    G, device_dict = dl.all_data_graph(alarms_df, topology_matrix)
    m = model.Model(G,len(unique_alarms), len(alarms_df),  device_dict)
    
    delays = get_delays.get_event_delays(m.all_alarms, topology_matrix, len(causal_matrix), len(topology_matrix), delta_t=our_gloabls.max_delay, long_format=True, save_path=None)
    candidates = score_all_edges.score_all_edges(delays,m.base_cost,m.alarm_counts)
    tested_candidates = np.zeros(len(candidates), dtype=bool)

    set_causal_prior_edges(causal_matrix, delays, candidates, tested_candidates)
    while True:
        edgeCand_pt = candidates.argmin()
        edgeCand = candidates[edgeCand_pt]
        if edgeCand.gain_est == -np.inf:
            edge = edgeCand.get_edge()
            m.edges[edge.cause, edge.effect] = edge
            edgeCand.gain_est = 1 # mark as already added
        else:
            break
    if not skip_reassign: #fix true explanation here ? 
        m.reassign_all_alrams_and_refit()
        print("Length under set causal prior: ", m.compute_length())
    return m



# deprecated start search over main.py 
if __name__ == '__main__':

    topology_matrix = np.load(r'./HuaweiVirus/sample/topology.npy')
    empty_topology = np.zeros(topology_matrix.shape, dtype=int)
    full_topology = np.ones(topology_matrix.shape, dtype=int)
    np.fill_diagonal(full_topology, 0)

    alarms = pd.read_csv(r'./HuaweiVirus/sample/alarm.csv')
    causal_prior= np.load(r'./HuaweiVirus/sample/causal_prior.npy')
    m = search(alarms, full_topology, causal_prior) #TODO test without causal prior 
    pickle.dump(m, open('model_full_topo.pkl', 'wb')) #TODO properly save model 


    # empty topology approach
    '''
    assume empty topology
    score all edges 
    take best edge add to model
    add all edges to topology that improve score 
    take next best edge add to model 
    add all edges to topology that improve score
    etc. 
    '''
    