import ast
import re
import pandas as pd
from collections import defaultdict
from typing import Union, List, Optional
from pathlib import Path
import pickle

def parse_str_to_dict(s):
    """Convert string representation of dictionary to actual dictionary."""
    try:
        return ast.literal_eval(s)
    except:
        return {}
    
def find_non_fact_definite_rules_from_proof(proof):
    """
    Extract non-fact definite rules from a proof string.
    
    Args:
        proof (str): The proof string containing facts and rules separated by '|'
        
    Returns:
        list: List of non-fact definite rules found in the proof
    """
    if not proof:
        return []
    
    rules = []
    # Split the proof by '|' to get individual components
    components = [comp.strip() for comp in proof.split('|')]
    
    for comp in components:
        # Skip facts (they start with 'fact:')
        if comp.startswith('fact:'):
            continue
        # Skip empty components
        if not comp:
            continue
        # Add the rule (already stripped)
        rules.append(comp)
    
    return rules

def extracting_applied_rules_from_branch(branch_derivation, branch_result):
    """
    Extract applied rules from a single branch based on its result.
    
    Args:
        branch_derivation (dict): The derivation dictionary for the branch
        branch_result (str): The result of the branch ('unique stable model' or 'contradiction')
        
    Returns:
        list: List of applied rules for this branch
    """
    applied_rules = []
    
    if branch_result == 'unique stable model':
        # For each atom proof in the branch
        for atom, proof in branch_derivation.items():
            # Skip inequality predicates
            if atom.startswith('!='):
                continue
            applied_rules.extend(find_non_fact_definite_rules_from_proof(proof))
    
    elif branch_result == 'contradiction':
        contradiction_atoms = []
        
        for atom, proof in branch_derivation.items():
            if atom.startswith('!='):
                contradiction_atoms.append(atom)
                continue
            
            rules = find_non_fact_definite_rules_from_proof(proof)
            applied_rules.extend(rules)
            contradiction_atoms.append(atom)
        
        # Add the contradiction rule (all atoms joined with ', ')
        if contradiction_atoms:
            contradiction_rule = ':- '+ ', '.join(contradiction_atoms) + '.'
            applied_rules.append(contradiction_rule)
            # print(f'we have contradiction {contradiction_rule} ')
    return applied_rules

def extracting_applied_rules_from_example(row):
    """
    Extract all applied rules from an example (row in dataframe).
    
    Args:
        row (pd.Series): A row from the dataframe containing:
            - derivation_chain (str): String representation of derivation chain dict
            - branch_results (str): String representation of branch results dict
            - graph_complexity_stats (str): String representation of graph stats dict
            
    Returns:
        list: List of all applied rules for this example
    """
    # Parse the string representations to dictionaries
    derivation_chain = parse_str_to_dict(row['derivation_chain'])
    branch_results = parse_str_to_dict(row['branch_results'])
    graph_stats = parse_str_to_dict(row['graph_complexity_stats'])
    
    if not derivation_chain or not branch_results or not graph_stats:
        return []
    
    all_applied_rules = []
    
    # Get the primary branches from graph_complexity_stats keys
    primary_branches = []
    for p_key in graph_stats.keys():
        if isinstance(p_key, (tuple, list)):
            primary_branches.append(p_key[0])
        else:
            primary_branches.append(p_key)
    
    # Process each primary branch
    for branch_num in primary_branches:
        branch_derivation = derivation_chain.get(branch_num, {})
        branch_result = branch_results.get(branch_num, '')
        
        if not branch_derivation or not branch_result:
            continue
            
        applied_rules = extracting_applied_rules_from_branch(branch_derivation, branch_result)
        all_applied_rules.extend(applied_rules)
    
    return list(set(all_applied_rules))

