import numpy as np
import igraph
# from .pyspice_utils import EqClass

## These classes are copied here to prevent import errors
class EqElt:
    def __init__(self, elt_name):
        self.eq = set([elt_name])
        self.first_elt = elt_name
    def eqto(self, b):
        self.eq = self.eq.union(b.eq)
    @property
    def default_name(self):
        if 'IN' in self.eq:
            return 'IN'
        if 'OUT' in self.eq:
            return 'OUT'
        else:
            return self.first_elt

class EqClass:
    def __init__(self):
        self.all_eqs = {}

    def add_new_eq(self, a, b):
        if (a in self.all_eqs.keys()) and (b in self.all_eqs.keys()):
            self.all_eqs[a].eqto(self.all_eqs[b])
        elif a in self.all_eqs.keys():
            self.all_eqs[a].eqto(EqElt(b))
        else:
            if b not in self.all_eqs.keys():
                self.all_eqs[b] = EqElt(b)
            self.all_eqs[b].eqto(EqElt(a))
            self.all_eqs[a] = self.all_eqs[b]
        self.all_eqs[b] = self.all_eqs[a]
    @property
    def equivalences(self):
        return {k: v.default_name for k, v in self.all_eqs.items()}


def build_node_equivalences(edge_list, in_idx, out_idx):

    ### Identify equivalent nodes

    # Identify common device I/O nodes of the circuit 
    eq_nodes = EqClass()
    
    for in_node, out_node in edge_list:
        if in_node == in_idx:
            eq_nodes.add_new_eq('IN', f'{str(out_node)}_inpt')
        elif out_node == in_idx:
            eq_nodes.add_new_eq('IN', f'{str(in_node)}_out')
        elif in_node == out_idx:
            eq_nodes.add_new_eq('OUT', f'{str(out_node)}_inpt')
        elif out_node == out_idx:
            eq_nodes.add_new_eq('OUT', f'{str(in_node)}_out')
        else:
            eq_nodes.add_new_eq(f'{str(out_node)}_inpt', f'{str(in_node)}_out')

    return eq_nodes


def search_missing_edges(eq_nodes, edgelist, postprocess=False):
    '''
    Associativity has been assumed to construct the equivalence between electric nodes. Hence look for missing edges in the equivalence lists.
    Missing edges originate from devices that belong to a feedback loop.
    Params
    ------
        postprocess (bool): if yes then edges on feedback loops have already been reversed, this time missing edges are directed toward feedback
        loop components.
    Outputs
    -------
        nodes_to_reverse (list): a list of node indices which belong to a feedback loop.
    '''

    nodes_to_reverse = []
    for i in eq_nodes.all_eqs.keys():
        if '_out' in i:
            reverse_edges = False
            int_i = int(i.replace('_out', ''))
            if postprocess and 'IN' in eq_nodes.all_eqs[i].eq:
                continue
            for o in eq_nodes.all_eqs[i].eq:
                if '_inpt' in o:
                    int_o = int(o.replace('_inpt', ''))
                    if (int_i, int_o) not in edgelist:
                        reverse_edges = True
                        if postprocess:
                            nodes_to_reverse.append(int_o)
                            reverse_edges = False                            
            if reverse_edges:
                nodes_to_reverse.append(int_i)

    return np.unique(nodes_to_reverse).tolist()


def search_components_in_series(feedb_idx, eq_nodes):

    nodes_to_reverse = []

    for feedb_node in feedb_idx:
        out_compo = [e for e in eq_nodes.all_eqs[f'{feedb_node}_out'].eq if e != f'{feedb_node}_out']
        if len(out_compo) == 1 and out_compo[0] != 'OUT':
            nodes_to_reverse.append(int(out_compo[0].replace('_inpt', '')))
        inpt_compo = [e for e in eq_nodes.all_eqs[f'{feedb_node}_inpt'].eq if e != f'{feedb_node}_inpt']
        if len(inpt_compo) == 1 and inpt_compo[0] != 'IN':
            nodes_to_reverse.append(int(inpt_compo[0].replace('_out', '')))

    return nodes_to_reverse


