import onnx
import onnx.helper as helper
import onnx.numpy_helper as numpy_helper
import collections
import numpy as np
import torch
import torch.nn as nn
import os
import tqdm
import onnxruntime as ort
from collections import defaultdict


class Mul(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, input1, input2):
        return input1 * input2


class Add(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,input1,input2):
        return input1 + input2


class Reshape(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input1, input2):
        return torch.reshape(input1,shape=list(input2))


class Concat(nn.Module):
    def __init__(self, dim=[1]):
        super().__init__()
        self.dim = dim
        if not isinstance(self.dim, list):
            self.dim = [self.dim]
        for i, d in enumerate(self.dim):
            if not isinstance(d, int):
                self.dim[i] = int(d)

    def forward(self, *args):
        args = list(args)
        for i,a in enumerate(args):
            args[i] = a.type_as(args[0])
        return torch.cat(args,dim=self.dim[0])

class Shape(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return torch.IntTensor([input.size(i) for i in range(len(input.size()))])

class Gather(nn.Module):
    def __init__(self,dim=1):
        super().__init__()
        self.dim= int(dim)

    def forward(self, input1, input2):
        return torch.gather(input1,dim=self.dim,index=input2.cpu())

class Unsqueeze(nn.Module):
    def __init__(self, dim=[1]):
        super().__init__()
        self.dim = dim
        if not isinstance(self.dim, list):
            self.dim = [self.dim]
        for i,d in enumerate(self.dim):
            if not isinstance(d,int):
                self.dim[i] = int(d)

    def forward(self, input):
        x = input
        for i in self.dim:
            x = torch.unsqueeze(x,dim=i)
        return x

class TopologyAnalyser:
    def __init__(self):
        '''
        * :ref:`API in English <TopologyAnalyser.__init__-en>`

        .. _TopologyAnalyser.__init__-cn:

        这个类通过onnx分析模型的拓扑结构，方便后续处理
        此处还有更多更好的实现方法，欢迎开发者不断优化

        * :ref:`API in English <TopologyAnalyser.__init__-cn>`

        .. _TopologyAnalyser.__init__-en:

        This class analyzes the topological structure of the model through onnx to facilitate subsequent processing
        There are better implementation methods here, developers are welcome to continue to optimize
        '''
        self.data_nodes = []
        self.module_output = collections.OrderedDict()
        self.module_op = collections.OrderedDict()
        self.module_idx = collections.OrderedDict()
        self.param_idx = collections.OrderedDict()
        self.edge = collections.OrderedDict()
        self.reverse_edge = collections.OrderedDict()  # 快速计算前驱结点

    def add_data_node(self, a):
        if not a in self.data_nodes:
            self.data_nodes.append(a)

    def insert(self, a, b, info=None):
        self.add_data_node(a)
        self.add_data_node(b)
        if a not in self.edge.keys():
            self.edge[a] = [(b, info)]
        else:
            self.edge[a].append((b, info))
        if b not in self.reverse_edge.keys():
            self.reverse_edge[b] = [a]
        else:
            self.reverse_edge[b].append(a)

    def findNext(self, id):
        if isinstance(id, str):
            if id in self.edge.keys():
                return self.edge[id]
            else:
                return []
        elif isinstance(id, list):
            l = []
            for i in id:
                l += self.findNext(i)
            return l

    def findPre(self, id):
        l = []
        if isinstance(id, str):
            for pre_id in self.reverse_edge[id]:
                if pre_id in self.reverse_edge.keys():
                    for pre_pre_id in self.reverse_edge[pre_id]:
                        if pre_pre_id in self.edge.keys():
                            for item in self.edge[pre_pre_id]:
                                if item[0] == pre_id:
                                    l.append(item)
        elif isinstance(id, list):
            for i in id:
                l += self.findPre(i)
        return l

    def find_pre_module(self, module_name):
        if module_name in self.module_output.keys():
            ids = self.module_output[module_name]
            return set(['%s:%s' % (k[1]['op'], k[1]['param_module_name']) for k in self.findPre(ids)])
        else:
            return set()

    def find_next_module(self, module_name):
        if module_name in self.module_output.keys():
            ids = self.module_output[module_name]
            return set(['%s:%s' % (k[1]['op'], k[1]['param_module_name']) for k in self.findNext(ids)])
        else:
            return set()

    def update_module_idx(self, onnx_graph):
        for idx, n in enumerate(onnx_graph.node):
            trainable_input = n.input[1:]
            op = n.op_type
            k = set()
            for i in trainable_input:
                n = self._get_module_name_from_value_name(i)
                if n is not None:
                    k.add(n)
            if len(k) > 1:
                # TODO: sanity check, raise error
                pass
            if len(k) == 1:
                param_module_name = list(k)[0]
                self.module_op[param_module_name] = op
                self.module_idx[param_module_name] = idx

    def analyse(self, onnx_graph):  # 输入的onnx graph需要保证所以常量在已经在initializer中
        # 先把该往initializer放下面的参数，保证下面只有运算没有常量
        for idx, constant in enumerate(onnx_graph.initializer):
            self.param_idx[constant.name] = idx
        for idx, n in enumerate(onnx_graph.node):
            param_module_name = None
            op = n.op_type
            inputs = n.input
            outputs = n.output
            # print(inputs,outputs)
            k = set()
            trainable_input = inputs[1:]
            for i in trainable_input:
                n = self._get_module_name_from_value_name(i)
                if n is not None:
                    k.add(n)
            if len(k) > 1:
                # TODO: sanity check, raise error
                pass
            if len(k) == 1:
                param_module_name = list(k)[0]
                self.module_op[param_module_name] = op
                self.module_idx[param_module_name] = idx
            if op is not None:
                for o in outputs:
                    for i in inputs:
                        self.insert(i, o, {'op': op, 'param_module_name': param_module_name})
                    if param_module_name is not None:
                        if param_module_name not in self.module_output.keys():
                            self.module_output[param_module_name] = [o]
                        else:
                            self.module_output[param_module_name].append(o)
        return self

    @staticmethod
    def _get_module_name_from_value_name(value_name):
        module_name = None
        if len(value_name.split('.')) > 1:
            l = value_name.split('.')[:-1]
            l = '.'.join(l)
            module_name = l  # [1:]
            module_name.replace(' ', '')
        return module_name

def pytorch2onnx_model(model: nn.Module, data, **kargs) -> onnx.ModelProto:
    '''

    * :ref:`API in English <pytorch2onnx_model-en>`

    .. _pytorch2onnx_model-cn:

    :param model: 待转换的PyTorch模型

    :param data: 用于转换的数据（用来确定输入维度）

    :param log_dir: 输出文件夹

    转换PyTorch模型到onnx模型

    * :ref:`API in English <pytorch2onnx_model-cn>`

    .. _pytorch2onnx_model-en:

    :param model: the PyTorch model to be converted

    :param data: The data used for conversion (used to determine the input dimension)

    :param log_dir: output folder

    Convert PyTorch model to onnx model

    '''
    try:
        log_dir = kargs['log_dir']
    except KeyError:
        print('pytorch2onnx_model need argument log_dir!')
    dump_input_size = [data.size(i) for i in range(len(data.size()))]
    dump_input_size[0] = 1
    fname = os.path.join(log_dir,'onnxmodel')
    try:
        dynamic_axes = {'input': {0: 'batch_size'},
                        'output': {0: 'batch_size'}}
        torch.onnx.export(model, torch.ones(dump_input_size), fname,
                          input_names=['input'],
                          output_names=['output'],
                          dynamic_axes=dynamic_axes)
    except BaseException:
        raise NotImplementedError("Models with multiple inputs are not supported yet!")
    return onnx.load(fname)

def onnx2pytorch_model(model: onnx.ModelProto, _converter) -> nn.Module:
    model = _pt_model(model, _converter)
    model = model.reduce()
    return model

def layer_reduction(model: onnx.ModelProto) -> onnx.ModelProto:
    graph = model.graph
    topo_analyser = TopologyAnalyser()
    graph = move_constant_to_initializer(graph)
    topo_analyser.analyse(graph)

    absorb_bn(graph, topo_analyser)
    remove_unreferenced_initializer(graph)
    update_topology(graph)
    print("Finish layer reduction!")
    return model

def rate_normalization(model: onnx.ModelProto, data: torch.Tensor, **kargs) -> onnx.ModelProto:
    '''

    * :ref:`API in English <rate_normalization-en>`

    .. _rate_normalization-cn:

    :param model: ANN模型，类型为onnx.ModelProto

    :param data: 用于转换的数据，类型为torch.Tensor

    :param channelwise: 如果为``True``，则控制激活幅值的统计是channelwise的；否则，控制激活幅值的统计是layerwise的

    :param robust: 如果为``True``，则控制激活幅值的统计是激活的99.9百分位；否则，控制激活幅值的统计是激活的最值

    :param eps: epsilon；未设置值时默认1e-5

    发放率归一化

    * :ref:`API in English <rate_normalization-cn>`

    .. _rate_normalization-en:

    :param model: ANN model, the type is onnx.ModelProto

    :param data: the data used for conversion, the type is torch.Tensor

    :param channelwise: If ``True`` , the statistics that control the activation amplitude are channelwise; otherwise, the statistics that control the activation amplitude are layerwise

    :param robust: If ``True``, the statistic of the control activation amplitude is the 99.9th percentile of activation; otherwise, the statistic of the activation amplitude is the maximum value of activation

    :param eps: epsilon; if no value is set, the default is 1e-5

    normalize the firing rate

    '''

    try:
        channelwise = kargs['channelwise']
    except KeyError:
        channelwise = False
    try:
        robust_norm = kargs['robust']
    except KeyError:
        robust_norm = False
    try:
        eps = kargs['eps']
    except KeyError:
        eps = 1e-5
    topo_analyser = update_topology(model.graph)
    output_debug = {}
    output_statistics = get_intermediate_output_statistics(model, data,
                                                           channelwise=channelwise) # if want debug, debug=output_debug
    model = normalize_model(model, output_statistics, topo_analyser, robust_norm=robust_norm,
                            channelwise=channelwise, eps=eps)
    return model

def save_model(model: onnx.ModelProto, f=None):
    fb = model.SerializeToString()
    if f is not None:
        if hasattr(f, 'write'):
            f.write(fb)
        else:
            with open(f, "wb") as f:
                f.write(fb)
    return fb

def move_constant_to_initializer(graph):
    constant_idx = []
    for idx, n in enumerate(graph.node):
        op = n.op_type
        if op == 'Constant':
            constant_idx.append(idx)
    if len(constant_idx):
        for idx in reversed(constant_idx):
            n = graph.node[idx]
            graph.initializer.append(
                numpy_helper.from_array(numpy_helper.to_array(n.attribute[0].t), n.output[0]))
            graph.node.remove(n)
    return graph

def print_onnx_model(graph):
    print(onnx.helper.printable_graph(graph))

def absorb_bn(graph, topo_analyser):
    print("\nAbsorbing BatchNorm Parameters...\n")
    for mn in tqdm.tqdm(reversed(topo_analyser.module_output.keys())):
        if topo_analyser.module_op[mn] == 'BatchNormalization':
            pre_m = topo_analyser.find_pre_module(mn)
            next_m = topo_analyser.find_next_module(mn)
            bn_weight_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[1]]
            bn_weight = np.array(numpy_helper.to_array(graph.initializer[bn_weight_idx]))
            bn_bias_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[2]]
            bn_bias = np.array(numpy_helper.to_array(graph.initializer[bn_bias_idx]))
            bn_mean_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[3]]
            bn_mean = np.array(numpy_helper.to_array(graph.initializer[bn_mean_idx]))
            bn_var_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[4]]
            bn_var = np.array(numpy_helper.to_array(graph.initializer[bn_var_idx]))
            bn_eps = graph.node[topo_analyser.module_idx[mn]].attribute[0].f
            bn_std = np.sqrt(bn_var + bn_eps)
            if len(pre_m) == 1 and list(pre_m)[0].split(':')[0] in ['Conv', 'Gemm']:
                pre_mn = list(pre_m)[0].split(':')[1]
                weight_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[pre_mn]].input[1]]
                weight = np.array(numpy_helper.to_array(graph.initializer[weight_idx]))
                if len(graph.node[topo_analyser.module_idx[pre_mn]].input) == 2:
                    bias = None
                else:
                    bias_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[pre_mn]].input[2]]
                    bias = np.array(numpy_helper.to_array(graph.initializer[bias_idx]))
                wrsp_args = (-1, 1) if len(weight.shape) == 2 else (-1, 1, 1, 1)

                weight_ = weight * bn_weight.reshape(*wrsp_args) / bn_std.reshape(*wrsp_args)
                bias_ = ((bias if bias is not None else 0) - bn_mean.reshape(-1)) * bn_weight.reshape(
                    -1) / bn_std.reshape(-1) \
                        + bn_bias.reshape(-1)
                assert (list(pre_m)[0].split(':')[0] in ['Conv', 'Gemm'])
                args = {}
                for attr in graph.node[topo_analyser.module_idx[pre_mn]].attribute:
                    args[attr.name] = helper.get_attribute_value(attr)
                new_node = onnx.helper.make_node(
                    list(pre_m)[0].split(':')[0],
                    inputs=[graph.node[topo_analyser.module_idx[pre_mn]].input[0], pre_mn + ".new.weight", pre_mn + ".new.bias"],
                    outputs=[graph.node[topo_analyser.module_idx[mn]].output[0]],
                    **args
                )
                graph.initializer.append(numpy_helper.from_array(weight_.astype(np.float32), pre_mn + ".new.weight"))
                graph.initializer.append(numpy_helper.from_array(bias_.astype(np.float32), pre_mn + ".new.bias"))
                graph.node.remove(graph.node[topo_analyser.module_idx[pre_mn]])
                graph.node.insert(topo_analyser.module_idx[pre_mn], new_node)
                graph.node.remove(graph.node[topo_analyser.module_idx[mn]])
            else:
                weight_ = bn_weight / bn_std
                bias_ = bn_bias - bn_weight * bn_mean / bn_std
                name = graph.initializer[bn_weight_idx].name
                graph.initializer.remove(graph.initializer[bn_weight_idx])
                graph.initializer.insert(bn_weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name))
                name = graph.initializer[bn_bias_idx].name
                graph.initializer.remove(graph.initializer[bn_bias_idx])
                graph.initializer.insert(bn_bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name))
                name = graph.initializer[bn_mean_idx].name
                graph.initializer.remove(graph.initializer[bn_mean_idx])
                graph.initializer.insert(bn_mean_idx,
                                   numpy_helper.from_array(np.zeros_like(bn_mean).astype(np.float32), name))
                name = graph.initializer[bn_var_idx].name
                graph.initializer.remove(graph.initializer[bn_var_idx])
                graph.initializer.insert(bn_var_idx,
                                   numpy_helper.from_array(np.ones_like(bn_var).astype(np.float32), name))