def process_world_rules(world_rules_file):
    """
    Process world rules from a file and create dictionaries mapping:
    - line numbers to world rules
    - body predicate tuples to line numbers
    - head predicate tuples to line numbers
    
    Args:
        world_rules_file (str): Path to file containing world rules (one per line)
        
    Returns:
        tuple: (dict_index_to_wr, dict_body_to_wr, dict_head_to_wr)
            dict_index_to_wr: line_number -> world_rule (str)
            dict_body_to_wr: body_predicates_tuple -> list of line numbers
            dict_head_to_wr: head_predicates_tuple -> list of line numbers
    """
    dict_index_to_wr = {}
    dict_body_to_wr = defaultdict(list)
    dict_head_to_wr = defaultdict(list)
    
    with open(world_rules_file, 'r') as f:
        for line_num, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                continue
                
            dict_index_to_wr[line_num] = line
            
            # Handle constraints (start with ':-')
            if line.startswith(':-'):
                # Constraint has no head, just body after ':-'
                body_part = line[2:].rstrip('.')  # Remove ':-' and trailing '.'
                head_predicates = tuple()
                body_predicates = _extract_predicates(body_part.strip())
            else:
                # Handle normal rules and facts
                if ':-' in line:
                    # Normal rule with head and body
                    head_part, body_part = line.split(':-', 1)
                    head_predicates = _extract_predicates(head_part.strip())
                    body_predicates = _extract_predicates(body_part.rstrip('.').strip())
                else:
                    # Fact (no body)
                    head_predicates = _extract_predicates(line.rstrip('.').strip())
                    body_predicates = tuple()
            
            # Store in dictionaries
            dict_body_to_wr[body_predicates].append(line_num)
            dict_head_to_wr[head_predicates].append(line_num)
    
    return dict_index_to_wr, dict_body_to_wr, dict_head_to_wr


def _extract_predicates(part):
    """
    Helper function to extract predicates from a rule part (head or body).
    Special handling for != which should be extracted as just '!='.
    
    Args:
        part (str): The head or body part of a rule
        
    Returns:
        tuple: Sorted tuple of predicate names
    """
    if not part:
        return tuple()
    
    # Handle case where part is just a constraint (no head)
    if part.startswith(':-'):
        part = part[2:].strip()
    
    # Split into individual predicates
    predicates = []
    current_pred = []
    paren_count = 0
    
    for char in part:
        if char == '(':
            paren_count += 1
        elif char == ')':
            paren_count -= 1
        
        if char == ',' and paren_count == 0:
            # This comma separates predicates
            pred_str = ''.join(current_pred).strip()
            pred_name = _process_predicate_string(pred_str)
            if pred_name:  # Skip empty predicates
                predicates.append(pred_name)
            current_pred = []
        else:
            current_pred.append(char)
    
    # Add the last predicate
    if current_pred:
        pred_str = ''.join(current_pred).strip()
        pred_name = _process_predicate_string(pred_str)
        if pred_name:
            predicates.append(pred_name)
    
    # Sort alphabetically and return as tuple
    return tuple(sorted(predicates))

def _process_predicate_string(pred_str):
    """
    Process a single predicate string to extract the predicate name.
    Special handling for != which should be returned as '!='.
    """
    if '!=' in pred_str:
        return '!='
    return pred_str.split('(')[0]

def find_world_rules_used_for_example(row, dict_index_to_wr, dict_body_to_wr, dict_head_to_wr):
    """
    Find corresponding world rules for all applied rules in a row.
    
    Args:
        row: DataFrame row containing example data
        dict_index_to_wr: line_number -> world_rule
        dict_body_to_wr: body_tuple -> list of line numbers
        dict_head_to_wr: head_tuple -> list of line numbers
        
    Returns:
        row with two new columns:
            ground_to_world: dict mapping applied rules to world rules
            set_of_world_used: set of line numbers of world rules used
    """
    applied_rules = extracting_applied_rules_from_example(row)
    ground_to_world = {}
    set_of_world_used = set()
    
    for p_rule in applied_rules:
        # Handle constraint case (nothing but possibly whitespace before ':-')
        if p_rule.lstrip().startswith(':-'):
            # This is a constraint - no head, just body after ':-'
            body_part = p_rule.split(':-', 1)[1].strip()
            if body_part.endswith('.'):
                body_part = body_part[:-1].strip()  # Remove trailing '.' if present
            head_predicates = tuple()
            body_predicates = _extract_predicates(body_part)
        else:
            # Handle normal rules and facts
            if ':-' in p_rule:
                head_part, body_part = p_rule.split(':-', 1)
                head_predicates = _extract_predicates(head_part.strip())
                body_predicates = _extract_predicates(body_part.strip())
            else:
                # Handle facts (no body)
                head_predicates = _extract_predicates(p_rule.strip())
                body_predicates = tuple()
        
        # Find matching world rules
        matching_body_lines = dict_body_to_wr.get(body_predicates, [])
        matching_head_lines = dict_head_to_wr.get(head_predicates, [])
        candidate_lines = set(matching_body_lines) & set(matching_head_lines)
        
        if not candidate_lines:
            raise ValueError(
                f"No matching world rule found for applied rule:\n{p_rule}\n"
                f"Head predicates: {head_predicates}\n"
                f"Body predicates: {body_predicates}"
            )
        
        if len(candidate_lines) == 1:
            # Single candidate - use it directly
            line_num = candidate_lines.pop()
            world_rule = dict_index_to_wr[line_num]
            ground_to_world[p_rule] = world_rule
            set_of_world_used.add(line_num)
        else:
            # Multiple candidates - need to find matching one
            candidates = [(ln, dict_index_to_wr[ln]) for ln in candidate_lines]
            matched_line_num = find_matching_world_rule(p_rule, candidates)
            print(f"Multiple candidate world rules found for applied rule: {p_rule} \n OPTIONS ARE {candidates}. \n we will use {dict_index_to_wr[matched_line_num]}\n")
            if matched_line_num is None:
                raise ValueError(
                    f"No matching world rule found among candidates for:\n{p_rule}\n"
                    f"Candidates were:\n" + "\n".join(dict_index_to_wr[ln] for ln in candidate_lines)
                )
            
            ground_to_world[p_rule] = dict_index_to_wr[matched_line_num]
            set_of_world_used.add(matched_line_num)
    
    row['ground_to_world'] = ground_to_world
    row['set_of_world_used'] = set_of_world_used
    return row

