import numpy as np
from causalnex.structure.structuremodel import StructureModel
import pandas as pd
import torch

def abs_(x: torch.Tensor): 
    pos_mask = x >= 0
    neg_mask = x < 0
    
    return x * pos_mask - x * neg_mask

def lin(x): return x

def get_solve_matrix(sm: StructureModel, p: int, d_vars: int):
    solve = np.zeros((p + 1, d_vars, d_vars))
    for edge in sm.edges():
        start_node = edge[0].split('_lag')[0]
        end_node = edge[1].split('_lag')[0]
        lag = int(edge[0].split('_lag')[1])
        solve[lag, int(start_node), int(end_node)] = 1
    return solve

def get_gt_matrix(links_infos: dict, p: int):
    gt = np.zeros((p + 1, len(links_infos), len(links_infos)))
    for end_node, node_edges in links_infos.items():
        for (start_node, lag), _ in node_edges:
            gt[abs(lag), int(start_node), int(end_node)] = 1
    return gt

def get_result_matrix(solve: np.ndarray,
                      gt: np.ndarray,
                      eloss: float,
                      exist_edges_mask: np.ndarray,
                      w_est: np.ndarray,
                      a_est: np.ndarray,
                      w_threshold: float,
                      p: int, # Added p
                      d_vars: int # Added d_vars
                     ):
    WEAK_EDGE_TOLERANCE = 1e-4
    expected_shape = ((p + 1) * d_vars, d_vars)

    num_timestep_rows = solve.shape[0]
    result_matrix = np.zeros((2 + num_timestep_rows, 8)) # Now 8 columns

    weak_prior_correct_count = 0
    weak_prior_count = 0 # Initialize weak_prior_count
    if w_est is not None and (p == 0 or a_est is not None):
        # Combine raw weights into the 2D ((p+1)*d, d) shape
        combined_weights = np.zeros(expected_shape, dtype=float)
        combined_weights[:d_vars, :] = w_est
        if p > 0:
            combined_weights[d_vars:, :] = a_est
        abs_weights = np.abs(combined_weights)
        lower_bound = w_threshold - WEAK_EDGE_TOLERANCE
        if lower_bound < 0:
            lower_bound = 0
        is_weak = (abs_weights < w_threshold + WEAK_EDGE_TOLERANCE) & (abs_weights > lower_bound)
        is_prior = (exist_edges_mask != 0)
        is_correct = (gt != 0) # Note: gt here is the full (p+1)*d x d matrix
        
        # Reshape is_weak to match gt's shape for element-wise comparison within lags
        # Assuming is_weak corresponds to the full (p+1)*d x d structure
        is_weak_reshaped = is_weak.reshape((p + 1), d_vars, d_vars)
        
        # Reshape is_prior to match gt's shape for element-wise comparison
        # is_prior mask is (d, d), needs to be broadcast or reshaped to (p+1, d, d)
        is_prior_reshaped = np.tile(is_prior, (p + 1, 1, 1)) # Tile across the lag dimension
        
        weak_prior_correct_count = int(np.sum(is_weak_reshaped & is_correct.astype(bool) & is_prior_reshaped.astype(bool)))
        weak_prior_count = int(np.sum(is_weak_reshaped & is_prior_reshaped.astype(bool)))
    else:
        print("Warning: Raw weights (w_est/a_est) not fully available for weak edge calculation.")

    # Calculate SHD for the overall time series matrix
    # SHD = Number of incorrect edges (FP + FN)
    shd_timeseries = np.sum(solve != gt)

    solve_any = np.any(solve, axis=0) # Shape (d_vars, d_vars) - True if any lag connects i->j
    gt_any = np.any(gt, axis=0)       # Shape (d_vars, d_vars) - True if any lag connects i->j
    # Calculate SHD for the "any" connection summary
    shd_any = np.sum(solve_any != gt_any)

    path_recovery_denom = np.sum(exist_edges_mask)
    path_recovery = np.sum(solve_any.astype(bool) & exist_edges_mask.astype(bool)) / path_recovery_denom if path_recovery_denom > 0 else 0


    # Summary based on 'any' connection (original logic)
    edge_num_correct_any = np.sum(np.logical_and(solve_any, gt_any))
    solve_any_sum = np.sum(solve_any)
    gt_any_sum = np.sum(gt_any)
    edge_num_accuracy_any = edge_num_correct_any / solve_any_sum if solve_any_sum > 0 else 0
    edge_num_recall_any = edge_num_correct_any / gt_any_sum if gt_any_sum > 0 else 0
    edge_num_f1_any = 2 * edge_num_recall_any * edge_num_accuracy_any / (edge_num_recall_any + edge_num_accuracy_any) if (edge_num_recall_any + edge_num_accuracy_any) != 0 else 0
    # Assign to row 1 (index 1)
    # Added shd_any at index 3
    result_matrix[1] = [edge_num_accuracy_any, edge_num_recall_any, edge_num_f1_any, shd_any, eloss, path_recovery, weak_prior_correct_count, weak_prior_count]

    # Calculate overall timeseries metrics (row 0 - based on exact edge match)
    edge_num_timeseries_correct = np.sum(np.logical_and(solve, gt))
    solve_sum = np.sum(solve)
    gt_sum = np.sum(gt)
    edge_num_timeseries_accuracy = edge_num_timeseries_correct / solve_sum if solve_sum > 0 else 0
    edge_num_timeseries_recall = edge_num_timeseries_correct / gt_sum if gt_sum > 0 else 0
    edge_num_timeseries_f1 = 2 * edge_num_timeseries_recall * edge_num_timeseries_accuracy / (edge_num_timeseries_recall + edge_num_timeseries_accuracy) if (edge_num_timeseries_recall + edge_num_timeseries_accuracy) != 0 else 0
    # Assign to row 0 (index 0)
    # Added shd_timeseries at index 3
    result_matrix[0] = [edge_num_timeseries_accuracy, edge_num_timeseries_recall, edge_num_timeseries_f1, shd_timeseries, eloss, path_recovery, weak_prior_correct_count, weak_prior_count]

    # Calculate metrics for each "timestep row" (row index in the ((p+1)*d, d) matrix)
    # Rows 2 to 2 + num_timestep_rows - 1
    for i in range(num_timestep_rows):
        solve_i = solve[i] # Row i (shape d_vars, d_vars)
        gt_i = gt[i]       # Row i (shape d_vars, d_vars)

        # Calculate SHD for this specific lag
        shd_lag_i = np.sum(solve_i != gt_i)

        edge_num_correct_i = np.sum(np.logical_and(solve_i, gt_i))
        solve_i_sum = np.sum(solve_i)
        gt_i_sum = np.sum(gt_i)
        edge_num_accuracy_i = edge_num_correct_i / solve_i_sum if solve_i_sum != 0 else 0
        edge_num_recall_i = edge_num_correct_i / gt_i_sum if gt_i_sum != 0 else 0
        edge_num_f1_i = 2 * edge_num_recall_i * edge_num_accuracy_i / (edge_num_recall_i + edge_num_accuracy_i) if (edge_num_recall_i + edge_num_accuracy_i) != 0 else 0
        # Assign to row i + 2
        # Added shd_lag_i at index 3
        result_matrix[i + 2] = [edge_num_accuracy_i, edge_num_recall_i, edge_num_f1_i, shd_lag_i, eloss, path_recovery, weak_prior_correct_count, weak_prior_count]

    return result_matrix



def read_markdown_table(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()[2:]
        data = [line.strip().strip('|').split('|') for line in lines]
        df = pd.DataFrame(data, columns=['name', 'accuracy', 'recall', 'f1', 'shd', 'edge_loss', "edge_recovery", "weak_correct_edge_num", "weak_edge_num"])
        df[['accuracy', 'recall', 'f1', 'shd','edge_loss', "edge_recovery", "weak_correct_edge_num", "weak_edge_num"]] = df[['accuracy', 'recall', 'f1', 'shd', 'edge_loss', "edge_recovery", "weak_correct_edge_num", "weak_edge_num"]].astype(float)
    return df