def remove_unreferenced_initializer(graph):
    in_graph = set()
    in_initializer = set()
    for node in graph.node:
        in_graph.update(node.input)
        in_graph.update(node.output)
    for init in graph.initializer:
        in_initializer.add(init.name)
    not_in_graph = in_initializer - in_graph
    l = len(graph.initializer)
    for i in range(l - 1, -1, -1):
        if graph.initializer[i].name in not_in_graph:
            graph.initializer.remove(graph.initializer[i])

def update_topology(graph):
    topo_analyser = TopologyAnalyser()
    move_constant_to_initializer(graph)
    topo_analyser.analyse(graph)
    return topo_analyser

def find_node_by_output(output_name, graph):
    flag = False
    idx, node = None, None
    for idx, node in enumerate(graph.node):
        if output_name in node.output:
            flag = True
            break
    if not flag:
        idx, node = None, None
    return idx, node

def scale_node_weight_bias(topo_analyser, graph, node_idx, scale):
    initializer = graph.initializer
    node = graph.node[node_idx]
    if len(node.input) < 2:
        return
    weight_idx = topo_analyser.param_idx[node.input[1]]
    bias_idx = topo_analyser.param_idx[node.input[2]] if len(node.input) >= 3 else None
    weight = np.array(numpy_helper.to_array(initializer[weight_idx]))
    bias = np.array(numpy_helper.to_array(initializer[bias_idx])) if bias_idx is not None else None

    w_scale = scale.reshape([*scale.shape] + [1 for _ in range(len(weight.shape) - 1)]) \
        if len(scale.shape) == 1 else scale
    b_scale = scale

    weight_ = weight * w_scale
    name = initializer[weight_idx].name
    initializer.remove(initializer[weight_idx])
    initializer.insert(weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name))
    if bias is not None:
        bias_ = bias * b_scale
        name = initializer[bias_idx].name
        initializer.remove(initializer[bias_idx])
        initializer.insert(bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name))

