import os

import utils.path_utils
from utils.io_utils import load_file, dump_file, complete_path
from core.policy.schema import Policy
from core.reasoning.schema import ReasoningTrajectory, ReasoningMode
from core.domain.schema import BinaryProblem, ProblemDomain
from core.algo.schema import DebiasStrategy


def run_reasoning(
    algo: DebiasStrategy,
    policy: Policy,
    domain: ProblemDomain,
    reasoning_mode: ReasoningMode,
    num_trajectories: int,
    expected_priors: list[None | float],
    ground_truth_req: bool,
    clear_run_id: bool = False,
):
    if clear_run_id and os.environ.get("RUN_ID", None) is not None:
        os.environ.pop("RUN_ID")
    
    # Produce reasoning trajectories
    if (os.environ.get("RUN_ID", None) is None or
        not os.path.exists(complete_path("reasoning-trajectories.json"))):
        reasoning_trajectories = reasoning_mode.reasoning_rollout(
            policy=policy,
            domain=domain,
            algo=algo,
            n=num_trajectories,
            expected_priors=expected_priors,
        )
        dump_file("reasoning-trajectories.json", reasoning_trajectories)
    else:
        reasoning_trajectories = load_file(
            "reasoning-trajectories.json",
            schema=ReasoningTrajectory,
        )

    # Evaluate reasoning trajectories
    belief_entrenchment, details = algo.measure_belief_entrenchment(reasoning_trajectories)
    print(f"Bias magnitude in this batch of reasoning trajectories: {belief_entrenchment}")
    # belief_entrenchment_informative, _ = algo.measure_belief_entrenchment(reasoning_trajectories, informative_switch=True)
    # print(f"Bias magnitude in this batch of reasoning trajectories, with informativeness reward added: {belief_entrenchment_informative}")
 
    if ground_truth_req:
        # Evaluate ground-truth performance
        initial_cross_entropy_loss = BinaryProblem.calculate_belief_accuracy_loss(
            problems=[rt.problem for rt in reasoning_trajectories],
            beliefs=[rt.steps[0].belief for rt in reasoning_trajectories],
            metric="cross_entropy",
        )
        initial_brier_loss = BinaryProblem.calculate_belief_accuracy_loss(
            problems=[rt.problem for rt in reasoning_trajectories],
            beliefs=[rt.steps[0].belief for rt in reasoning_trajectories],
            metric="brier",
        )
        eventual_cross_entropy_loss = BinaryProblem.calculate_belief_accuracy_loss(
            problems=[rt.problem for rt in reasoning_trajectories],
            beliefs=[rt.steps[-1].belief for rt in reasoning_trajectories],
            metric="cross_entropy",
        )
        eventual_brier_loss = BinaryProblem.calculate_belief_accuracy_loss(
            problems=[rt.problem for rt in reasoning_trajectories],
            beliefs=[rt.steps[-1].belief for rt in reasoning_trajectories],
            metric="brier",
        )

        dump_file(
            "bias-eval-results.json",
            {
                "algo": str(algo),
                "policy": str(policy),
                "domain": str(domain),
                "reasoning_mode": str(reasoning_mode),
                "belief_entrenchment": belief_entrenchment,
                # "belief_entrenchment_informative": belief_entrenchment_informative,
                "num_trajectories": num_trajectories,
                "expected_priors": expected_priors,
                "initial_cross_entropy_loss": initial_cross_entropy_loss,
                "eventual_cross_entropy_loss": eventual_cross_entropy_loss,
                "initial_brier_loss": initial_brier_loss,
                "eventual_brier_loss": eventual_brier_loss,
                "belief_entrenchment_details": details,
            },
        )
    else:
        dump_file(
            "bias-eval-results.json",
            {
                "algo": str(algo),
                "policy": str(policy),
                "domain": str(domain),
                "reasoning_mode": str(reasoning_mode),
                "belief_entrenchment": belief_entrenchment,
                # "belief_entrenchment_informative": belief_entrenchment_informative,
                "num_trajectories": num_trajectories,
                "expected_priors": expected_priors,
                "belief_entrenchment_details": details,
            },
        )
    
    # Save environment, with API-related environment variables removed
    dump_file(
        "environ.json",
        {
            k: v
            for k, v in dict(os.environ).items()
            if "API" not in k and "GOOGLE" not in k and "TOKEN" not in k 
        }
    )


if __name__ == "__main__":
    from scripts.setup import (
        setup,
        num_trajectories,
        expected_priors,
        ground_truth_req
    )

    if eval(os.environ.get("DEBUG", "0")):
        # Debug mode: break at the first error
        import pdb, sys, traceback

        try:
            run_reasoning(
                setup.algo,
                setup.policy,
                setup.domain,
                setup.reasoning_mode,
                num_trajectories,
                expected_priors,
                ground_truth_req,
                False
            )
        except Exception as e:
            print(e)
            extype, value, tb = sys.exc_info()
            traceback.print_exc()
            pdb.post_mortem(tb)

    else:
        # Normal mode: run the script directly
        run_reasoning(
            setup.algo,
            setup.policy,
            setup.domain,
            setup.reasoning_mode,
            num_trajectories,
            expected_priors,
            ground_truth_req,
            False
        )
