import numpy as np
from utils.calculate_similarity import get_SimilarityCalculator
from builder.synaptic_visualization import visualize_synaptic_net
import os
import argparse


class SynapticNode:
    def __init__(self, name, feature):
        """
        突触网络节点类
        :param name: 节点名称（个体标识）
        :param feature: 元型特征向量
        """
        self.name = name
        self.prototype = feature
        self.connections = {}  # {target_node: (similarity, strength)}
        self.sample_path = {'data': [], 'label': []}
        self.model_path = []
        self.step = 1

    def connection_count(self):
        return len(self.connections)

    def get_average_strength(self):
        # return sum(strength for _, strength in self.connections.values())

        return np.mean([strength for _, strength in self.connections.values()])

    def get_target_importance(self, alpha=0.6, beta=0.4):
        """Importance = alpha*Similarity + beta*target_nodes_average_strength"""
        # 没有节点连接或者只有一个节点连接时，直接返回
        importance = dict()
        if len(self.connections) < 1:
            return importance

        connect_info = []
        for connected_node, (similarity, _) in self.connections.items():
            connected_node_average_strength = connected_node.get_average_strength()
            connect_info.append([connected_node, similarity, connected_node_average_strength])

        strength_sum = sum([n[2] for n in connect_info])

        for i in range(len(connect_info)):
            connect_info[i].append(connect_info[i][2] / strength_sum)
        for node in connect_info:
            importance[node[0].name] = {'importance': alpha*node[1]+beta*node[3],
                                        'similarity': node[1],
                                        'average_strength': node[2]}
        return importance

    def find_importance_knn(self, args):
        """Fink top-K important nodes"""
        if not self.connections:
            return None
        k = args.k_factor
        importance = self.get_target_importance(alpha=args.alpha_sim, beta=args.alpha_str).items()
        importance_info = []
        for (connect_name, connect_info) in importance:
            importance_info.append([connect_name, connect_info['importance']])
        sum_ = sum([i[1] for i in importance_info])
        importance_info = sorted([[i[0], i[1] / sum_] for i in importance_info], key=lambda x: x[1], reverse=True)
        if len(importance_info) > k:
            importance_info = importance_info[:k]
            sum_2 = sum([i[1] for i in importance_info])
            importance_info = [[i[0], i[1] / sum_2] for i in importance_info]
        return importance_info

    def add_connection(self, target_node, similarity):
        if target_node == self:
            return

        if self.connections.get(target_node):
            return

        new_strength = self.connections.get(target_node, (0, 0))[1] + 1
        self.connections[target_node] = (similarity, new_strength)
        target_node.connections[self] = (similarity, new_strength)

    def renormalization_synapses(self, decay_factor=20):
        for node, (sim, strength) in self.connections.items():
            # new_strength = (strength * decay_factor)
            new_strength = strength * np.exp(-self.step / decay_factor)  # 艾宾浩斯遗忘曲线
            self.connections[node] = (sim, new_strength)
        self.step += 1

    def consolidate_memory(self, consolidation_node_list, consolidation_factor=1.3, up_bound=3):
        for node in self.connections:
            sim, strength = self.connections[node]
            if strength != 1:
                update_strength = strength * consolidation_factor
                if update_strength >= up_bound:
                    update_strength = 3
                self.connections[node] = (sim, update_strength)
                # 判断目标节点node是否在所有增强节点内，如果不在，更新目标节点连接强度,在的话, 下一次自会更新
                if node not in consolidation_node_list:
                    node.connections[self] = (sim, update_strength)
        self.step = 1

    def find_strength_knn(self, k=3):

        return sorted(self.connections.items(),
                      key=lambda x: (-x[1][1], -x[1][0]))[:k]

    def __repr__(self):
        return f"<SynapticNode {self.name} | Connections: {self.connection_count}>"


class SynapticNetwork:
    def __init__(self, args):
        self.nodes = {}
        self.args = args
        if self.args.dataset == 'ISRUC':
            self.sim_calculator = get_SimilarityCalculator(args=self.args, n_channels=6)
        if self.args.dataset == 'FACED':
            self.sim_calculator = get_SimilarityCalculator(args=self.args, n_channels=32)
        if self.args.dataset == 'BCI2000':
            self.sim_calculator = get_SimilarityCalculator(args=self.args, n_channels=64)

    def net_initialization(self, initial_info, initial_model_path, seed, similarity_threshold=0.6):
        """
        :param initial_info: [[sub_name, sub_prototype, sub_data_path, sub_label_path]]
        :param initial_model_path: initial model path  .pkl
        :param similarity_threshold: similarity threshold
        :return:
        """
        # 实例化源域个体节点
        for subject in initial_info:
            sub_name, sub_prototype = subject[0], subject[1]
            sub_data_path, sub_label_path = subject[2], subject[3]
            new_node = SynapticNode(sub_name, sub_prototype)
            self.nodes[sub_name] = new_node
            self.nodes[sub_name].sample_path = {'data': sub_data_path, 'label': sub_label_path}
            self.nodes[sub_name].model_path.append(initial_model_path + f'feature_extractor_parameter_{seed}.pkl')
            self.nodes[sub_name].model_path.append(initial_model_path + f'feature_encoder_parameter_{seed}.pkl')
            self.nodes[sub_name].model_path.append(initial_model_path + f'sleep_classifier_parameter_{seed}.pkl')

        for idx, node in enumerate(list(self.nodes.values())):
            for nt_node in list(self.nodes.values())[idx+1:]:
                sim = self.sim_calculator.domain_weighted_cosine(node.prototype, nt_node.prototype)
                if sim > similarity_threshold:
                    node.add_connection(nt_node, sim)

    def add_node(self, name, feature, args, similarity_threshold=0.6):
        new_node = SynapticNode(name, feature)
        self.nodes[name] = new_node

        for exist_node in self.nodes.values():
            if exist_node.name == name:
                continue

            sim = self.sim_calculator.domain_weighted_cosine(exist_node.prototype, feature)
            if sim > similarity_threshold:
                exist_node.add_connection(new_node, sim)

        '''Consolidate Importance TOP K node strength'''
        if self.nodes[name].connection_count():
            consolidation_nodes = [self.nodes[connect_sub[0]] for connect_sub in self.nodes[name].find_importance_knn(args)]
            for connected_node in consolidation_nodes:
                connected_node.consolidate_memory(consolidation_nodes, consolidation_factor=self.args.consolidation_factor)

    def global_renormalization_synapses(self):
        for node in self.nodes.values():
            node.renormalization_synapses(decay_factor=self.args.decay_factor)

    def visualize_network(self):
        visualize_synaptic_net(self.nodes, self.args)