def get_onnx_output(model, numpy_tensor):
    ort_session = ort.InferenceSession(model.SerializeToString())
    outputs = ort_session.run(None, {'input': numpy_tensor})
    return outputs

def get_intermediate_output_statistics(model, numpy_tensor, channelwise=False, debug=None):
    graph = model.graph
    output_needed_module = {}
    output_needed_all_input = {}
    for idx, node in enumerate(graph.node):
        output = node.output
        input = node.input
        if 'input' in node.input:
            for out in output:
                output_needed_module[out] = set([idx])
                output_needed_all_input[out] = set(input)
        else:
            s = set()
            s_i = set()
            for in_ in input:
                s |= (output_needed_module[in_] if in_ in output_needed_module.keys() else set())
                s_i |= (output_needed_all_input[in_] if in_ in output_needed_all_input.keys() else set())
            for out in output:
                output_needed_module[out] = s | set([idx])
                output_needed_all_input[out] = s_i | set(input)

    output_statistics = {}
    if not channelwise:
        statistic = {'shape': numpy_tensor.shape,
                     'min': np.min(numpy_tensor),
                     'max': np.max(numpy_tensor) if np.max(numpy_tensor) > 0 else np.abs(np.min(numpy_tensor)),
                     '99.9': np.percentile(numpy_tensor, 99.9)
                     }
    else:
        axis_args = (0, 2, 3) if len(numpy_tensor.shape) == 4 else (0)
        statistic = {'shape': numpy_tensor.shape,
                     'min': np.min(numpy_tensor, axis=axis_args),
                     'max': np.max(numpy_tensor, axis=axis_args),
                     '99.9': np.percentile(numpy_tensor, 99.9, axis=axis_args)
                     }
    output_statistics['input'] = statistic
    print("\nGetting intermediate output statistics...\n")
    for out in tqdm.tqdm(output_needed_module.keys()):
        keep_nodes = [graph.node[i] for i in list(output_needed_module[out])]
        keep_initializer = [init for init in graph.initializer
                            if init.name in list(output_needed_all_input[out])]
        var_out = []
        value_info = onnx.ValueInfoProto()
        value_info.name = out
        var_out.append(value_info)
        new_graph = onnx.helper.make_graph(keep_nodes, graph.name, graph.input,
                                       var_out, keep_initializer)
        tmp_model = onnx.helper.make_model(new_graph)
        tmp_model.ir_version = model.ir_version
        tmp_model.producer_name = model.producer_name
        tmp_model.producer_version = model.producer_version
        tmp_model.domain = model.domain
        tmp_model.model_version = model.model_version
        tmp_model.doc_string = model.doc_string
        if len(tmp_model.metadata_props) > 0:
            values = {p.key: p.value for p in model.metadata_props}
            onnx.helper.set_model_props(tmp_model, values)
        # fix opset import
        for oimp in model.opset_import:
            op_set = tmp_model.opset_import.add()
            op_set.domain = oimp.domain
            op_set.version = oimp.version

        ort_session = ort.InferenceSession(tmp_model.SerializeToString())
        outputs = ort_session.run(None, {'input': numpy_tensor})
        if debug is not None:
            # print(out,outputs[0].reshape(1,-1)[0,10:20])
            debug[out] = outputs[0]
        if not channelwise:
            statistic = {'shape': outputs[0].shape,
                         'min': np.min(outputs[0]),
                         'max': np.max(outputs[0]) if np.max(outputs[0]) > 0 else np.abs(np.min(outputs[0])),
                         '99.9': np.percentile(outputs[0], 99.9) if np.percentile(outputs[0], 99.9) > 0 else np.abs(np.min(outputs[0]))
                         }
        else:
            axis_args = (0, 2, 3) if len(outputs[0].shape) == 4 else (0)
            statistic = {'shape': outputs[0].shape,
                         'min': np.min(outputs[0], axis=axis_args),
                         'max': np.max(outputs[0], axis=axis_args),
                         '99.9': np.percentile(outputs[0], 99.9, axis=axis_args)
                         }
            # print(np.max(statistic['max']),np.max(outputs[0]))
        output_statistics[out] = statistic
    print("Finished getting intermediate output statistics!")
    if debug is not None:
        return output_statistics,debug
    else:
        return output_statistics

