import numpy as np
import igraph
# from .pyspice_utils import EqClass

## These classes are copied here to prevent import errors
class EqElt:
    def __init__(self, elt_name):
        self.eq = set([elt_name])
        self.first_elt = elt_name
    def eqto(self, b):
        self.eq = self.eq.union(b.eq)
    @property
    def default_name(self):
        if 'IN' in self.eq:
            return 'IN'
        if 'OUT' in self.eq:
            return 'OUT'
        else:
            return self.first_elt

class EqClass:
    def __init__(self):
        self.all_eqs = {}

    def add_new_eq(self, a, b):
        if (a in self.all_eqs.keys()) and (b in self.all_eqs.keys()):
            self.all_eqs[a].eqto(self.all_eqs[b])
        elif a in self.all_eqs.keys():
            self.all_eqs[a].eqto(EqElt(b))
        else:
            if b not in self.all_eqs.keys():
                self.all_eqs[b] = EqElt(b)
            self.all_eqs[b].eqto(EqElt(a))
            self.all_eqs[a] = self.all_eqs[b]
        self.all_eqs[b] = self.all_eqs[a]
    @property
    def equivalences(self):
        return {k: v.default_name for k, v in self.all_eqs.items()}


def build_node_equivalences(edge_list, in_idx, out_idx):

    ### Identify equivalent nodes

    # Identify common device I/O nodes of the circuit 
    eq_nodes = EqClass()
    
    for in_node, out_node in edge_list:
        if in_node == in_idx:
            eq_nodes.add_new_eq('IN', f'{str(out_node)}_inpt')
        elif out_node == in_idx:
            eq_nodes.add_new_eq('IN', f'{str(in_node)}_out')
        elif in_node == out_idx:
            eq_nodes.add_new_eq('OUT', f'{str(out_node)}_inpt')
        elif out_node == out_idx:
            eq_nodes.add_new_eq('OUT', f'{str(in_node)}_out')
        else:
            eq_nodes.add_new_eq(f'{str(out_node)}_inpt', f'{str(in_node)}_out')

    return eq_nodes


def search_missing_edges(eq_nodes, edgelist, postprocess=False):
    '''
    Associativity has been assumed to construct the equivalence between electric nodes. Hence look for missing edges in the equivalence lists.
    Missing edges originate from devices that belong to a feedback loop.
    Params
    ------
        postprocess (bool): if yes then edges on feedback loops have already been reversed, this time missing edges are directed toward feedback
        loop components.
    Outputs
    -------
        nodes_to_reverse (list): a list of node indices which belong to a feedback loop.
    '''

    nodes_to_reverse = []
    for i in eq_nodes.all_eqs.keys():
        if '_out' in i:
            reverse_edges = False
            int_i = int(i.replace('_out', ''))
            if postprocess and 'IN' in eq_nodes.all_eqs[i].eq:
                continue
            for o in eq_nodes.all_eqs[i].eq:
                if '_inpt' in o:
                    int_o = int(o.replace('_inpt', ''))
                    if (int_i, int_o) not in edgelist:
                        reverse_edges = True
                        if postprocess:
                            nodes_to_reverse.append(int_o)
                            reverse_edges = False                            
            if reverse_edges:
                nodes_to_reverse.append(int_i)

    return np.unique(nodes_to_reverse).tolist()


def search_components_in_series(feedb_idx, eq_nodes):

    nodes_to_reverse = []

    for feedb_node in feedb_idx:
        out_compo = [e for e in eq_nodes.all_eqs[f'{feedb_node}_out'].eq if e != f'{feedb_node}_out']
        if len(out_compo) == 1 and out_compo[0] != 'OUT':
            nodes_to_reverse.append(int(out_compo[0].replace('_inpt', '')))
        inpt_compo = [e for e in eq_nodes.all_eqs[f'{feedb_node}_inpt'].eq if e != f'{feedb_node}_inpt']
        if len(inpt_compo) == 1 and inpt_compo[0] != 'IN':
            nodes_to_reverse.append(int(inpt_compo[0].replace('_out', '')))

    return nodes_to_reverse