def find_matching_world_rule(p_rule, candidates):
    """
    Find which candidate world rule matches the applied rule through variable substitution.
    """
    # Parse p_rule into head and body predicates (keeping whole predicates intact)
    if ':-' in p_rule:
        p_head_part, p_body_part = p_rule.split(':-', 1)
        p_head_preds = _split_predicates(p_head_part.strip())
        p_body_preds = _split_predicates(p_body_part.strip())
    else:
        p_head_preds = _split_predicates(p_rule.strip())
        p_body_preds = []
    
    for line_num, candidate in candidates:
        # Parse candidate rule
        if ':-' in candidate:
            c_head_part, c_body_part = candidate.split(':-', 1)
            c_head_preds = _split_predicates(c_head_part.strip())
            c_body_preds = _split_predicates(c_body_part.strip())
        else:
            c_head_preds = _split_predicates(candidate.strip())
            c_body_preds = []
        
        # Check number of predicates match
        if len(p_head_preds) != len(c_head_preds) or len(p_body_preds) != len(c_body_preds):
            continue
        
        # Try variable substitution
        substitution = {}
        match = True
        
        # Check head predicates
        for p_pred, c_pred in zip(p_head_preds, c_head_preds):
            if not _predicates_match(p_pred, c_pred, substitution):
                match = False
                break
        
        if not match:
            continue
        
        # Check body predicates
        for p_pred, c_pred in zip(p_body_preds, c_body_preds):
            if not _predicates_match(p_pred, c_pred, substitution):
                match = False
                break
        
        if match:
            return line_num
    
    return None

def _split_predicates(predicate_str):
    """Split a string of predicates into individual predicates while handling nested parentheses"""
    predicates = []
    current = []
    paren_level = 0
    
    for char in predicate_str:
        if char == '(':
            paren_level += 1
        elif char == ')':
            paren_level -= 1
        
        if char == ',' and paren_level == 0:
            predicates.append(''.join(current).strip())
            current = []
        else:
            current.append(char)
    
    if current:
        predicates.append(''.join(current).strip())
    
    return predicates

def _predicates_match(p_pred, c_pred, substitution):
    """
    Check if two predicates match with substitution.
    p_pred is grounded, c_pred has variables.
    """
    # Split into name and args
    p_name, p_args = _split_predicate(p_pred)
    c_name, c_args = _split_predicate(c_pred)
    
    # Check names match
    if p_name != c_name:
        return False
    
    # Check args match
    if len(p_args) != len(c_args):
        return False
    
    for p_arg, c_arg in zip(p_args, c_args):
        if c_arg.isupper():  # Variable in candidate
            if c_arg in substitution:
                if substitution[c_arg] != p_arg:
                    return False
            else:
                substitution[c_arg] = p_arg
        else:  # Constant in candidate
            if c_arg != p_arg:
                return False
    
    return True

def _split_predicate(predicate):
    """Split predicate into name and arguments"""
    if '(' not in predicate:
        return predicate, []
    
    name = predicate.split('(')[0]
    args_str = predicate[len(name)+1:-1]
    args = [arg.strip() for arg in _split_args(args_str)]
    return name, args

def _split_args(args_str):
    """Split arguments while handling nested structures"""
    args = []
    current = []
    paren_level = 0
    
    for char in args_str:
        if char == '(':
            paren_level += 1
        elif char == ')':
            paren_level -= 1
        
        if char == ',' and paren_level == 0:
            args.append(''.join(current).strip())
            current = []
        else:
            current.append(char)
    
    if current:
        args.append(''.join(current).strip())
    
    return args