def normalize_model(model, output_statistics, topo_analyser, robust_norm=True, channelwise=False, eps=1e-5):
    nodes = model.graph.node
    graph = model.graph
    initializer = model.graph.initializer
    if robust_norm:
        statistic_key = '99.9'
    else:
        statistic_key = 'max'
    node_scaled_range = {}
    seperate_scale = collections.OrderedDict()
    print("\nNormalizing model...\n")
    for node_idx, node in enumerate(tqdm.tqdm(nodes)):
        output = node.output
        input = node.input
        op = node.op_type
        if input[0] == 'input':  # single input model
            l = output_statistics[input[0]]['shape'][1]
            node_scaled_range[input[0]] = np.ones(l) if channelwise else 1.0 \
                                                                         * output_statistics[input[0]][
                                                                             statistic_key]

        if op in ['Conv', 'Gemm']:
            weight_idx = topo_analyser.param_idx[input[1]]
            bias_idx = topo_analyser.param_idx[input[2]] if len(input) == 3 else None
            weight = np.array(numpy_helper.to_array(initializer[weight_idx]))
            bias = np.array(numpy_helper.to_array(initializer[bias_idx])) if bias_idx is not None else None

            l = output_statistics[output[0]]['shape'][1]
            input_real_range = node_scaled_range[input[0]]
            input_range = output_statistics[input[0]][statistic_key]
            output_range = output_statistics[output[0]][statistic_key]
            demand = np.ones(l) if channelwise else 1.0
            w_scale = demand / (output_range + eps) * (
                        input_range / (input_real_range + eps)) if not channelwise else \
                (demand / (output_range + eps)).reshape(-1, 1).dot(
                    (input_range / (input_real_range + eps)).reshape(1, -1))
            w_scale = w_scale.reshape([*w_scale.shape, 1, 1]) if len(weight.shape) == 4 else w_scale
            b_scale = 1 / (output_range + eps)
            node_scaled_range[output[0]] = demand

            weight_ = weight * w_scale

            name = initializer[weight_idx].name
            initializer.remove(initializer[weight_idx])
            initializer.insert(weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name))
            if bias is not None:
                bias_ = bias * b_scale
                name = initializer[bias_idx].name
                initializer.remove(initializer[bias_idx])
                initializer.insert(bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name))

        elif op == 'BatchNormalization':  # var=1 mean=0
            weight_idx = topo_analyser.param_idx[input[1]]
            bias_idx = topo_analyser.param_idx[input[2]]
            weight = np.array(numpy_helper.to_array(initializer[weight_idx]))
            bias = np.array(numpy_helper.to_array(initializer[bias_idx]))

            # node_scaled_range[output[0]] = node_scaled_range[input[0]] * self.output_statistics[input[0]][statistic_key] / self.output_statistics[output[0]][statistic_key]
            # lamda_last = self.output_statistics[input[0]][statistic_key]
            # lamda = self.output_statistics[output[0]][statistic_key]
            # weight_ = weight * node_scaled_range[output[0]]
            # bias_ = bias / lamda

            print(output_statistics[output[0]])
            input_real_range = node_scaled_range[input[0]]
            input_range = output_statistics[input[0]][statistic_key]
            output_range = output_statistics[output[0]][statistic_key]
            demand = 1.0
            w_scale = demand / (output_range + eps) * (input_range / (input_real_range + eps))
            b_scale = 1 / (output_range + eps)
            node_scaled_range[output[0]] = demand
            weight_ = weight * w_scale
            bias_ = bias * b_scale
            # print(output[0],op,input[0], input_range, output_range, demand, input_real_range, w_scale)

            name = initializer[weight_idx].name
            initializer.remove(initializer[weight_idx])
            initializer.insert(weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name))
            name = initializer[bias_idx].name
            initializer.remove(initializer[bias_idx])
            initializer.insert(bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name))

        elif op == 'Add':
            l = output_statistics[output[0]]['shape'][1]
            demand = np.ones(l) if channelwise else 1.0
            node_scaled_range[output[0]] = demand
            output_range = output_statistics[output[0]][statistic_key]

            # node_scaled_range[output[0]] = 1.0
            # lamda = self.output_statistics[output[0]][statistic_key]
            # lamda_lasts = {}
            for i in input:
                if i in output_statistics.keys():
                    # lamda_lasts[i] = self.output_statistics[i][statistic_key]
                    # scale = lamda_lasts[i] / lamda
                    input_real_range = node_scaled_range[i]
                    input_range = output_statistics[i][statistic_key]
                    scale = demand / (output_range + eps) * (input_range / (input_real_range + eps))

                    # print(output[0], op, i, input_range, output_range, demand, input_real_range, scale)

                    idx, _ = find_node_by_output(i, graph)
                    if idx is not None and nodes[idx].op_type in ['Conv', 'Gemm', 'BatchNormalization']:
                        scale_node_weight_bias(topo_analyser, graph, idx, scale)
                    else:
                        scale = scale.reshape(
                            [1, *scale.shape] + [1 for _ in range(len(output_statistics[i]['shape']) - 2)]) \
                            if len(scale.shape) == 1 else scale
                        initializer.append(numpy_helper.from_array(scale.astype(np.float32), "scale_" + i))
                        if idx not in seperate_scale.keys():
                            # seperate_scale[node_idx] = [(i,"scale_"+i,"scaled_"+i)]
                            seperate_scale[node_idx] = {i: ("scale_" + i, "scaled_" + i)}
                        else:
                            # seperate_scale[node_idx].append((i,"scale_"+i,"scaled_"+i))
                            seperate_scale[node_idx][i] = ("scale_" + i, "scaled_" + i)
            pass
        elif op in ['Gather', 'Unsqueeze', 'Shape', 'Concat']:
            continue
        # elif op == "Concat":
        # raise NotImplementedError("Not supported %s yet!"%(op))
        elif op == "Softmax":
            raise NotImplementedError("Not supported %s yet!" % (op))
        else:  # single input single output module
            # print(op,self.output_statistics[output[0]]['shape'])
            input_range = output_statistics[input[0]][statistic_key]
            output_range = output_statistics[output[0]][statistic_key]
            input_scaled_range = node_scaled_range[input[0]]
            output_scaled_range = input_scaled_range / (input_range + eps) * output_range
            node_scaled_range[output[0]] = output_scaled_range

            # print(output[0], op, input[0], input_range, output_range, output_scaled_range)
            # print(op, node_scaled_range[output[0]],'=',input_scaled_range,'/',input_range,'*',output_range)
        # else:
        #     raise NotImplementedError("Not supported yet! %s"%(op))

    if len(seperate_scale.keys()) != 0:
        print("Making new scale node...")

    for node_idx in reversed(seperate_scale.keys()):
        args = {}
        for attr in nodes[node_idx].attribute:
            args[attr.name] = helper.get_attribute_value(attr)
        input = [str(i) if i not in seperate_scale[node_idx].keys() else seperate_scale[node_idx][i][1] \
                 for i in nodes[node_idx].input]

        output = [str(i) for i in nodes[node_idx].output]

        new_node = onnx.helper.make_node(
            nodes[node_idx].op_type,
            inputs=input,
            outputs=output,
            **args
        )
        nodes.remove(nodes[node_idx])
        nodes.insert(node_idx, new_node)

        for i in seperate_scale[node_idx].keys():
            new_node = onnx.helper.make_node(
                'Mul',
                inputs=[seperate_scale[node_idx][i][0], i],
                outputs=[seperate_scale[node_idx][i][1]]
            )
            nodes.insert(node_idx, new_node)
    print("Finished normalizing model!")
    return model

