import torch

def solve_by_creating_edge(a,b, orig_source_list, orig_target_list, all_edges, source_list, target_list, edge_attrs, processed_ids, dict_orig_to_ide_graph, new_weight):
    flag_inserted= False
    if a != b:
        if (a,b) not in all_edges: 
            orig_target_list.append(b)
            orig_source_list.append(a)
            all_edges.append((a,b))
            flag_inserted= True

            if a in processed_ids:
                source_list.append(processed_ids.index(a))                      #requires int from 0 to len
            else:
                source_list.append(dict_orig_to_ide_graph[a])                   #requires int from 0 to len
            if b in processed_ids:
                target_list.append(processed_ids.index(b))                      #requires int from 0 to len
            else:
                target_list.append(dict_orig_to_ide_graph[b])                   #requires int from 0 to len
            if new_weight!= None:
                edge_attrs.append(new_weight)
        else:  
            if new_weight!= None:                                               #check if the edge already exists and update the weight
                pos_edge = all_edges.index((a,b))                               
                if edge_attrs[pos_edge] < new_weight:
                    edge_attrs[pos_edge]= new_weight
        
    return orig_source_list, orig_target_list, all_edges, source_list, target_list, edge_attrs, flag_inserted



def get_threshold(input_tensor, degree_std= 1, unbiased=True, type="mean", mode="local"):
    if mode=="global":
        input_tensor = torch.reshape(input_tensor, (-1,))
        max, mean, std = torch.max(input_tensor), torch.mean(input_tensor), torch.std(input_tensor, unbiased)
    elif mode == "local":
        max, mean, std = torch.max(input_tensor, dim=1).values, torch.mean(input_tensor, dim=1), torch.std(input_tensor, dim=1)

    if type=="mean":
        min_value = mean - std*-degree_std  # Negative degree_std for less tolerance at filtering
      
    elif type=="max":
        min_value = max - std*degree_std

    else:
        return max, mean, std, None
        
    return max, mean, std, min_value



def filtering_matrix(doc_att, valid_sents, degree_std=0.5, with_filtering=True, filtering_type="mean", granularity="local"):
    # This function filters a single attention matrix based on a threshold calculated from the mean and standard deviation of the attention weights.
    cropped_matrix = doc_att[:valid_sents,:valid_sents]
            
    if with_filtering:
        max_v, mean, std, threshold_min = get_threshold(cropped_matrix, degree_std, type=filtering_type, mode=granularity)  

        if granularity=="local":
            filtered_matrix = torch.Tensor(cropped_matrix.size())
            for i in range(cropped_matrix.size(0)):
                filtered_matrix[i]= torch.where(cropped_matrix[i] < threshold_min[i], 0., cropped_matrix[i])
        else:
            filtered_matrix = torch.where(cropped_matrix < threshold_min, 0., cropped_matrix.double()) #mean

    else:        
        filtered_matrix = cropped_matrix            
    
    return filtered_matrix