def reverse_feedback_loops(dataset):

    ### Script to recover feedback loops in the OCB subcircuit dataset and break subcircuits into its elements

    in_type, out_type = 0, 1
    feedb_gm_types = [8, 9, 12, 13, 16, 17, 20, 21, 24, 25]

    new_dataset = []

    for graph_i in dataset:

        n_elt = range(len(graph_i.vs['type']))
        # Identify I/O nodes
        in_idx, out_idx = ([i for i in n_elt if graph_i.vs['type'][i] == in_type][0],
                        [i for i in n_elt if graph_i.vs['type'][i] == out_type][0])
        edgelist = graph_i.get_edgelist()

        ## 1st step - Leverage the subcircuits to identify feedback loops

        # Flag components belonging to feedback loops using their type
        feedb_idx = [t in feedb_gm_types for t in graph_i.vs['type']]
        
        # Identify components with missing out edges
        eq_nodes = build_node_equivalences(edgelist, in_idx, out_idx)
        nodes_to_reverse = search_missing_edges(eq_nodes, edgelist)
        feedb_idx = [True if idx in nodes_to_reverse else fb_t for (idx, fb_t) in enumerate(feedb_idx)]

        # Search for components in series with feedback loop components (max 1 before, 1 after)
        nodes_to_reverse = search_components_in_series(np.where(feedb_idx)[0], eq_nodes)
        feedb_idx = [True if idx in nodes_to_reverse else fb_t for (idx, fb_t) in enumerate(feedb_idx)]
        ##
        
        ## 2nd step - Reverse feedback loops edges, edges can now be directed to IN or from OUT
        new_edges = [(o, i) if (feedb_idx[i] or feedb_idx[o]) else (i, o) for (i, o) in graph_i.get_edgelist()]
        ##

        ## 3rd step - Search for missing edges one more time to discover additional feedback loops (this concerns 166 original OCB graphs)
        eq_nodes = build_node_equivalences(new_edges, in_idx, out_idx)
        nodes_to_reverse = search_missing_edges(eq_nodes, new_edges, postprocess=True)
        nodes_to_reverse = [e for e in nodes_to_reverse if not feedb_idx[e]]

        # Search for components in series with feedback loop components (max 1 before, 1 after)
        nodes_to_reverse += search_components_in_series(nodes_to_reverse, eq_nodes)

        # Reverse feedback loops edges one last time - after inspection, this does not concern nodes that were already assigned to feedback loops
        new_edges = [(o, i) if (i in nodes_to_reverse or o in nodes_to_reverse) else (i, o) for (i, o) in new_edges]
        graph_i.delete_edges()
        graph_i.add_edges(new_edges)
        
        # Add a feedback attribute to the graph
        feedb_idx = [True if idx in nodes_to_reverse else fb_t for (idx, fb_t) in enumerate(feedb_idx)]
        graph_i.vs['feedback'] = feedb_idx

        ###
        
        # Map from subckt to low level
        new_dataset.append(subckt_to_elements(graph_i))

    return new_dataset