def _pre_onnx_shape_inference(model:onnx.ModelProto):
    '''
    为了对模型进行shape inference，需要先对onnxmodel运行此函数进行准备

    To perform shape inference for model, need to run this function on onnxmodel to prepare

    This function has referenced code in https://github.com/onnx/onnx/issues/2660#issuecomment-605874784
    '''
    if model.ir_version < 4:
        return

    def add_some_graph_info(graph:onnx.GraphProto):
        inputs = {i.name for i in graph.input}
        vi_dict = {vi.name: vi for vi in graph.value_info}
        for init in graph.initializer:
            if init.name in inputs:
                continue
            vi = vi_dict.get(init.name)
            if vi is None:
                vi = graph.value_info.add()
                vi.name = init.name

            tensor_type = vi.type.tensor_type
            if tensor_type.elem_type == onnx.TensorProto.UNDEFINED:
                tensor_type.elem_type = init.data_type
            if not tensor_type.HasField("shape"):
                tensor_type.shape.dim.extend([])
                for dim in init.dims:
                    tensor_type.shape.dim.add().dim_value = dim

        for node in graph.node:
            for attr in node.attribute:
                if attr.ref_attr_name != "":
                    continue

                if attr.type == onnx.AttributeProto.GRAPH:
                    add_some_graph_info(attr.g)
                if attr.type == onnx.AttributeProto.GRAPHS:
                    for g in attr.graphs:
                        add_some_graph_info(g)
    return add_some_graph_info(model.graph)

