import os
import numpy as np
import random
import torch.nn as nn
import torch
def print_file(str_, save_file_path=None):
    print(str_)
    if save_file_path != None:
        f = open(save_file_path, 'a')
        print(str_, file=f)

def negative_sampling_from_node_pairs(edge_index, node_id):
    existing_edges = set()
    edge_np = edge_index.t().tolist()
    for u, v in edge_np:
        existing_edges.add((min(u, v), max(u, v)))
    pos_edges = node_id.t().tolist()
    node_set = set()
    for u, v in pos_edges:
        node_set.add(u)
        node_set.add(v)
    node_list = list(node_set)
    neg_edges = set()
    while len(neg_edges) < len(pos_edges):
        u = random.choice(node_list)
        v = random.choice(node_list)
        if u == v:
            continue
        edge = (min(u, v), max(u, v))
        if edge not in existing_edges and edge not in neg_edges:
            neg_edges.add(edge)
    neg_edges = list(neg_edges)
    neg_tensor = torch.tensor(neg_edges).t()
    return neg_tensor

class Metrictor_PPI:
    def __init__(self, pre_y, truth_y, is_binary=False):
        self.TP = 0
        self.FP = 0
        self.TN = 0
        self.FN = 0

        if is_binary:
            length = pre_y.shape[0]
            for i in range(length):
                if pre_y[i] == truth_y[i]:
                    if truth_y[i] == 1:
                        self.TP += 1
                    else:
                        self.TN += 1
                elif truth_y[i] == 1:
                    self.FN += 1
                elif pre_y[i] == 1:
                    self.FP += 1
            self.num = length

        else:
            N, C = pre_y.shape
            for i in range(N):
                for j in range(C):
                    if pre_y[i][j] == truth_y[i][j]:
                        if truth_y[i][j] == 1:
                            self.TP += 1
                        else:
                            self.TN += 1
                    elif truth_y[i][j] == 1:
                        self.FN += 1
                    elif truth_y[i][j] == 0:
                        self.FP += 1
            self.num = N * C
    
    def show_result(self, is_print=False, file=None):
        self.Accuracy = (self.TP + self.TN) / (self.num + 1e-10)
        self.Precision = self.TP / (self.TP + self.FP + 1e-10)
        self.Recall = self.TP / (self.TP + self.FN + 1e-10)
        self.F1 = 2 * self.Precision * self.Recall / (self.Precision + self.Recall + 1e-10)
        if is_print:
            print_file("Accuracy: {}".format(self.Accuracy), file)
            print_file("Precision: {}".format(self.Precision), file)
            print_file("Recall: {}".format(self.Recall), file)
            print_file("F1-Score: {}".format(self.F1), file)

class UnionFindSet(object):
    def __init__(self, m):
        # m, n = len(grid), len(grid[0])
        self.roots = [i for i in range(m)]
        self.rank = [0 for i in range(m)]
        self.count = m
        
        for i in range(m):
            self.roots[i] = i
 
    def find(self, member):
        tmp = []
        while member != self.roots[member]:
            tmp.append(member)
            member = self.roots[member]
        for root in tmp:
            self.roots[root] = member
        return member
        
    def union(self, p, q):
        parentP = self.find(p)
        parentQ = self.find(q)
        if parentP != parentQ:
            if self.rank[parentP] > self.rank[parentQ]:
                self.roots[parentQ] = parentP
            elif self.rank[parentP] < self.rank[parentQ]:
                self.roots[parentP] = parentQ
            else:
                self.roots[parentQ] = parentP
                self.rank[parentP] -= 1
            self.count -= 1


