import argparse
import random
from random import sample
import numpy as np
import pandas as pd
from scipy.stats import bernoulli
import os

"""
Example usage
python ./data_generator/generate_data.py -output-path ./data/syn_data/simple/ -alarms alarms.csv -true-graph true_graph.npy -nodes 7 -edges 10 -T 10000 -delay-dist exp -delay-params-names beta -delay-params-lower 20 -delay-params-upper 50 -skip-prob-range 0 0.01 -root-count 1000 -add-device True -g-add-percentage 0.01 -m-add-percentage 0.01 -m-dest-percentage 0.01
"""

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)


class Simulation(object):

    @staticmethod
    def sample_delay(distribution, params, min_delay=0): #TODO: maybe instead of checking manually, use rv_continous and pass name and parameters to distribution from scipy.stats
        if distribution == 'exp':
            delay = np.random.exponential(params['beta']) + min_delay #starts at 0 (>0, but we later take int of delay, so any delay between 0 and 1 will be mapped to 0)
        elif distribution == 'normal': #sample
            #transform loc and scale, sample from lognormal to obtain only positive, 
            delay = np.random.normal(params['loc'], params['scale']) #can generate instant effects if min_delay is 0, not sure if I need to shift
        elif distribution == 'uniform':
            delay = np.random.uniform(params['low'], params['high']) #can generate instant effects if min_delay is 0, not sure if I need to shift.
            #delay = np.random.uniform(max(params['low'], min_delay), params['high'])
        elif distribution == 'poisson':
            delay = np.random.poisson(params['lam']) + min_delay #starts at 0
        elif distribution == 'geom':
            delay = np.random.geometric(params['p']) + min_delay - 1 #starts at 1
        else:
            delay = min_delay
        # Add more distributions as needed
        return delay  # Default delay

    @staticmethod
    def _random_permutation(M): #permute node order (taken from gcastle library class DAG)
        # np.random.permutation permutes first axis only
        P = np.random.permutation(np.eye(M.shape[0]))
        return P.T @ M @ P
    
    @staticmethod
    def generate_dag(n_nodes=5, n_edges=5):
        G = np.zeros((n_nodes, n_nodes)) #adjacency matrix
        #G = np.identity(n_nodes)
        #Upper traingular matrix
        #randomly select E edges out of upper triangular indexes
        indices = np.transpose(np.triu_indices(n_nodes, 1))
        edges = np.random.choice(list(range(len(indices))), n_edges, replace=False)
        G[indices[edges].T[0], indices[edges].T[1]] = 1
        #permute
        #G = Simulation._random_permutation(G)
        return G.astype(int)
    
    @staticmethod
    def generate_dag_many_to_one(n_nodes=5, n_edges=5):
        #last node is the effect
        G = np.zeros((n_nodes, n_nodes)) #adjacency matrix
        #G = np.identity(n_nodes)
        G[:, -1] = np.ones(n_nodes)
        G[-1, -1] = 0
        return G.astype(int)

    @staticmethod
    def bfs_get_depths(self, G, depths=None):
        from collections import deque
        if depths is None:
            depths=np.zeros(G.shape[0]) #make inf at start, and take min of depth and other depths
        root_nodes=np.nonzero(np.sum(G, axis=0)==0)[0]
        q = deque(root_nodes)
        d=1
        visited=set()
        while len(q)>0:
            next_nodes=list()
            while len(q)>0:
                n = q.pop()
                if n in visited:
                    continue
                visited.add(n)
                depths[n] = d
                next_nodes.extend(np.nonzero(G[n])[0]) #edges outgoing from n
            #print(next_nodes)
            q=deque(next_nodes)
            d=d+1
        return depths
    
    @staticmethod
    def get_proportions(D):
        #input D
        #output array of relative frequency of occurences (|alarm types|)
        return D['alarm_id'].value_counts(normalize=True)
    
    @staticmethod
    def get_effects(graph, cause_alarm_id):
        return np.nonzero(graph[cause_alarm_id])[0]
    
    @staticmethod #TODO: make not static?
    def add_event(data, T, event_type, event_timestamp, cause_index): #adds event to last index
        if event_timestamp < T: #check that sampled event is less than end timestamp
            #using .loc fails silently?
            #data.loc[len(data)]={'alarm_id':int(event_type), 'start_timestamp':int(event_timestamp), 'cause_index':cause_index}
            new_row_df = pd.DataFrame({'alarm_id':[int(event_type)], 'start_timestamp':[int(event_timestamp)], 'cause_index':[cause_index]})
            data = pd.concat([data, new_row_df], ignore_index=True)
        return data

    @staticmethod
    def propogate_event(data, T, event_index, graph, skip_probs, delay_dist, delay_params, min_delay):
        """
        iteratively propogates event at event_index
        """
        from collections import deque
        #get neighbors
        event_type, timestamp, _ = data.loc[event_index]
        event_type = int(event_type)
        Q = deque()
        Q.appendleft((event_type, timestamp, event_index))
        while len(Q)>0:
            cause_event_type, cause_timestamp, cause_index = Q.pop()
            cause_event_type = int(cause_event_type)
            effects = Simulation.get_effects(graph, cause_event_type)
            for effect_event_type in effects: #for each effect node
                skip = bernoulli.rvs(skip_probs[cause_event_type, effect_event_type]) #sample skip
                if skip: #if skip is True then don't add effect event
                    continue
                # to add an event sample delay using delay dist and delay param, then add timestamp
                delay=Simulation.sample_delay(delay_dist, delay_params[effect_event_type][cause_event_type], min_delay)
                if delay < min_delay:
                    continue
                effect_timestamp = int(cause_timestamp+delay)
                #add event
                data=Simulation.add_event(data, T, effect_event_type, effect_timestamp, cause_index)
                if effect_timestamp < T:
                    Q.append((effect_event_type, effect_timestamp, len(data)-1))
        return data
    
    @staticmethod
    def g_add(data, T, g_add_percentage, graph, skip_probs, delay_dist, delay_params, min_delay):
        proportions = Simulation.get_proportions(data)
        for alarm_id,proportion in proportions.items():
            #determine count to add
            count = int(proportion*len(data)*g_add_percentage) #can use count instead of proportion vs len(data)?
            timestamps = np.random.uniform(0, T, count)
            for t in timestamps:
                data=Simulation.add_event(data, T, alarm_id, t, -1) #add event in last index
                data=Simulation.propogate_event(data, T, len(data)-1, graph, skip_probs, delay_dist, delay_params ,min_delay)
        return data
    
    @staticmethod
    def m_add(data, T, D, proportions, m_add_percentage):
        for alarm_id,proportion in proportions.items():
            #determine count to add
            count = int(proportion*D*m_add_percentage) #can use count instead of proportion vs len(data)?
            timestamps = np.random.uniform(0, T, count).astype(int)
            for t in timestamps:
                data=Simulation.add_event(data, T, alarm_id, t, -1) #add event in last index
                #don't propogate
        return data
    
    @staticmethod
    def m_dest(data, D, proportions, m_dest_percentage):
        deletion_indices = []
        for alarm_id,proportion in proportions.items():
            #determine count to delete
            count = int(proportion*D*m_dest_percentage) #can use count instead of proportion vs len(data)?
            #select indexes to delete
            indices = np.array(data.index[data['alarm_id']==alarm_id].tolist())
            #randomly select count indexes
            indices = np.random.choice(indices, count, replace=False)
            deletion_indices.extend(indices)
        #adjust cause_indexes first
        deletion_indices.sort(reverse=True)
        for index in deletion_indices:
            #does this work or need loc
            data.loc[data['cause_index'] == index,'cause_index'] = -1
            data.loc[data['cause_index'] > index,'cause_index'] -= 1

        #create new dataframe from non deleted rows
        data = data[~data.index.isin(deletion_indices)]
        data = data.reset_index(drop=True)
        return data
    
    @staticmethod
    def balanced_skip(data, proportions=None):
        #balanced skip, may affect variables with small skip probability more than they affect variables with large skip probability. (add option to sample based on distribution of proportions?)
        deletion_indices = []
        df = data.copy()
        df['index'] = df.index
        grouped = df.groupby(['alarm_id', 'start_timestamp'])['index'].apply(list).reset_index()
        #print("grouped:", grouped)
        ilists = grouped[grouped['index'].apply(lambda x: len(x) > 1)]['index']
        #print("ilists", ilists)
        for ilist in ilists:
            indices = np.random.choice(ilist, len(ilist)-1, replace=False)
            deletion_indices.extend(indices)
        #adjust cause_indexes first
        print("Number of same event type collided in a timestamp:", len(deletion_indices))
        deletion_indices.sort(reverse=True)
        for index in deletion_indices:
            #does this work or need loc
            data.loc[data['cause_index'] == index,'cause_index'] = -1
            data.loc[data['cause_index'] > index,'cause_index'] -= 1
        #create new dataframe from non deleted rows
        data = data[~data.index.isin(deletion_indices)]
        data = data.reset_index(drop=True)
        return data
    
    def delete_null_event_types(data, graph, skip_probs, delay_params):
        #get events that didn't occur
        null_events = list( set(list(range(graph.shape[0]))) - set(data['alarm_id'].unique()) )
        if len(null_events)==0:
            return data, graph, skip_probs, delay_params, 0
        if len(null_events)==graph.shape[0]:#empty graph and data
            return data, np.zeros_like(graph), skip_probs, delay_params, graph.shape[0] #?
        #delete rows and columns
        graph=np.delete(graph, null_events, 0)
        graph=np.delete(graph, null_events, 1)
        #shift event_type_indices
        null_events.sort(reverse=True) #list of event types to delete
        for event_type in null_events:
            #event_type
            #no event should have type event_type
            assert len(data.loc[data['alarm_id']==event_type]) == 0
            data.loc[data['alarm_id']>event_type, 'alarm_id'] -= 1
        #delete null events from skip probs and delay params
        skip_probs=np.delete(skip_probs, null_events, 0)
        skip_probs=np.delete(skip_probs, null_events, 1)
        delay_params = [[delay_params[j][i] for i in range(len(delay_params)) if i not in set(null_events)] for j in range(len(delay_params[0])) if j not in set(null_events)]
        return data, graph, skip_probs, delay_params, len(null_events)
    

    def __init__(self, N=7, E=10, T=1000, delay_dist = 'exp', delay_params_range={'beta':(20, 50)}, skip_prob_range=(0,1), root_count=10, g_add_percentage=0.1, m_add_percentage=0.1, m_dest_percentage=0.1, add_device=True, graph_type=None, dag=None, min_delay = 0) -> None:
        """
        N: number of event types (nodes in graph)
        E: number of edges
        T: maximum timestamp
        delay_dist: distribution of delays between nodes
        delay_params_range: dictionary with key:parameter name, value: tuple with range
        skip_prob_range: dictionary with key:parameter name, value: tuple with range
        root_count: number of root events
        return graph, data (N, timestamp, cause_id)
        data is list of nodes, each containing list of tuples (timestamp, id)
        causes is list of nodes, each containing list of tuples (timestamp, id)
        """
        #TODO: assert ranges of delay_params_range and skip_prob_range
        #TODO: add parameter for additive generative noise? instead of non root count
        #TODO: it seems data frame extension is expensive timewise because it creates a copy each time with concat?
        data = pd.DataFrame(columns=['alarm_id', 'start_timestamp', 'cause_index'])
        #sample graph
        if graph_type=="M1":
            graph = Simulation.generate_dag_many_to_one(N, E)
        else:
            graph = Simulation.generate_dag(N, E) if dag is None else dag
        #Define delay distribution parameters for each edge

        delay_params = [[{parameter:np.random.uniform(r[0], r[1]) for parameter, r in delay_params_range.items()} if graph[i,j] else None for i in range(N)] for j in range(N)]

        #sample skip probabilities
        skip_probs = np.random.uniform(skip_prob_range[0], skip_prob_range[1], (N,N)) * (graph)
        #skip_probs[graph==0] = 1 ?
        
        
        #sample root events Uniform(0, T, root_count) for each root node in graph
        #get root
        root_nodes = np.nonzero(np.sum(graph, axis=0)==0)[0] #indegree of node, axis 0 or 1?
        #root_nodes = list(range(graph.shape[0])) #if using graph with self edges
        #Sample root events
        for node in root_nodes:
            timestamps = np.random.randint(0, T, root_count)
            for t in timestamps:#TODO: convert to batch operation
                data=Simulation.add_event(data, T, node, t, -1)
                data=Simulation.propogate_event(data, T, len(data)-1, graph, skip_probs, delay_dist, delay_params, min_delay)
        
        data = Simulation.balanced_skip(data)
        #add generative noise
        g_add_count = len(data)
        data=Simulation.g_add(data, T, g_add_percentage, graph, skip_probs, delay_dist, delay_params, min_delay)    
        g_add_count = len(data) - g_add_count
        #TODO: delete measurement noise events, call m_dest
        D=len(data)
        proportions = Simulation.get_proportions(data)
        data = Simulation.m_dest(data, D, proportions, m_dest_percentage)
        m_dest_count = D - len(data)
        m_add_count = len(data)
        #TODO: add measurement noise events, call m_add
        data = Simulation.m_add(data, T, D, proportions, m_add_percentage)
        m_add_count = len(data) - m_add_count

        if add_device:
            data['device_id']=0
        if True: #TODO: add end_timestamp paremeter
            data['end_timestamp'] = data['start_timestamp']

        ##SORT by timestamp preserving cause indices
        data = data.sort_values('start_timestamp', ignore_index=False)
        indexmap = np.argsort(data.index)
        data.index = indexmap[data.index]
        data.loc[data['cause_index']!=-1,'cause_index'] = indexmap[data.loc[data['cause_index']!=-1,'cause_index'].tolist()]


        data, graph, skip_probs, delay_params, deleted = Simulation.delete_null_event_types(data, graph, skip_probs, delay_params)
        print(data.head())
        print("Total number of events:", len(data))
        print("Number of generative additive events:", g_add_count)
        print("Number of measurement additive events:", m_add_count)
        print("Number of measurement destructive events:", m_dest_count)
        print("proportion of events:", Simulation.get_proportions(data))
        print("Number of deleted non occuring events types:", deleted)

        assert len( set(list(range(graph.shape[0]))) - set(data['alarm_id'].unique()) ) == 0
        #TODO: order dataframe columns?

        self.data=data
        self.graph=graph
        self.skip_probs=skip_probs
        self.delay_params=delay_params
        return None
    