def split_graph(graph, attributes=True):
    '''
    Add intermediate circuit nodes between all electrical components. Those are given the type "10".
    '''

    # Identify I/O nodes
    n_elt = len(graph.vs['type'])
    in_idx, out_idx = ([i for i in range(n_elt) if graph.vs['type'][i] == 8][0],
                       [i for i in range(n_elt) if graph.vs['type'][i] == 9][0])
    
    # Identify common device I/O nodes of the circuit 
    edge_list = graph.get_edgelist()
    eq_nodes = EqClass()
    for in_node, out_node in edge_list:
        if in_node == in_idx:
            eq_nodes.add_new_eq('IN', f'{str(out_node)}_inpt')
        elif out_node == in_idx:
            eq_nodes.add_new_eq('IN', f'{str(in_node)}_out')
        elif in_node == out_idx:
            eq_nodes.add_new_eq('OUT', f'{str(out_node)}_inpt')
        elif out_node == out_idx:
            eq_nodes.add_new_eq('OUT', f'{str(in_node)}_out')
        else:
            eq_nodes.add_new_eq(f'{str(out_node)}_inpt', f'{str(in_node)}_out')

    # Node name to idx dict
    additional_nodes = np.unique([e for e in eq_nodes.equivalences.values() if e not in ['IN', 'OUT']])
    additional_nodes = {str(node): n_elt + i for i, node in enumerate(additional_nodes)}
    additional_nodes.update({'IN': in_idx, 'OUT': out_idx})

    # Update type list, nets are given the index "10"
    new_types = graph.vs['type'] + (len(additional_nodes) - 2) * [10]
    new_attributes = (graph.vs['feat'] + (len(additional_nodes) - 2) * [0.0]) if attributes else None
    new_vid = (graph.vs['vid'] + (len(additional_nodes) - 2) * [0]) if attributes else None

    # Update edge list
    new_edgelist = []
    for (i, o) in edge_list:
        if (i in [in_idx, out_idx]) or (o in [in_idx, out_idx]):
            new_edgelist.append((i, o))
        else:
            interm_node = additional_nodes[eq_nodes.equivalences[f'{i}_out']]
            if not (i, interm_node) in new_edgelist:
                new_edgelist.append((i, interm_node))
            if not (interm_node, o) in new_edgelist:
                new_edgelist.append((interm_node, o))

    return {'type': new_types, 'feat': new_attributes, 'vid': new_vid}, new_edgelist


# Split graphs in a list
def split_collection(graph_list, attributes):

    split_graph_list = []
    
    for graph in graph_list:
        new_graph = graph.copy()
        # Get new types and edge list
        attr, edgelist = split_graph(new_graph, attributes)
        # Assign new node types
        new_graph.delete_vertices()
        new_graph.add_vertices(len(attr['type']), attributes=attr)
        # Replace with new edge list
        new_graph.add_edges(edgelist)
        
        split_graph_list.append(new_graph)

    return split_graph_list


# The original OCB node types
NODE_TYPE = {'R': 0, 'C': 1, '+gm+':2, '-gm+':3, '+gm-':4, '-gm-':5, 'sudo_in':6, 'sudo_out':7, 'In': 8, 'Out':9}
# The OCB_v2 node types
extended_node_type = {'R': 0, 'C': 1, '+gm+':2, '-gm+':3, '+gm-':4, '-gm-':5, 'gm_in':6, 'gm_out':7, 'In': 8, 'Out':9, 'n': 10}
extended_subg_node = {0: ['R'], 1: ['C'], 2: ['gm_in', '+gm+', 'gm_out'], 3: ['gm_in', '-gm+', 'gm_out'], 4: ['gm_in', '+gm-', 'gm_out'], 
                      5: ['gm_in', '-gm-', 'gm_out'], 8: ['In'], 9: ['Out'], 10: ['n']}


def ocb_graph_with_split_gm(g):
    '''
    Takes a OCB graph as input and further split gm nodes into 3 nodes in series, making the input and output gm nodes apparent.
    '''

    new_types = []
    new_edges = []
    node_counter = 0
    eq_dict = {}
    
    # First, assign new idx and types
    for i, t_i in enumerate(g.vs['type']):
        curr_types = [extended_node_type[t] for t in extended_subg_node[t_i]]
        new_types.extend(curr_types)
    
        if len(curr_types) > 1:
            eq_dict[f'{i}_in'] = node_counter
            eq_dict[f'{i}_out'] = node_counter + len(curr_types) - 1
            # Add new edges to account for the links between gm and its pins
            new_edges += [(node_counter, node_counter + 1), (node_counter + 1, node_counter + 2)]
        else:
            eq_dict[f'{i}_in'] = node_counter
            eq_dict[f'{i}_out'] = eq_dict[f'{i}_in']
        
        node_counter += len(curr_types)
    
    # Then complete missing edges
    for n_i, n_o in g.get_edgelist():
        n_eq_i = eq_dict[f'{n_i}_out']
        n_eq_o = eq_dict[f'{n_o}_in']
        new_edges.append((int(n_eq_i), int(n_eq_o)))
    
    g_o = igraph.Graph(directed=True)
    g_o.add_vertices(len(new_types), {'type': new_types})
    g_o.add_edges(new_edges)
    
    return g_o