class _pt_model(nn.Module):
    def __init__(self, path_or_model, _converter=None):
        super(_pt_model, self).__init__()
        if path_or_model is not None:
            if isinstance(path_or_model, str):
                onnx_model = onnx.load(path_or_model)
            else:
                onnx_model = path_or_model
            self.onnx_model = onnx_model

            self.loaded_weights = load_parameters(self, onnx_model.graph.initializer)
            self.module_list = nn.ModuleList([])
            self.op_tree = {}

            _pre_onnx_shape_inference(onnx_model)
            inferred_model = onnx.shape_inference.infer_shapes(onnx_model)
            self.value_info = inferred_model.graph.value_info
            self.dim_info = {}
            for idx, v in enumerate(self.value_info):
                self.dim_info[v.name] = len(v.type.tensor_type.shape.dim)

            self.graph = defaultdict(list)
            # self.V = set()
            for idx, node in enumerate(onnx_model.graph.node):
                op = node.op_type
                # if op=='MatMul': # TODO temporary
                #     op = 'Gemm'
                (op_idx, inputs, outputs) = getattr(_converter, 'convert_' + op.lower())(node, self)
                for out_seq, output in enumerate(outputs):
                    # self.op_tree[output] = (op_idx, inputs, out_seq)
                    self.op_tree[output] = (op_idx, inputs, out_seq)
                for output in outputs:
                    for input in inputs:
                        self.graph[input].append(output)
                # self.V.update(inputs)
                # self.V.update(outputs)
            # print(self.V)

            self.input_name = [i.name for i in onnx_model.graph.input]
            self.output_name = [i.name for i in onnx_model.graph.output]

            for out in onnx_model.graph.output:
                self.graph[out.name] = []

            def TopologicalSort(G):
                in_degrees = dict((u, 0) for u in G)
                for u in G:
                    for v in G[u]:
                        in_degrees[v] += 1
                Q = [u for u in G if in_degrees[u] == 0]
                res = []
                while Q:
                    u = Q.pop()
                    res.append(u)
                    for v in G[u]:
                        in_degrees[v] -= 1
                        if in_degrees[v] == 0:
                            Q.append(v)
                return res

            self.compute_seq = TopologicalSort(self.graph)
            # print(self.compute_seq)

            # self.tensors = {}

            for k in self.loaded_weights.keys():
                if isinstance(self.loaded_weights[k], torch.FloatTensor):
                    setattr(self, 'P' + k.replace('.', '@'), torch.nn.Parameter(self.loaded_weights[k]))
                else:
                    # print(self.loaded_weights[k])
                    self.register_buffer('P' + k.replace('.', '@'), self.loaded_weights[k])

            self.reserved_tensors_name = list(self.loaded_weights.keys())
            # print(self.reserved_tensors_name)
            # print(self.tensors)

    def refresh_running_tensor(self):
        self.tensors = {}
        for k in set(self.tensors.keys()) | set(self.reserved_tensors_name):
            if k not in self.reserved_tensors_name:
                del self.tensors[k]
            else:
                self.tensors[k] = getattr(self, 'P' + k.replace('.', '@'))

    def forward(self, input):
        self.refresh_running_tensor()
        if not isinstance(input, list) or not isinstance(input, tuple):
            input = [input]
        for i, n in enumerate(self.input_name):
            self.tensors[n] = input[i]
        for name in self.compute_seq:
            if name in self.op_tree.keys():
                op_idx, inputs, out_seq = self.op_tree[name]
                # print(name,op_idx, inputs)
                args = []
                for input in inputs:
                    args.append(self.tensors[input])
                # print(len(args))
                result = self.module_list[op_idx](*args)
                if not isinstance(result, tuple):
                    self.tensors[name] = result
                else:
                    self.tensors[name] = result[out_seq]

        if len(self.output_name) == 1:
            return self.tensors[self.output_name[0]]
        else:
            ret = []
            for output in self.output_name:
                ret.append(self.tensors[output])
            return ret

    def reduce(self):
        import copy
        net = _pt_model(None)
        for k in self.reserved_tensors_name:
            if isinstance(self.loaded_weights[k], torch.FloatTensor):
                setattr(net, 'P' + k.replace('.', '@'),
                        torch.nn.Parameter(getattr(self,'P' + k.replace('.', '@')).data.detach().clone()) )
            else:
                net.register_buffer('P' + k.replace('.', '@'),
                                    getattr(self,'P' + k.replace('.', '@')).data.clone() )

        net.compute_seq = copy.deepcopy(self.compute_seq)

        net.input_name = copy.deepcopy(self.input_name)
        net.output_name = copy.deepcopy(self.output_name)
        net.module_list = copy.deepcopy(self.module_list)
        net.op_tree = copy.deepcopy(self.op_tree)
        net.reserved_tensors_name = copy.deepcopy(self.reserved_tensors_name)
        return net