def reverse_feedback_loops(dataset):

    ### Script to recover feedback loops in the OCB subcircuit dataset and break subcircuits into its elements

    in_type, out_type = 0, 1
    feedb_gm_types = [8, 9, 12, 13, 16, 17, 20, 21, 24, 25]

    new_dataset = []

    for graph_i in dataset:

        n_elt = range(len(graph_i.vs['type']))
        # Identify I/O nodes
        in_idx, out_idx = ([i for i in n_elt if graph_i.vs['type'][i] == in_type][0],
                        [i for i in n_elt if graph_i.vs['type'][i] == out_type][0])
        edgelist = graph_i.get_edgelist()

        ## 1st step - Leverage the subcircuits to identify feedback loops

        # Flag components belonging to feedback loops using their type
        feedb_idx = [t in feedb_gm_types for t in graph_i.vs['type']]
        
        # Identify components with missing out edges
        eq_nodes = build_node_equivalences(edgelist, in_idx, out_idx)
        nodes_to_reverse = search_missing_edges(eq_nodes, edgelist)
        feedb_idx = [True if idx in nodes_to_reverse else fb_t for (idx, fb_t) in enumerate(feedb_idx)]

        # Search for components in series with feedback loop components (max 1 before, 1 after)
        nodes_to_reverse = search_components_in_series(np.where(feedb_idx)[0], eq_nodes)
        feedb_idx = [True if idx in nodes_to_reverse else fb_t for (idx, fb_t) in enumerate(feedb_idx)]
        ##
        
        ## 2nd step - Reverse feedback loops edges, edges can now be directed to IN or from OUT
        new_edges = [(o, i) if (feedb_idx[i] or feedb_idx[o]) else (i, o) for (i, o) in graph_i.get_edgelist()]
        ##

        ## 3rd step - Search for missing edges one more time to discover additional feedback loops (this concerns 166 original OCB graphs)
        eq_nodes = build_node_equivalences(new_edges, in_idx, out_idx)
        nodes_to_reverse = search_missing_edges(eq_nodes, new_edges, postprocess=True)
        nodes_to_reverse = [e for e in nodes_to_reverse if not feedb_idx[e]]

        # Search for components in series with feedback loop components (max 1 before, 1 after)
        nodes_to_reverse += search_components_in_series(nodes_to_reverse, eq_nodes)

        # Reverse feedback loops edges one last time - after inspection, this does not concern nodes that were already assigned to feedback loops
        new_edges = [(o, i) if (i in nodes_to_reverse or o in nodes_to_reverse) else (i, o) for (i, o) in new_edges]
        graph_i.delete_edges()
        graph_i.add_edges(new_edges)
        
        # Add a feedback attribute to the graph
        feedb_idx = [True if idx in nodes_to_reverse else fb_t for (idx, fb_t) in enumerate(feedb_idx)]
        graph_i.vs['feedback'] = feedb_idx

        ###
        
        # Map from subckt to low level
        new_dataset.append(subckt_to_elements(graph_i))

    return new_dataset


def split_graph(graph, attributes=True):
    '''
    Add intermediate circuit nodes between all electrical components. Those are given the type "10".
    '''

    # Identify I/O nodes
    n_elt = len(graph.vs['type'])
    in_idx, out_idx = ([i for i in range(n_elt) if graph.vs['type'][i] == 8][0],
                       [i for i in range(n_elt) if graph.vs['type'][i] == 9][0])
    
    # Identify common device I/O nodes of the circuit 
    edge_list = graph.get_edgelist()
    eq_nodes = EqClass()
    for in_node, out_node in edge_list:
        if in_node == in_idx:
            eq_nodes.add_new_eq('IN', f'{str(out_node)}_inpt')
        elif out_node == in_idx:
            eq_nodes.add_new_eq('IN', f'{str(in_node)}_out')
        elif in_node == out_idx:
            eq_nodes.add_new_eq('OUT', f'{str(out_node)}_inpt')
        elif out_node == out_idx:
            eq_nodes.add_new_eq('OUT', f'{str(in_node)}_out')
        else:
            eq_nodes.add_new_eq(f'{str(out_node)}_inpt', f'{str(in_node)}_out')

    # Node name to idx dict
    additional_nodes = np.unique([e for e in eq_nodes.equivalences.values() if e not in ['IN', 'OUT']])
    additional_nodes = {str(node): n_elt + i for i, node in enumerate(additional_nodes)}
    additional_nodes.update({'IN': in_idx, 'OUT': out_idx})

    # Update type list, nets are given the index "10"
    new_types = graph.vs['type'] + (len(additional_nodes) - 2) * [10]
    new_attributes = (graph.vs['feat'] + (len(additional_nodes) - 2) * [0.0]) if attributes else None
    new_vid = (graph.vs['vid'] + (len(additional_nodes) - 2) * [0]) if attributes else None

    # Update edge list
    new_edgelist = []
    for (i, o) in edge_list:
        if (i in [in_idx, out_idx]) or (o in [in_idx, out_idx]):
            new_edgelist.append((i, o))
        else:
            interm_node = additional_nodes[eq_nodes.equivalences[f'{i}_out']]
            if not (i, interm_node) in new_edgelist:
                new_edgelist.append((i, interm_node))
            if not (interm_node, o) in new_edgelist:
                new_edgelist.append((interm_node, o))

    return {'type': new_types, 'feat': new_attributes, 'vid': new_vid}, new_edgelist


