import traceback
import numpy as np
from src.pcgym import make_env

from params import get_running_params, get_env_params
from utils import py2str, str2py, py2func

from sub_agents.Coder import Coder
from sub_agents.Evaluator import Evaluator
from sub_agents.Debugger import Debugger

running_params = get_running_params()
system = running_params['system']
env, env_params = get_env_params(running_params['system'])

# %%
def ce_by_policy(t_begin, t_end, policy, message, team_conversation, max_retries, horizon, use_debugger=True):
    """
    Contrastive analysis to future trajectories, according to rule-based policies.
    The policies are generated by successive interactions with coder and evaluator agent.
    i.e.) "What would the trajectory change if I use the on-off controller instead of the current RL policy?"
          "Would you compare the predicted trajectory between our RL policy and bang-bang controller after t=300?"
    Args:
        t_begin (Union[float, int]): First time step within the simulation interval to be interpreted
        t_end (Union[float, int]): Last time step within the simulation interval to be interpreted
        policy (BaseAlgorithm): Trained RL actor, using stable-baselines3
        message (str): Input message raised by Coordinator agent, about policy behavior or constraints
        team_conversation (list): Conversation history between agents
        max_retries (int): Maximum number of iteration allowed for generating the decomposed reward function
        horizon (int): Length of future horizon to be explored
        use_debugger (bool): Whether to use the debugger for refining the code
    Returns:
        figures (list): List of decomposed reward figures
        evaluator.data (dict): Forward rollout data of actual and contrastive scenarios
    """

    # Translate queried timesteps to indices
    begin_index = int(np.round(t_begin / env_params['delta_t']))
    end_index = int(np.round(t_end / env_params['delta_t']))
    len_indices = end_index - begin_index + 1
    horizon += len_indices # Re-adjusting horizon

    # Regenerating trajectory data with noise disabled
    env_params['noise'] = False  # For reproducibility
    env = make_env(env_params)

    evaluator, data = env.get_rollouts({'Actual': policy}, reps=1, get_Q=True)

    # Initializing Coder, Debugger and Evaluator agent
    generator = Coder()
    debugger = Debugger()
    ev = Evaluator()

    # Generate initial policy using Coder
    code = generator.generate(message, policy)
    print(f"[Coder] Initial contrastive policy generated")
    team_conversation.append({"agent": "Coder",
                              "content": f"Initial policy generated",
                              "code_length": len(code)
                              })

    success = False
    trial = 0

    # Iterate until no errors or hallucinations are detected
    while not success and trial < max_retries:
        try:
            file_path = f'./policies/[{system}] ce_policy.py'
            str2py(code, file_path=file_path)
            CE_policy = py2func(file_path, 'CE_policy')(env, policy)

            # Obtain rollout data from contrastive policy trajectories
            ce_settings = {
                'CE_mode': 'policy',
                'begin_index': begin_index,
                'end_index': end_index,
                'CE_policy': CE_policy
            }
            _, data_ce = env.get_rollouts({'New policy': policy}, reps=1, get_Q=False,
                                                  ce_settings=ce_settings)
            data_interval = data_ce['New policy'].copy()
            for k, v in data_interval.items():
                data_interval[k] = v[:, begin_index:end_index, :]

            # Evaluate the policy with Evaluator agent
            ev.evaluate(data_interval, message=message)

            success = True

        # If errors or hallucinations are detected, refine the code.
        except Exception as e:
            trial += 1
            error_message = traceback.format_exc()
            error_type = type(e).__name__
            print(f"[Debugger] Error during rollout (trial {trial}):\n{str(e)}")
            team_conversation.append({"agent": "Debugger",
                                      "content": f"[Trial {trial}] Error during rollout",
                                      "error_message": str(e),
                                      "error_type": error_type
                                      })

            if use_debugger:
                guidance = debugger.debug(code, error_message)
                code = generator.refine_with_guidance(error_message, guidance) # Use guidance from debugger agent
            else:
                code = generator.refine_with_error(error_message) # Just use the error message

            team_conversation.append({"agent": "Coder",
                                      "content": f"[Trial {trial}] Refined policy generated.",
                                      "code_length": len(code)
                                      })

    log = "[Coder] Code successfully generated. Rollout complete." if success \
        else "[Coder] Failed after multiple attempts."
    team_conversation.append({"agent": "Coder",
                              "content": log,
                              "status_message": log,
                              "status": 'success' if success else 'failure'
                              })

    print(log)
    team_conversation.append(
        {"agent": "Coder",
         "content": f"[Trial {trial}] Code successfully generated.",
         "Code": f"{generator.prev_codes[-1]}"
         })

    # Obtain final rollout data from contrastive policy trajectories
    if success:
        # Append contrastive results to evaluator object
        evaluator.n_pi += 1
        evaluator.policies['New policy'] = policy
        evaluator.data = data | data_ce

        interval = [begin_index - 1, begin_index + horizon]  # Interval to watch the control results
        figures = [evaluator.plot_data(evaluator.data, interval=interval)]

        return figures, evaluator.data

    else:
        return None, None
