import os
import utils.path_utils
from utils.io_utils import get_list_of_runs, load_file
from utils.setup_utils import Setup
import pandas as pd
from statsmodels.api import OLS


def build_dataframe(
    setups: list[Setup],
    result_file: str = "bias-eval-results.json",
    accuracy_field: str = "eventual_brier_loss",
    bias_field: list[str] = ["belief_entrenchment_details", "linear_regression", "coef"],
    add_dummies: bool = False,
) -> pd.DataFrame:
    """
    Build a dataframe of all runs in the data directory.
    This includes the 4 factors as independent variables and the bias as the dependent variable.
    """
    df_data = {
        "identifier": [],
        "bias": [],
        "accuracy": [],
        "domain": [],
        "reasoning_mode": [],
        "model": [],
        "prompt": [],
        "algo": [],
    }
    
    for setup in setups:
        # Fill in independent variables
        df_data["identifier"].append(setup.identifier.replace("-", "_"))
        df_data["domain"].append(setup.domain.replace("-", "_"))
        df_data["reasoning_mode"].append(setup.reasoning_mode.replace("-", "_"))
        df_data["algo"].append(setup.algo.replace("-", "_"))
        
        # Fill in policy independent variables
        policy_name: str = setup.policy
        if policy_name.endswith("-confirmatory"):
            df_data["prompt"].append("confirmatory")
            policy_name = policy_name.replace("-confirmatory", "")
        elif policy_name.endswith("-critical"):
            df_data["prompt"].append("critical")
            policy_name = policy_name.replace("-critical", "")
        elif policy_name.endswith("-direction"):
            raise ValueError("Legacy system prompt: single-direction")
        else:
            df_data["prompt"].append("none")
        
        df_data["model"].append(policy_name.replace("-", "_").replace(".", "_"))
            
        
        # Fill in dependent variable by reading from the run directory
        os.environ["RUN_ID"] = str(setup)
        result = load_file(result_file)
        
        # Fill in accuracy
        if accuracy_field not in result:
            accuracy = None
        else:
            accuracy = result[accuracy_field]
            if isinstance(accuracy, list):
                accuracy = accuracy[0]
            
            assert isinstance(accuracy, float), f"Accuracy is not a float: {accuracy}"
        
        # Fill in bias
        bias = result
        for field_name in bias_field:
            bias = bias[field_name]
        
        if isinstance(bias, list):
            bias = bias[0]
            
        assert isinstance(bias, float), f"Bias is not a float: {bias}"
        
        df_data["accuracy"].append(accuracy)
        df_data["bias"].append(bias)
    
    df = pd.DataFrame(df_data)
    
    if add_dummies:
        df = pd.get_dummies(
            df, 
            columns=["domain", "reasoning_mode", "model", "algo", "prompt"],
            drop_first=True
        )
    
    return df


def run_regression(df: pd.DataFrame):
    """
    Perform an OLS regression on the dataframe and visualize the results.
    """
    # Build the formula
    formula = "bias ~ " + " + ".join(df.columns.drop(["bias", "identifier", "accuracy"]))
    print(formula)
    
    # Run the regression
    model = OLS.from_formula(formula, df)
    results = model.fit()
    print(results.summary())


def run_regression_with_accuracy(
    df: pd.DataFrame,
    formula: str = "accuracy ~ bias + domain + reasoning_mode"
):
    """
    Perform an OLS regression accuracy ~ bias + controls on the dataframe
    """
    # Build the formula
    print(formula)
    
    # Run the regression
    model = OLS.from_formula(formula, df)
    results = model.fit()
    print(results.summary())
    
def causal_attribution():
    """
    Run causal attribution for all runs in the data directory. 
    We attribute the causal effect of all 4 factors on the bias.
    """
    run_ids = get_list_of_runs()
    setups = [Setup.from_str(run_id) for run_id in run_ids]
    
    df = build_dataframe(setups)
    df["bias"] = abs(df["bias"])
    print(df.describe())
    df.to_csv("data/tmp/causal-attribution-data.csv", index=False)
    # run_regression(df)
    run_regression_with_accuracy(df, "accuracy ~ bias"); input()
    run_regression_with_accuracy(df, "accuracy ~ bias + domain"); input()
    run_regression_with_accuracy(df, "accuracy ~ bias + reasoning_mode"); input()
    run_regression_with_accuracy(df, "accuracy ~ bias + domain + reasoning_mode"); input()
    run_regression_with_accuracy(df, "accuracy ~ bias + domain + reasoning_mode + model"); input()
    run_regression_with_accuracy(df, "accuracy ~ bias + domain + reasoning_mode + prompt")
    


if __name__ == "__main__":
    import pdb, sys, traceback
    try:
        causal_attribution()
    except Exception as e:
        print(e)
        extype, value, tb = sys.exc_info()
        traceback.print_exc()
        pdb.post_mortem(tb)