# Split graphs in a directory
def split_collection(graph_list, attributes):

    split_graph_list = []
    
    for graph in graph_list:
        new_graph = graph.copy()
        # Get new types and edge list
        attr, edgelist = split_graph(new_graph, attributes)
        # Assign new node types
        new_graph.delete_vertices()
        new_graph.add_vertices(len(attr['type']), attributes=attr)
        # Replace with new edge list
        new_graph.add_edges(edgelist)
        
        split_graph_list.append(new_graph)

    return split_graph_list


# The original OCB node types
NODE_TYPE = {'R': 0, 'C': 1, '+gm+':2, '-gm+':3, '+gm-':4, '-gm-':5, 'sudo_in':6, 'sudo_out':7, 'In': 8, 'Out':9}
# The OCB_v2 node types
extended_node_type = {'R': 0, 'C': 1, '+gm+':2, '-gm+':3, '+gm-':4, '-gm-':5, 'gm_in':6, 'gm_out':7, 'In': 8, 'Out':9, 'n': 10}
extended_subg_node = {0: ['R'], 1: ['C'], 2: ['gm_in', '+gm+', 'gm_out'], 3: ['gm_in', '-gm+', 'gm_out'], 4: ['gm_in', '+gm-', 'gm_out'], 
                      5: ['gm_in', '-gm-', 'gm_out'], 8: ['In'], 9: ['Out'], 10: ['n']}


def ocb_graph_with_split_gm(g):
    '''
    Takes a OCB graph as input and further split gm nodes into 3 nodes in series, making the input and output gm nodes apparent.
    '''

    new_types = []
    new_edges = []
    node_counter = 0
    eq_dict = {}
    
    # First, assign new idx and types
    for i, t_i in enumerate(g.vs['type']):
        curr_types = [extended_node_type[t] for t in extended_subg_node[t_i]]
        new_types.extend(curr_types)
    
        if len(curr_types) > 1:
            eq_dict[f'{i}_in'] = node_counter
            eq_dict[f'{i}_out'] = node_counter + len(curr_types) - 1
            # Add new edges to account for the links between gm and its pins
            new_edges += [(node_counter, node_counter + 1), (node_counter + 1, node_counter + 2)]
        else:
            eq_dict[f'{i}_in'] = node_counter
            eq_dict[f'{i}_out'] = eq_dict[f'{i}_in']
        
        node_counter += len(curr_types)
    
    # Then complete missing edges
    for n_i, n_o in g.get_edgelist():
        n_eq_i = eq_dict[f'{n_i}_out']
        n_eq_o = eq_dict[f'{n_o}_in']
        new_edges.append((int(n_eq_i), int(n_eq_o)))
    
    g_o = igraph.Graph(directed=True)
    g_o.add_vertices(len(new_types), {'type': new_types})
    g_o.add_edges(new_edges)
    
    return g_o


def retrieve_features(data_graphs, processed_data_graphs):
    '''
    Loops through the original dataset and the processed version with directed feedback loops, and match nodes and edges whenever possible
    to retrieve the features from the original data. When the mapping fails, draw features randomly. This allows to recover ~40% of the features.
    '''
    
    fails = []
    processed_data_graphs_with_feats = []
    for split_i in range(len(data_graphs)):
        data, processed_data = [e[1] for e in data_graphs[split_i]], processed_data_graphs[split_i]
        split_out = []
        fails_split = []
        for i in range(len(data)):
            data_graph = data[i].copy()
            processed_graph = processed_data[i].copy()
            try:
                res, permutation = complete_node_permutation(data_graph, processed_graph, 0, list())
            except:
                res = False
            if res:
                processed_graph.vs['feat'] = np.array(data_graph.vs['feat'])[permutation].tolist()
                processed_graph.vs['vid'] = [int(e) for e in processed_graph.vs['feat']]
            else:
                processed_graph.vs['feat'] = np.round(np.random.randn(len(processed_graph.vs['type'])) * 25 + 50, 0).clip(min=1.0, max=100.0).tolist()
                processed_graph.vs['vid'] = [int(e) for e in processed_graph.vs['feat']]
                fails_split.append(i)
                
            split_out.append(processed_graph)
            
        processed_data_graphs_with_feats.append(split_out)
        fails.append(fails_split)

    return processed_data_graphs_with_feats, fails