def retrieve_features(data_graphs, processed_data_graphs):
    '''
    Loops through the original dataset and the processed version with directed feedback loops, and match nodes and edges whenever possible
    to retrieve the features from the original data.
    '''
    
    fails = []
    processed_data_graphs_with_feats = []
    for split_i in range(len(data_graphs)):
        data, processed_data = [e[1] for e in data_graphs[split_i]], processed_data_graphs[split_i]
        split_out = []
        fails_split = []
        for i in range(len(data)):
            data_graph = data[i].copy()
            processed_graph = processed_data[i].copy()
            try:
                res, permutation = complete_node_permutation(data_graph, processed_graph)
            except:
                res = False
            if res:
                processed_graph.vs['feat'] = np.array(data_graph.vs['feat'])[permutation].tolist()
                processed_graph.vs['vid'] = [int(e) for e in processed_graph.vs['feat']]
            else:
                processed_graph.vs['feat'] = np.round(np.random.randn(len(processed_graph.vs['type'])) * 25 + 50, 0).clip(min=1.0, max=100.0).tolist()
                processed_graph.vs['vid'] = [int(e) for e in processed_graph.vs['feat']]
                fails_split.append(i)
                
            split_out.append(processed_graph)
            
        processed_data_graphs_with_feats.append(split_out)
        fails.append(fails_split)

    return processed_data_graphs_with_feats, fails


def complete_node_permutation(split_graph, graph, depth=0, permutation=list()):
    
    sorted_split_edges = [(min(i, o), max(i, o)) for (i, o) in split_graph.get_edgelist()]
    edges = graph.get_edgelist()
    
    for j in range(depth, split_graph.vcount()):
        # print(f'Depth is {j}, permutation: {permutation}')
        set_split_graph = set([(permutation[i], o) for (i, o) in sorted_split_edges if o == j])
        set_graph = set([(i, o) for (i, o) in edges if o == j])
        if (set_split_graph != set_graph) or (split_graph.vs['type'][j] != graph.vs['type'][j]):
            # Node type agreement
            candidates_idx = [e for e in np.where(np.array(split_graph.vs['type']) == graph.vs['type'][j])[0].tolist() if e not in permutation]
            # Edge agreement
            candidates_idx = [e for e in candidates_idx if \
                              set_graph == set([(permutation[i], j) for (i, o) in sorted_split_edges if o == e]) \
                             or len(set([(i, o) for (i, o) in sorted_split_edges if o == e])) == 0]
            # If the list is empty then a parent component is wrongly placed, leave the current loop
            if len(candidates_idx) == 0:
                return False, list()
            # Else, explore the different branches
            for candidate_idx in candidates_idx:
                success, candidate_permut = complete_node_permutation(split_graph, graph, j + 1, permutation + [candidate_idx])
                if success:
                    return True, candidate_permut
            if not success:
                return False, list()
        else:
            permutation.append(j)

    return True, permutation


from itertools import permutations

def get_permutations(type_i, type_f, nnode_types=10):
    '''
    List & rank all possible permutations from the initial type list type_i to the destination list type_f
    '''

    node_type_permuts = []
    for node_type in np.arange(nnode_types):
        orig_idx = np.where(type_f == node_type)[0].tolist()
        dest_idx = np.where(type_i == node_type)[0].tolist()
        node_type_permuts.append((orig_idx, list(permutations(dest_idx))))

    ordering = np.arange(len(type_i))
    all_permutations = [ordering]
    for i, all_p in node_type_permuts:
        temp_p = []
        for p in all_p:
            for permutation in all_permutations:
                permutation_copy = permutation.copy()
                permutation_copy[i] = p
                temp_p.append(permutation_copy)
        all_permutations = temp_p
    
    # Rank permutation in increasing order of the number of reorderings
    permut_rank = np.array([(p != ordering).sum() for p in all_permutations])
    permut_rank = np.argsort(permut_rank)
    all_permutations = [all_permutations[e] for e in permut_rank]

    return all_permutations, np.array([(p != ordering).sum() for p in all_permutations])