def get_params_range_dict(delay_names, delay_lower, delay_upper):
    #TODO: assert all the same lenght?
    #TODO: what if parameter is not numeric?
    return {delay_names[i]: (delay_lower[i], delay_upper[i]) for i in range(len(delay_names))}


if __name__=="__main__":
    #delay_params = [[{'beta':np.random.uniform(2, 100)} i in range(7)] for j in range(7)]
    argparser = argparse.ArgumentParser()
    argparser.add_argument("-seed", type=int, help="random seed", default=123)
    argparser.add_argument("-output-path", type=str, help="path to save alarms and graph files", default="./data/syn_data/simple/")
    argparser.add_argument("-alarms", type=str, help="name of csv file", default="alarms.csv")
    argparser.add_argument("-true-graph", type=str, help="name of true graph file", default="true_graph.npy")
    argparser.add_argument("-nodes", type=int, help="number of nodes (event types)", default=7)
    argparser.add_argument("-edges", type=int, help="number of edges", default=10)
    argparser.add_argument("-T", type=int, help="maximum timestamp", default=10000)
    argparser.add_argument('-delay-dist', type=str, help="delay dist. accepts one of following (exp, normal, uniform, geom)", default='exp')
    argparser.add_argument('-delay-params-names', type=str, nargs='+', help="names of parameters input the following parameters for distributions [exp:(beta), normal:(loc, scale), uniform:(low, high),].", default=['beta'])
    argparser.add_argument('-delay-params-lower', type=float, nargs='+', help="lower range of param. specify min beta for exp dist.", default=[100])
    argparser.add_argument('-delay-params-upper', type=float, nargs='+', help="upper range of param. specify max beta for exp dist. U(lower, upper)", default=[250])
    argparser.add_argument('-skip-prob-range', type=float, nargs='+', help="lower and upper end of range of skip probability", default=[0.05, 0.05])
    argparser.add_argument('-root-count', type=int, help="number of root events", default=1000)
    argparser.add_argument('--add-device', action='store_true', help="add device id to alarms", default=True)
    argparser.add_argument('-g-add-percentage', type=float, help="percentage of additive generative noise", default=0.05)
    argparser.add_argument('-m-add-percentage', type=float, help="percentage of measurement additive noise", default=0.05)
    argparser.add_argument('-m-dest-percentage', type=float, help="percentage of measurement destructive noise", default=0.05)
    argparser.add_argument('-graph-type', type=str, help="takes values M1, or any string (generates random dag). (many to 1)", default="random")
    argparser.add_argument('-dag', type=str, help="path to dag file", default=None)
    argparser.add_argument('-min-delay', type=int, help="minimum delay, skip if sampled delay is less than it", default=0)

    args = argparser.parse_args()

    #make args params into a dictionary
    delay_params_range=get_params_range_dict(args.delay_params_names, args.delay_params_lower, args.delay_params_upper)


    if args.seed is not None:
        set_random_seed(args.seed)

    dag = np.load(args.dag) if args.dag is not None else None
    if args.dag is not None:
        args.nodes = dag.shape[0]
        args.edges = np.sum(dag)
    
    simulation=Simulation(N=args.nodes, E=args.edges, T=args.T, delay_dist = args.delay_dist, delay_params_range=delay_params_range, skip_prob_range=args.skip_prob_range, root_count=args.root_count, g_add_percentage=args.g_add_percentage, m_add_percentage=args.m_add_percentage, m_dest_percentage=args.m_dest_percentage, add_device=args.add_device, graph_type = args.graph_type, dag=dag, min_delay = args.min_delay)
    #print(simulation.data)
    #print(simulation.graph)
    #save data
    os.makedirs(args.output_path, exist_ok=True)
    
    simulation.data.to_csv(args.output_path+args.alarms, index=False)
    np.save(args.output_path+args.true_graph, simulation.graph)
