from assistantsv2 import create_agent
import os
import sys
import numpy as np
import shutil
import json
from argparse import ArgumentParser
from loguru import logger
from tqdm import tqdm
from sample_interactions import sample_session_given_query
from utils import save_json, load_json, parse_json

import random
random.seed(42)  # For reproducibility


METRICS = ["accuracy", "jaccard", "hamming"]


# Template for overall evaluation - asks agent to select best answer from multiple choices
OVERALL_PROMPT = """\
{query}

Please select the most suitable answer for my current situation from the following options:
(considering my current relevant preferences and state information)

{choices}

Express your choice with a number and output in the following JSON format:
```json
{{
    "answer": int
}}
```
Only keep the JSON format output, do not include any other content.
"""


# Template for utilization evaluation - asks agent to select answer given explicit state information
UTILIZATION_PROMPT = """\
{query}

Given that my current relevant preferences and state information are as follows:
{state}

Please select the most suitable answer for my current situation from the following options:

{choices}

Express your choice with a number and output in the following JSON format:
```json
{{
    "answer": int
}}
```
Only keep the JSON format output, do not include any other content.
"""


def state_similarity(state1, state2, metric="accuracy"):
    assert len(state1) == len(state2), "States must have the same length."

    num_matched = sum(s1 == s2 for s1, s2 in zip(state1, state2))

    match metric:
        case "accuracy":
            return float(num_matched == len(state1))
        case "hamming":
            return num_matched / len(state1)
        case "jaccard":
            return num_matched / (len(state1) * 2 - num_matched)
        case _:
            raise ValueError(f"Invalid metric: {metric}")
    