### Utility functions for simulation metrics calculation

NODE_TYPE = {
    'R': 0,
    'C': 1,
    '+gm+':2,
    '-gm+':3,
    '+gm-':4,
    '-gm-':5,
    'sudo_in':6,
    'sudo_out':7,
    'In': 8,
    'Out':9
}

SUBG_NODE = {
    0: ['In'],
    1: ['Out'],
    2: ['R'],
    3: ['C'],
    4: ['R','C'],
    5: ['R','C'],
    6: ['+gm+'],
    7: ['-gm+'],
    8: ['+gm-'],
    9: ['-gm-'],
    10: ['C', '+gm+'],
    11: ['C', '-gm+'],
    12: ['C', '+gm-'],
    13: ['C', '-gm-'],
    14: ['R', '+gm+'],
    15: ['R', '-gm+'],
    16: ['R', '+gm-'],
    17: ['R', '-gm-'],
    18: ['C', 'R', '+gm+'],
    19: ['C', 'R', '-gm+'],
    20: ['C', 'R', '+gm-'],
    21: ['C', 'R', '-gm-'],
    22: ['C', 'R', '+gm+'],
    23: ['C', 'R', '-gm+'],
    24: ['C', 'R', '+gm-'],
    25: ['C', 'R', '-gm-']
}

SUBG_CON = {
    0: None,
    1: None,
    2: None,
    3: None,
    4: 'series',
    5: 'parral',
    6: None,
    7: None,
    8: None,
    9: None,
    10: 'parral',
    11: 'parral',
    12: 'parral',
    13: 'parral',
    14: 'parral',
    15: 'parral',
    16: 'parral',
    17: 'parral',
    18: 'parral',
    19: 'parral',
    20: 'parral',
    21: 'parral',
    22: 'series',
    23: 'series',
    24: 'series',
    25: 'series'
}


def subckt_to_elements(g):

    new_types = []
    new_edges = []
    node_counter = 0
    eq_dict = {}
    
    # First, assign new idx and types
    for i, t_i in enumerate(g.vs['type']):
        curr_types = [NODE_TYPE[t] for t in SUBG_NODE[t_i]]
        new_types.extend(curr_types)
    
        subc_type = SUBG_CON[t_i]
        if subc_type == 'parral':
            # In this case all elements share the same input and output
            eq_dict[f'{i}_in'] = list(node_counter + np.arange(len(curr_types)))
            eq_dict[f'{i}_out'] = eq_dict[f'{i}_in']
        elif subc_type == 'series':
            eq_dict[f'{i}_in'] = [node_counter]
            eq_dict[f'{i}_out'] = [node_counter + len(curr_types) - 1]
            # Add a new edge to account for the link between first and second edges
            new_edges.append((node_counter, node_counter + 1))
            # If the subcircuit has 3 elements, then the 2nd one is also connected to the output, and the last to the input
            if len(curr_types) > 2 :
                eq_dict[f'{i}_in'].append(node_counter + len(curr_types) - 1)
                eq_dict[f'{i}_out'].append(node_counter + 1)
        else:
            eq_dict[f'{i}_in'] = [node_counter]
            eq_dict[f'{i}_out'] = eq_dict[f'{i}_in']
        
        node_counter += len(curr_types)
    
    # Then complete missing edges
    for n_i, n_o in g.get_edgelist():
        for n_eq_i in eq_dict[f'{n_i}_out']:
            for n_eq_o in eq_dict[f'{n_o}_in']:
                new_edges.append((int(n_eq_i), int(n_eq_o)))
    
    g_o = igraph.Graph(directed=True)
    g_o.add_vertices(len(new_types), {'type': new_types})
    g_o.add_edges(new_edges)

    return g_o