def complete_node_permutation(split_graph, graph, depth=0, permutation=list()):
    
    sorted_split_edges = [(min(i, o), max(i, o)) for (i, o) in split_graph.get_edgelist()]
    edges = graph.get_edgelist()
    
    for j in range(depth, split_graph.vcount()):
        # print(f'Depth is {j}, permutation: {permutation}')
        set_split_graph = set([(permutation[i], o) for (i, o) in sorted_split_edges if o == j])
        set_graph = set([(i, o) for (i, o) in edges if o == j])
        if (set_split_graph != set_graph) or (split_graph.vs['type'][j] != graph.vs['type'][j]):
            # Node type agreement
            candidates_idx = [e for e in np.where(np.array(split_graph.vs['type']) == graph.vs['type'][j])[0].tolist() if e not in permutation]
            # Edge agreement
            candidates_idx = [e for e in candidates_idx if \
                              set_graph == set([(permutation[i], j) for (i, o) in sorted_split_edges if o == e]) \
                             or len(set([(i, o) for (i, o) in sorted_split_edges if o == e])) == 0]
            # If the list is empty then a parent component is wrongly placed, leave the current loop
            if len(candidates_idx) == 0:
                return False, list()
            # Else, explore the different branches
            for candidate_idx in candidates_idx:
                success, candidate_permut = complete_node_permutation(split_graph, graph, j + 1, permutation + [candidate_idx])
                if success:
                    return True, candidate_permut
            if not success:
                return False, list()
        else:
            permutation.append(j)

    return True, permutation


### Utility functions for simulation metrics calculation

NODE_TYPE = {
    'R': 0,
    'C': 1,
    '+gm+':2,
    '-gm+':3,
    '+gm-':4,
    '-gm-':5,
    'sudo_in':6,
    'sudo_out':7,
    'In': 8,
    'Out':9
}

SUBG_NODE = {
    0: ['In'],
    1: ['Out'],
    2: ['R'],
    3: ['C'],
    4: ['R','C'],
    5: ['R','C'],
    6: ['+gm+'],
    7: ['-gm+'],
    8: ['+gm-'],
    9: ['-gm-'],
    10: ['C', '+gm+'],
    11: ['C', '-gm+'],
    12: ['C', '+gm-'],
    13: ['C', '-gm-'],
    14: ['R', '+gm+'],
    15: ['R', '-gm+'],
    16: ['R', '+gm-'],
    17: ['R', '-gm-'],
    18: ['C', 'R', '+gm+'],
    19: ['C', 'R', '-gm+'],
    20: ['C', 'R', '+gm-'],
    21: ['C', 'R', '-gm-'],
    22: ['C', 'R', '+gm+'],
    23: ['C', 'R', '-gm+'],
    24: ['C', 'R', '+gm-'],
    25: ['C', 'R', '-gm-']
}

SUBG_CON = {
    0: None,
    1: None,
    2: None,
    3: None,
    4: 'series',
    5: 'parral',
    6: None,
    7: None,
    8: None,
    9: None,
    10: 'parral',
    11: 'parral',
    12: 'parral',
    13: 'parral',
    14: 'parral',
    15: 'parral',
    16: 'parral',
    17: 'parral',
    18: 'parral',
    19: 'parral',
    20: 'parral',
    21: 'parral',
    22: 'series',
    23: 'series',
    24: 'series',
    25: 'series'
}


def subckt_to_elements(g):

    new_types = []
    new_edges = []
    node_counter = 0
    eq_dict = {}
    
    # First, assign new idx and types
    for i, t_i in enumerate(g.vs['type']):
        curr_types = [NODE_TYPE[t] for t in SUBG_NODE[t_i]]
        new_types.extend(curr_types)
    
        subc_type = SUBG_CON[t_i]
        if subc_type == 'parral':
            # In this case all elements share the same input and output
            eq_dict[f'{i}_in'] = list(node_counter + np.arange(len(curr_types)))
            eq_dict[f'{i}_out'] = eq_dict[f'{i}_in']
        elif subc_type == 'series':
            eq_dict[f'{i}_in'] = [node_counter]
            eq_dict[f'{i}_out'] = [node_counter + len(curr_types) - 1]
            # Add a new edge to account for the link between first and second edges
            new_edges.append((node_counter, node_counter + 1))
            # If the subcircuit has 3 elements, then the 2nd one is also connected to the output, and the last to the input
            if len(curr_types) > 2 :
                eq_dict[f'{i}_in'].append(node_counter + len(curr_types) - 1)
                eq_dict[f'{i}_out'].append(node_counter + 1)
        else:
            eq_dict[f'{i}_in'] = [node_counter]
            eq_dict[f'{i}_out'] = eq_dict[f'{i}_in']
        
        node_counter += len(curr_types)
    
    # Then complete missing edges
    for n_i, n_o in g.get_edgelist():
        for n_eq_i in eq_dict[f'{n_i}_out']:
            for n_eq_o in eq_dict[f'{n_o}_in']:
                new_edges.append((int(n_eq_i), int(n_eq_o)))
    
    g_o = igraph.Graph(directed=True)
    g_o.add_vertices(len(new_types), {'type': new_types})
    g_o.add_edges(new_edges)

    return g_o