def load_parameters(model:_pt_model, initializer):
    param_dict = {}
    for init in initializer:
        param_dict[init.name] = torch.from_numpy(numpy_helper.to_array(init).copy())
    return param_dict

class _o2p_converter:
    def __init__(self):
        '''
        * :ref:`API in English <ONNX_Converter.__init__-en>`

        .. _ONNX_Converter.__init__-cn:

        该类主要将onnx模型转换为Pytorch的ANN模型，从而转换为SpikingJelly的SNN模型
        链接中 [＃f1]_ 提供了一个onnx-pytorch转换的主要版本。更复杂的版本可以在这里找到。
        大多数使用过的onnx运算符已在此处定义，但仍然有一些未被覆盖，或没有被完美实现
        用户可以通过添加如下面例子所示的静态方法来定义您的例外情况

        * :ref:`API in English <ONNX_Converter.__init__-cn>`

        .. _ONNX_Converter.__init__-en:

        This class mainly convert an onnx model to Pytorch ANN model, and thus to SpikingJelly SNN model
        The link [#f1]_ has provided a primary version of onnx-pytorch conversion. More complex version can be found here.
        Most used onnx operators has covered here, yet still there are some left, or not being defined perfectly
        User can define your exceptions by adding static method like below

        .. [#f1] https://gist.github.com/qinjian623/6aa777037534c1c1dccbb66f832e93b8
        '''
        pass

    def add_method(self, op_name, func):
        setattr(self, 'convert_'+op_name, staticmethod(func))

    @staticmethod
    def convert_conv(node, model:_pt_model):
        attr_map = {
            "pads": "padding",
            "strides": "stride",
            "kernel_shape": "kernel_size",
            "group": "groups",
            "dilations": "dilation"
        }
        assert len(node.output) == 1
        with_bias = False
        if len(node.input) == 3:
            with_bias = True
            bias = model.loaded_weights[node.input[2]]
            del model.loaded_weights[node.input[2]]
        weight = model.loaded_weights[node.input[1]]
        del model.loaded_weights[node.input[1]]
        in_channels = weight.shape[1]
        out_channels = weight.shape[0]
        kwargs = {}
        for att in node.attribute:
            kwargs[attr_map[att.name]] = list(att.ints) if att.name != 'group' else att.i
        if 'padding' in kwargs:
            assert(kwargs["padding"][0]==kwargs["padding"][2] and kwargs["padding"][1]==kwargs["padding"][3])
            kwargs["padding"] = kwargs["padding"][0],kwargs["padding"][1]
        groups = 1 if 'groups' not in kwargs else kwargs['groups']
        in_channels *= groups
        conv = nn.Conv2d(in_channels, out_channels, **kwargs, bias=with_bias)
        conv.weight.data = weight
        if with_bias:
            conv.bias.data = bias
        model.module_list.append(conv)
        return len(model.module_list)-1, node.input[:1], node.output

    @staticmethod
    def convert_relu(node, model:_pt_model):
        relu = nn.ReLU()
        model.module_list.append(relu)
        return len(model.module_list)-1, node.input, node.output

    @staticmethod
    def convert_prelu(node, model:_pt_model):
        weight = model.loaded_weights[node.input[1]]
        del model.loaded_weights[node.input[1]]
        prelu = nn.PReLU()
        prelu.weight.data = weight
        model.module_list.append(prelu)
        return len(model.module_list) - 1, node.input[:-1], node.output

    @staticmethod
    def convert_shape(node, model:_pt_model):
        shape = Shape()
        model.module_list.append(shape)
        return len(model.module_list) - 1, node.input, node.output

    @staticmethod
    def convert_gather(node, model:_pt_model):
        attr_map = {
            "axis": "dim"
        }
        kwargs = {}
        for att in node.attribute:
            if att.name in attr_map:
                kwargs[attr_map[att.name]] = att.f
        gather = Gather(**kwargs)
        model.module_list.append(gather)
        return len(model.module_list) - 1, node.input, node.output

    @staticmethod
    def convert_unsqueeze(node, model:_pt_model):
        attr_map = {
            "axes": "dim"
        }
        kwargs = {}
        for att in node.attribute:
            if att.name in attr_map:
                kwargs[attr_map[att.name]] = att.f
        unsqueeze = Unsqueeze(**kwargs)
        model.module_list.append(unsqueeze)
        return len(model.module_list) - 1, node.input, node.output

    @staticmethod
    def convert_concat(node, model:_pt_model):
        attr_map = {
            "axis": "dim"
        }
        kwargs = {}
        for att in node.attribute:
            if att.name in attr_map:
                kwargs[attr_map[att.name]] = att.f

        concat = Concat(**kwargs)
        model.module_list.append(concat)
        return len(model.module_list) - 1, node.input, node.output

    @staticmethod
    def convert_reshape(node, model:_pt_model):
        reshape = Reshape()
        model.module_list.append(reshape)
        return len(model.module_list) - 1, node.input, node.output

    # @staticmethod
    # def convert_matmul(node, model:_pt_model):
    #     class MatMul(nn.Module):
    #         def __init__(self):
    #             super().__init__()
    #         def forward(self,input1,input2):
    #             return input1 @ input2
    #     mul = MatMul()
    #     model.module_list.append(mul)
    #     return len(model.module_list)-1, node.input, node.output


    @staticmethod
    def convert_batchnormalization(node, model:_pt_model):
        attr_map = {
            "epsilon": "eps",
            "momentum": "momentum"
        }
        assert len(node.input) == 5
        assert len(node.output) == 1
        weight = model.loaded_weights[node.input[1]]
        bias = model.loaded_weights[node.input[2]]
        running_mean = model.loaded_weights[node.input[3]]
        running_var = model.loaded_weights[node.input[4]]
        del model.loaded_weights[node.input[1]]
        del model.loaded_weights[node.input[2]]
        del model.loaded_weights[node.input[3]]
        del model.loaded_weights[node.input[4]]
        dim = weight.shape[0]
        kwargs = {}
        # _check_attr(node.attribute, rebuild_batchnormalization.bn_attr_map)
        for att in node.attribute:
            if att.name in attr_map:
                kwargs[attr_map[att.name]] = att.f
        bn = None
        if model.dim_info[node.output[0]] == 5:
            bn = nn.BatchNorm3d(num_features=dim)
        elif model.dim_info[node.output[0]] == 4:
            bn = nn.BatchNorm2d(num_features=dim)
        elif model.dim_info[node.output[0]] == 2 or model.dim_info[node.output[0]] == 3:
            bn = nn.BatchNorm1d(num_features=dim)
        bn.weight.data = weight
        bn.bias.data = bias
        bn.running_mean.data = running_mean
        bn.running_var.data = running_var
        model.module_list.append(bn)
        return len(model.module_list)-1, node.input[:1], node.output

    @staticmethod
    def convert_add(node, model:_pt_model):
        add = Add()
        model.module_list.append(add)
        return len(model.module_list)-1, node.input, node.output

    @staticmethod
    def convert_mul(node, model:_pt_model):
        mul = Mul()
        model.module_list.append(mul)
        return len(model.module_list)-1, node.input, node.output

    @staticmethod
    def convert_averagepool(node, model:_pt_model):
        attr_map = {
            "pads": "padding",
            "strides": "stride",
            "kernel_shape": "kernel_size",
        }
        kwargs = {}
        for att in node.attribute:
            kwargs[attr_map[att.name]] = list(att.ints)
        if 'padding' in kwargs:
            assert (kwargs["padding"][0] == kwargs["padding"][2] and kwargs["padding"][1] == kwargs["padding"][3])
            kwargs["padding"] = kwargs["padding"][0], kwargs["padding"][1]
        ap = nn.AvgPool2d(**kwargs)
        model.module_list.append(ap)
        return len(model.module_list)-1, node.input, node.output

    @staticmethod
    def convert_globalaveragepool(node, model:_pt_model):
        gap = nn.AdaptiveAvgPool2d((1, 1))
        model.module_list.append(gap)
        model.module_list.append(gap)
        return len(model.module_list) - 1, node.input, node.output

    @staticmethod
    def convert_maxpool(node, model:_pt_model):
        attr_map = {
            "pads": "padding",
            "strides": "stride",
            "kernel_shape": "kernel_size",
        }
        kwargs = {}
        for att in node.attribute:
            kwargs[attr_map[att.name]] = list(att.ints)
        if 'padding' in kwargs:
            assert (kwargs["padding"][0] == kwargs["padding"][2] and kwargs["padding"][1] == kwargs["padding"][3])
            kwargs["padding"] = kwargs["padding"][0], kwargs["padding"][1]
        ap = nn.MaxPool2d(**kwargs)
        model.module_list.append(ap)
        return len(model.module_list) - 1, node.input, node.output

    @staticmethod
    def convert_flatten(node, model:_pt_model):
        if len(node.attribute) == 0:
            axis = 1
        else:
            axis = node.attribute[0].i
        if axis==1:
            flatten = nn.Flatten()
            model.module_list.append(flatten)
            return len(model.module_list)-1, node.input, node.output
        else:
            raise NotImplementedError("Not Implemented yet!")

    @staticmethod
    def convert_gemm(node, model:_pt_model):
        weight = model.loaded_weights[node.input[1]]
        bias = model.loaded_weights[node.input[2]]
        del model.loaded_weights[node.input[2]]
        del model.loaded_weights[node.input[1]]
        in_features = weight.shape[1]
        out_features = weight.shape[0]
        linear = nn.Linear(in_features=in_features, out_features=out_features)
        linear.weight.data = weight
        linear.bias.data = bias
        model.module_list.append(linear)
        return len(model.module_list)-1, node.input[:1], node.output

    @staticmethod
    def convert_pad(node, model:_pt_model):
        mode = node.attribute[0].s
        pads = list(node.attribute[1].ints)
        value = node.attribute[2].f
        try:
            assert(mode == b'constant')
            assert(sum(pads[:4]) == 0)
        except AssertionError:
            print("Now only support converting to nn.ConstantPad2d")
        pad = nn.ConstantPad2d([*pads[2:4],*pads[3:5]],value)
        model.module_list.append(pad)
        return len(model.module_list)-1, node.input, node.output