def get_bfs_sub_graph(ppi_list, node_num, node_to_edge_index, sub_graph_size):

    candiate_node = []
    selected_edge_index = []
    selected_node = []

    random_node = random.randint(0, node_num - 1)
    while len(node_to_edge_index[random_node]) > 5:
        random_node = random.randint(0, node_num - 1)
    candiate_node.append(random_node)

    while len(selected_edge_index) < sub_graph_size:
        cur_node = candiate_node.pop(0)
        selected_node.append(cur_node)
        for edge_index in node_to_edge_index[cur_node]:

            if edge_index not in selected_edge_index:
                selected_edge_index.append(edge_index)

                end_node = -1
                if ppi_list[edge_index][0] == cur_node:
                    end_node = ppi_list[edge_index][1]
                else:
                    end_node = ppi_list[edge_index][0]

                if end_node not in selected_node and end_node not in candiate_node:
                    candiate_node.append(end_node)
            else:
                continue
        # print(len(selected_edge_index), len(candiate_node))
    node_list = candiate_node + selected_node
    # print(len(node_list), len(selected_edge_index))
    return selected_edge_index

def get_dfs_sub_graph(ppi_list, node_num, node_to_edge_index, sub_graph_size):
    
    stack = []
    selected_edge_index = []
    selected_node = []

    random_node = random.randint(0, node_num - 1)
    while len(node_to_edge_index[random_node]) > 5:
        random_node = random.randint(0, node_num - 1)
    stack.append(random_node)

    while len(selected_edge_index) < sub_graph_size:
        # print(len(selected_edge_index), len(stack), len(selected_node))
        cur_node = stack[-1]
        if cur_node in selected_node:
            flag = True
            for edge_index in node_to_edge_index[cur_node]:
                if flag:
                    end_node = -1
                    if ppi_list[edge_index][0] == cur_node:
                        end_node = ppi_list[edge_index][1]
                    else:
                        end_node = ppi_list[edge_index][0]
                    
                    if end_node in selected_node:
                        continue
                    else:
                        stack.append(end_node)
                        flag = False
                else:
                    break
            if flag:
                stack.pop()
            continue
        else:
            selected_node.append(cur_node)
            for edge_index in node_to_edge_index[cur_node]:
                if edge_index not in selected_edge_index:
                    selected_edge_index.append(edge_index)
    
    return selected_edge_index


class GateLoss(nn.Module):
    def __init__(self, reduction='mean'):
        super().__init__()
        self.reduction = reduction
    
    def forward(self, prob, label):

        prob_grouped = prob.view(int(prob.shape[0] / 7), 7)
        
        prob_sums = prob_grouped.sum(dim=1, keepdim=False)
        
        label_sums = label.sum(dim=1, keepdim=False)
        
        loss = (prob_sums - label_sums) ** 2
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss
        

def print_file_color(format_str, save_file_path=None, color_dict=None, **kwargs):

    import re
    required_keys = set(re.findall(r'\{(\w+)\}', format_str))
    missing_keys = required_keys - set(kwargs.keys())
    if missing_keys:
        raise ValueError(f"Missing format parameters: {missing_keys}")

    colored_str = format_str
    plain_str = format_str

    color_map = {
        'green': '\033[32m',
        'red': '\033[31m',
        'yellow': '\033[33m',
        'blue': '\033[34m',
        'magenta': '\033[35m',
        'cyan': '\033[36m',
        'white': '\033[37m',
        'bold': '\033[1m',
        'underline': '\033[4m',
    }

    if color_dict is None:
        color_dict = {}

    for key, color_name in color_dict.items():
        if key in kwargs:
            color_name = color_name.lower()
            color_code = color_map.get(color_name, '')
            reset_code = '\033[0m' if color_code else ''
            colored_str = colored_str.replace(
                f"{{{key}}}", 
                f"{color_code}{{{key}}}{reset_code}"
            )
    try:
        formatted_colored = colored_str.format(**kwargs)
        formatted_plain = plain_str.format(**kwargs)
    except KeyError as e:
        missing_key = str(e).strip("'")
        raise ValueError(f"Format string contains unprovided variable: {missing_key}") from e

    print(formatted_colored)

    if save_file_path is not None:
        with open(save_file_path, 'a') as f:
            print(formatted_plain, file=f)