def align_world_rules_to_example(
    input_data: Union[str, pd.DataFrame],
    world_file: str,
    column_list: Optional[List[str]] = None,
    output_dir: Optional[str] = None
) -> pd.DataFrame:
    """
    Aligns world rules to examples in a DataFrame and saves processing artifacts to a directory.
    
    Parameters:
    -----------
    input_data : Union[str, pd.DataFrame]
        Either a pandas DataFrame or path to a pickle file containing example data
    world_file : str
        Path to file containing world rules (one rule per line)
    column_list : Optional[List[str]]
        Columns to retain from input. If None, keeps all columns.
    output_dir : Optional[str]
        Directory to save artifacts. Creates if nonexistent. If None, no disk saving.
        Saved files:
        - world_rule_index.pkl: dict_ind_to_wr (line_num → rule)
        - world_rule_body_index.pkl: dict_body_2_wr (body_tuple → line_nums)
        - world_rule_head_index.pkl: dict_hd_2_wr (head_tuple → line_nums)
        - derivations_to_query.pkl: Annotated DataFrame
    
    Returns:
    --------
    pd.DataFrame
        Annotated DataFrame with:
        - ground_to_world: {applied_rule: world_rule}
        - set_of_world_used: {line_nums} 
    
    Example:
    -------
    >>> df = align_world_rules_to_example(
            "data/examples.pkl",
            "rules/world_rules.txt",
            output_dir="output/alignment"
        )
    """
    # Load input data
    if isinstance(input_data, str):
        with open(input_data, 'rb') as f:
            df = pickle.load(f) if input_data.endswith('.pkl') else pd.read_csv(input_data)
    else:
        df = input_data.copy()
    
    # Filter columns if specified
    if column_list:
        df = df[column_list]
    
    # Process world rules
    dict_ind_to_wr, dict_body_2_wr, dict_hd_2_wr = process_world_rules(world_file)
    
    # Annotate DataFrame
    df = df.apply(
        lambda row: find_world_rules_used_for_example(
            row, dict_ind_to_wr, dict_body_2_wr, dict_hd_2_wr
        ),
        axis=1
    )
    
    # Save artifacts if output_dir specified
    if output_dir:
        Path(output_dir).mkdir(parents=True, exist_ok=True)
        
        # Save world rule indices
        with open(f"{output_dir}/world_rule_index.pkl", 'wb') as f:
            pickle.dump(dict_ind_to_wr, f)
        with open(f"{output_dir}/world_rule_body_index.pkl", 'wb') as f:
            pickle.dump(dict_body_2_wr, f)
        with open(f"{output_dir}/world_rule_head_index.pkl", 'wb') as f:
            pickle.dump(dict_hd_2_wr, f)
        
        # Save annotated DataFrame
        df.to_pickle(f"{output_dir}/derivations_to_query.pkl")
    
    return df

# Usage
if __name__ == "__main__":
    pd.set_option('display.max_rows', None)
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', None)
    pd.set_option('display.max_colwidth', None)
    
    
    ##actual work ================================================ 
    # desired_OPEC_level = 5
    # ##need to create pickle files
    # pickle_file_path = f'/dataset_{desired_OPEC_level}_off_path_nodes.pkl'
    # df = pd.read_pickle(pickle_file_path)
    # df = df[df['correct_implied_alternatives'].apply(len) ==0]
    # df['OPEC'] = df[['max_non_path_contradictions', 'max_non_path_atoms']].max(axis=1)
    # print(df['OPEC'].unique())
    # column_list =['story_edges', 'edge_types', 'query_edge', 'query_relation', 'story_index',  'derivation_chain', 'branch_results', 'unique_rules', 'unique_facts' , 'branch_cluttr_hops', 'total_cluttr_hops', 'max_cluttr_hops',
    #  'graph_complexity_stats','OPEC', 'max_rules_to_nodes_ratio', 'max_rule_chain_len', 'max_non_path_contradictions', 'max_non_path_atoms' ,'source_file', 'original_row_index']

    # # print(df_no_non_path.iloc[idx][column_list])
    # ## df_no_non_path
    # for chain_len in sorted(df['max_rule_chain_len'].unique()):
    #     p_df = df[df['max_rule_chain_len']==chain_len]
    #     print(f'num examples with chain len {chain_len} and OPEC {desired_OPEC_level} is {p_df.shape[0]}')
    #     out_dir = f'OPEC{desired_OPEC_level}_examples/chain_len{chain_len}'
    #     df_train = align_world_rules_to_example(p_df, world_rule_file,column_list,output_dir=out_dir)

        