def evaluate_item_overall(item, agent, output_dir, env_config):
    """
    Perform overall evaluation of an agent on a single evaluation item.

    This function evaluates an agent's ability to answer questions based on tracked preferences
    by asking the agent to select the best answer from a set of choices given the question.
    """
    num_questions, num_periods = len(item["qas"]), len(item["periods"])

    # EARLY EXIT: Check if evaluation is already complete
    metric_path = os.path.join(output_dir, f"overall_metrics.json")
    if os.path.exists(metric_path):
        logger.info(f"Overall evaluation already complete for {output_dir}. Skipping.")
        return
    
    results_path = os.path.join(output_dir, f"overall_results.json")

    if not os.path.exists(results_path):
        results = [[None for _ in range(num_questions)] for _ in range(num_periods)]
    else:
        results = load_json(results_path)
    for pi, period in enumerate(tqdm(item["periods"], desc="Evaluating overall questions", ncols=80)):
        agent_state_dir = os.path.join(output_dir, f"agent_states/period_{pi:02d}")
        if os.path.exists(agent_state_dir):
            # Check if all questions for this period are answered
            period_complete = all(results[pi][qi] is not None for qi in range(num_questions))
            if period_complete:
                logger.info(f"Period {pi} already complete. Loading state and continuing.")
                agent.load_state(agent_state_dir)
                continue
            else:
                agent.load_state(agent_state_dir)
        else:
            sessions = period["sessions"]
            for session in tqdm(sessions, ncols=80, leave=False, desc="On-policy interactions"):
                query_with_time = f"[Current Time: {session['session_time']}]\n" + session["query"]
                max_rounds = env_config["num_rounds_init"] if pi == 0 else env_config["num_rounds_update"]

                msgs = sample_session_given_query(
                    env_config["llm_config_low_temp"], query_with_time, agent, item["start_time"],
                    item["user_profile"], period["period_end"], session["exposed_states"],
                    item["state_schema"], hist=None, max_rounds=max_rounds
                )
            
            agent.save_state(agent_state_dir)   # save evolution state info

        # logger.info(f"period {pi}, agent msg # turns: {len(agent.msg_history)}")

        for qi, qa in enumerate(tqdm(item["qas"], desc="Asking questions", ncols=80, leave=False)):
            if results[pi][qi] is not None:
                continue
            choices_text = '\n'.join(['{}: {}'.format(
                i + 1, choice['answer']) for i, choice in enumerate(qa["answer_choices"])])
            query = OVERALL_PROMPT.format(
                query=qa["query"], choices=choices_text)
            retrieved_memories, (response, usage_statistics) = agent.answer_question(query)
            try:
                response_answer = parse_json(response)["answer"]
                response_choice = qa["answer_choices"][response_answer - 1]
                json_error = False
            except Exception as e:
                response_answer = random.randint(1, len(qa["answer_choices"]))  # Fallback to random choice
                response_choice = qa["answer_choices"][response_answer - 1]
                json_error = True
            # retrieve golden answer
            golden_state = [period["state"][info_type] for info_type in qa["required_info"]]
            ci = 0
            for ci, choice in enumerate(qa["answer_choices"]):
                if choice["state"] == golden_state:
                    break
            choice = qa["answer_choices"][ci]

            scores = {}
            for metric in METRICS:
                scores[metric] = state_similarity(choice["state"], response_choice["state"], metric)

            result = {
                "query": qa["query"],
                "answer": ci,
                "answer_state": choice["state"],
                "answer_choice": choice["answer"],
                "raw_response": response,
                "response": response_answer,
                "response_state": response_choice["state"],
                "response_choice": response_choice["answer"],
                "json_error": json_error,
                "llm_usage_statistics": usage_statistics,
                "scores": scores,
                "relevant_memories": retrieved_memories  # Add retrieved memories
            }
            results[pi][qi] = result
            save_json(results_path, results)
    
        # >>>>> BEGIN ====== EVOLUTION POINT ======
        # Now we have fresh Q&A results and can make informed evolution decisions
        if agent.evolution_config["cadence"] == "period":
            logger.info(f"Triggering evolution after period {pi} evaluation")
            
            # Prepare env feedback for evolution
            feedback = []
            for result, qas in zip(results[pi], item['qas']):
                if result is None:
                    continue

                question = qas['query']
                choices = [item['answer'] for item in qas['answer_choices']]
                ground_truth_choice_id = result['answer']
                response_choice_id = result['response']
                retrieved_memories = result['relevant_memories']
                choices_states = [item['state'] for item in qas['answer_choices']]

                import string
                labels = list(string.ascii_uppercase)
                def format_question_choices(question, choices, with_choices=True):
                    
                    q = question.strip()
                    lines = [f"Question: {q}"]
                    if with_choices:
                        for i, choice in enumerate(choices):
                            label = labels[i] if i < len(labels) else str(i + 1)
                            lines.append(f"({label}) {choice.strip()}")

                    return ";\n".join(lines) + ";"  # semicolon at end of each line
                
                feedback_info = {}
                if "vanilla_question_only" in agent.evolution_config["feedback_types"]:
                    # Only provide memory + question text (no candidate choices & assistant's response)
                    feedback_info.update({
                            "question": format_question_choices(question, choices, with_choices=False),
                            "retrieved_memories": retrieved_memories
                        })
                elif "vanilla" in agent.evolution_config["feedback_types"]:
                    # Memory + question text with candidate choices & assistant's response
                    feedback_info.update({
                            "question": format_question_choices(question, choices, with_choices=True),
                            "assistant_response": labels[response_choice_id],
                            "retrieved_memories": retrieved_memories
                        })
                    
                if "with_answer" in agent.evolution_config["feedback_types"]:
                    feedback_info.update({
                            "ground_truth": labels[ground_truth_choice_id]
                        })
                    
                feedback.append(feedback_info)

            if "with_exposed_states" in agent.evolution_config["feedback_types"]:
                # States exposed in this period's sessions
                exposed_states = {}
                for session in period['sessions']:
                    exposed_states.update(session['exposed_states'])

                feedback = {
                    "question_answer_history": feedback,
                    "user_information_updates": exposed_states
                }
            # Trigger evolution with feedback
            changes_made = agent._evolve_policy(feedback)

            # Log evolution step
            evolution_step = {
                "step_id": len(agent.evolution_history),
                "feedback": feedback,
                "changes": changes_made
            }

            agent.evolution_history.append(evolution_step)
        elif agent.evolution_config["cadence"] == "no_evolution":
            logger.info(f"NO evolution (cadence == 'no_evolution') after period {pi} evaluation")
        else:
            raise ValueError(f"Unexpected cadence value in evolution config: {agent.evolution_config['cadence']}")
        
        agent.save_state(agent_state_dir)   # save evolution state info
        # <<<<< END


    metric_path = os.path.join(output_dir, f"overall_metrics.json")
    if os.path.exists(metric_path):
        return
    overall_metrics = {metric: np.zeros((num_periods, num_questions)) for metric in METRICS}
    num_json_errors = 0
    for pi in range(num_periods):
        for qi in range(num_questions):
            result = results[pi][qi]
            if result["json_error"]:
                num_json_errors += 1
            for metric in METRICS:
                overall_metrics[metric][pi, qi] = result["scores"][metric]
    for metric, scores in overall_metrics.items():
        logger.info(f"Overall metric: {metric} {scores.mean()} {scores.mean(axis=1)}")
        overall_metrics[metric] = scores.tolist()
    save_json(metric_path, overall_metrics)


def evaluate_item(item, agent, output_dir, env_config):
    """
    Perform evaluation of an agent on a single evaluation item.

    This is the core evaluation function that tests an agent's ability to:
    1. Utilize explicit state information (utilization evaluation)
    2. Infer state from conversation history (atomic evaluation)  
    3. Answer questions based on learned preferences (overall evaluation)
    
    The function runs multiple evaluation modes across different time periods and
    saves detailed results including response accuracy and error analysis.
    
    Args:
        item (dict): Evaluation item with state schema, QAs, and conversation periods
        agent (object): Agent instance to evaluate
        output_path (str): File path to save evaluation results
    """
    # agent.reset()  # Reset the agent for a fresh new life cycle
    evaluate_item_overall(item, agent, output_dir, env_config)


