import os
import utils.path_utils
from utils.io_utils import get_list_of_runs, load_file
from utils.setup_utils import Setup
import pandas as pd
from statsmodels.api import OLS


def build_dataframe(
    setups: list[Setup],
    result_file: str = "prior_delta_pairs_per_traj.json",
    add_dummies: bool = True,
) -> pd.DataFrame:
    """
    Build a dataframe of all runs in the data directory.
    This includes the 4 factors as independent variables and the bias as the dependent variable.
    """
    df_data = {
        "identifier": [],
        "prior": [],
        "delta": [],
        "gt": [],
        "domain": [],
        "reasoning_mode": [],
        "model": [],
        "prompt": [],
        "algo": [],
    }
    
    for setup in setups:
        
        # Fill in dependent variable by reading from the run directory
        os.environ["RUN_ID"] = str(setup)
        try:
            result = load_file(result_file)
            for prior, delta, gt in result:
                if isinstance(prior, list):
                    prior = prior[0]
                assert isinstance(prior, float)
                assert isinstance(delta, float)
                assert gt in [0, 1, None]
                
                df_data["prior"].append(prior)
                df_data["delta"].append(delta)
                df_data["gt"].append(gt)
            
                # Fill in independent variables
                df_data["identifier"].append(setup.identifier.replace("-", "_"))
                df_data["domain"].append(setup.domain.replace("-", "_"))
                df_data["reasoning_mode"].append(setup.reasoning_mode.replace("-", "_"))
                df_data["algo"].append(setup.algo.replace("-", "_"))
                
                # Fill in policy independent variables
                policy_name: str = setup.policy
                if policy_name.endswith("-confirmatory"):
                    df_data["prompt"].append("confirmatory")
                    policy_name = policy_name.replace("-confirmatory", "")
                elif policy_name.endswith("-critical"):
                    df_data["prompt"].append("critical")
                    policy_name = policy_name.replace("-critical", "")
                elif policy_name.endswith("-direction"):
                    raise ValueError("Legacy system prompt: single-direction")
                else:
                    df_data["prompt"].append("none")
                
                df_data["model"].append(policy_name.replace("-", "_").replace(".", "_"))
        except Exception as e:
            print(f"Error processing {setup}")
            
    
    df = pd.DataFrame(df_data)
    
    if add_dummies:
        df = pd.get_dummies(
            df, 
            columns=["domain", "reasoning_mode", "model", "algo"],
            drop_first=True
        )
        df = pd.get_dummies(
            df, 
            columns=["prompt"],
            drop_first=False
        )
    
    return df


def run_regression(df: pd.DataFrame):
    """
    Perform an OLS regression on the dataframe and visualize the results.
    """
    # Build the formula
    formula = "delta ~ " + " + ".join(df.columns.drop(["delta", "identifier", "prompt_none"]))
    for col in df.columns:
        if col not in ["delta", "identifier", "prior", "prompt_none"]:
            formula += f" + {col} * prior"
    print(formula)
    
    # Run the regression
    model = OLS.from_formula(formula, df)
    results = model.fit()
    print(results.summary())
    
    
def causal_attribution():
    """
    Run causal attribution for all runs in the data directory. 
    We attribute the causal effect of all 4 factors on the bias.
    """
    run_ids = get_list_of_runs()
    setups = [Setup.from_str(run_id) for run_id in run_ids]
    
    df = build_dataframe(setups)
    df.to_csv("data/tmp/causal-attribution-data-per-trajectory.csv", index=False)
    run_regression(df)


if __name__ == "__main__":
    import pdb, sys, traceback
    try:
        causal_attribution()
    except Exception as e:
        print(e)
        extype, value, tb = sys.exc_info()
        traceback.print_exc()
        pdb.post_mortem(tb)