def compute_metrics(item, output_path):
    """ Compute comprehensive evaluation metrics from raw evaluation results.
    
    This function processes the raw evaluation outputs to calculate meaningful metrics including 
        - read/write accuracy, 
        - utility scores, 
        - fact-level correctness. 
    It analyzes temporal aspects of memory and state tracking across conversation periods.
    
    Args:
        item (dict): Original evaluation item with ground truth data
        output_path (str): Path to file containing evaluation results to process
    
    Side Effects:
        Updates the output file with computed metrics including:
        - period_metrics: Per-period and per-question performance scores
        - fact_results: Individual fact-level accuracy across time periods
    """
    output_data = load_json(output_path)

    period_results = output_data["period_results"]  # w/o golden state
    utilization_results = output_data["utilization_results"]  # w/ golden state
    # map golden state to utlity scores for each question
    state_to_scores = [{} for _ in utilization_results]
    for state_to_score, utilization_result in zip(state_to_scores, utilization_results):
        for result in utilization_result:
            state = tuple(result["answer_state"])
            score = result["score"]
            assert state not in state_to_score, f"State {state} already exists in state_to_scores."
            state_to_score[state] = score

    state_write_pos = {
        info_type: 0 for info_type in item["state_schema"].keys()}
    fact_results = []
    period_metrics = [[] for _ in item["periods"]]
    for pos, (period_result, period) in enumerate(zip(period_results, item["periods"])):
        # update state write position
        for info_type in period["updates"].keys():
            state_write_pos[info_type] = pos

        response_state = period_result["response_state"]

        # query-granularity evaluation
        for qi, question in enumerate(item["qas"]):
            # read score
            predicted_state = tuple(
                response_state["state"][info_type] for info_type in question["required_info"])
            period_state = tuple(period["state"][info_type]
                                 for info_type in question["required_info"])
            read_score = state_similarity(predicted_state, period_state)

            # write score
            write_score = 0
            for info_type in question["required_info"]:
                info_write_pos = state_write_pos[info_type]
                if period_results[info_write_pos]["response_state"]["state"][info_type] == period["state"][info_type]:
                    write_score += 1
            write_score /= len(question["required_info"])

            # utility score
            utlity_score = state_to_scores[qi][period_state]
            period_metrics[pos].append({
                "query": question["query"],
                "read_score": read_score,
                "write_score": write_score,
                "utility_score": utlity_score,
                "overall_score": period_result["overall_results"][qi]["score"],
                "response_state": response_state["state"],
                "period_state": period["state"],
            })

        # fact-granularity evaluation
        for info_type in item["state_schema"].keys():
            correct = response_state["state"][info_type] == period["state"][info_type]
            fact_results.append({
                "info_type": info_type,
                "correct": correct,
                "write_pos": state_write_pos[info_type],
                "read_pos": pos,
            })

    output_data["metrics"] = {
        "period_metrics": period_metrics,
        "fact_results": fact_results,
    }
    save_json(output_path, output_data)


def setup_logging(item_dir):
    log_dir = os.path.join(item_dir, "logs")
    logger.remove()
    logger.add(sys.stdout, level="INFO")
    logger.add(
        os.path.join(log_dir, "evaluate.log"),
        level="TRACE", rotation="10 MB"
    )


if __name__ == "__main__":
    parser = ArgumentParser(
        description="Evaluate agents using environment data and generated conversations.")
    parser.add_argument("--env_data", type=str, default="data/environment_data/data-small.json",
                        help="Environment data file")
    parser.add_argument("--env_config", type=str, default="configs/environment_config_small.json",
                        help="Environment configuration file")
    parser.add_argument("--agent_config", type=str, required=True,
                        help="Agent configuration file")
    # >>>>> BEGIN
        #TODO: EVOLUTION CONFIG
    # <<<<< END
    parser.add_argument("--output_dir", type=str, default="eval-output/evolution",
                        help="Output directory for evaluation results")
    parser.add_argument("--reset", action="store_true",
                        help="Overwrite output")
    args = parser.parse_args()

    # Load configurations
    agent_config = load_json(args.agent_config)
    env_config = load_json(args.env_config)

    args.output_dir = os.path.join(args.output_dir, agent_config["name"])
    if args.reset and os.path.exists(args.output_dir):
        shutil.rmtree(args.output_dir)
    os.makedirs(args.output_dir, exist_ok=True)
    shutil.copy(args.agent_config, args.output_dir)

    data = load_json(args.env_data)
    # agent = create_agent(agent_config, output_dir=args.output_dir)
    for item in data:
        item_dir = os.path.join(args.output_dir, item["id"])

        # CHECK IF ITEM ALREADY COMPLETED
        metric_path = os.path.join(item_dir, "overall_metrics.json")
        if os.path.exists(metric_path) and not args.reset:
            logger.info(f"Item {item['id']} already evaluated. Skipping.")
            continue

        agent = create_agent(agent_config, output_dir=item_dir)
        setup_logging(item_dir)
        evaluate_item(item, agent, item_dir, env_config)
