import os
import json
from llm import call_llm
from tqdm import tqdm
from utils import load_json, save_json

llm_config_path = "configs/environment_config_small.json"
llm_config = load_json(llm_config_path)['llm_config_low_temp']


def check_claims_consistency_batch(exposed_states, memory):
    """
    Check the consistency of all claims against agent memory using a single LLM call.

    Args:
        exposed_states (dict): Dictionary of claims to verify
        memory (list): List of memory entries with label, value, and timestamp
        model (str): Model to use for LLM response

    Returns:
        dict: Results showing consistency for each claim
    """

    # Convert memory list to a readable document format
    def format_memory_as_document(memory):
        document_parts = []
        for entry in memory:
            # Format each memory entry
            formatted_entry = f"- {entry['label']}: {entry['value']} (recorded: {entry['timestamp']})"
            document_parts.append(formatted_entry)

        if len(document_parts):
            return "\n".join(document_parts)
        else:
            return "NO DOCUMENT AVAILABLE (Empty memory)"

    # Format the memory as a document
    document = format_memory_as_document(memory)

    # Format all claims as a numbered list
    claims_list = []
    claim_keys = list(exposed_states.keys())

    for i, (claim_key, claim_value) in enumerate(exposed_states.items(), 1):
        claim_statement = f"{i}. The user's {claim_key.replace('_', ' ')} is '{claim_value}'"
        claims_list.append(claim_statement)

    claims_text = "\n".join(claims_list)

    # Batch consistency check prompt
    prompt_template = """Below is a summary of information collected from conversations with a user, followed by multiple claims about their current characteristics or situation.

User's Conversational History Summary:
{document}

Claims about user:
{claims}

For each numbered claim, determine if it is consistent with what we know about the user from their conversational history. Answer "yes" if the claim is supported by the conversational evidence, or "no" if it is not supported or contradicted.

Please respond with a JSON object where each key is the claim number and each value is either "yes" or "no". For example:
{{
  "1": "yes",
  "2": "no",
  "3": "yes"
}}

Response:"""

    # Create the full prompt
    full_prompt = prompt_template.format(
        document=document,
        claims=claims_text
    )

    try:
        # Get LLM response
        messages = [{"role": "user", "content": full_prompt}]
        llm_response = call_llm(messages, llm_config=llm_config, json=True)

        # Parse the JSON response
        if isinstance(llm_response, dict):
            llm_results = llm_response
        else:
            # Try to parse as JSON if it's a string
            llm_results = json.loads(llm_response.strip())

        # Build results dictionary
        results = {}

        for i, claim_key in enumerate(claim_keys, 1):
            claim_statement = f"The user's {claim_key.replace('_', ' ')} is '{exposed_states[claim_key]}'"

            # Get the consistency result for this claim
            consistency_raw = llm_results.get(str(i), "no").lower().strip()
            consistency_result = True if 'yes' in consistency_raw else False

            results[claim_key] = {
                'claim': claim_statement,
                'consistency': consistency_result,
                'raw_response': consistency_raw
            }

        return results

    except Exception as e:
        # If batch processing fails, return error for all claims
        results = {}
        for claim_key, claim_value in exposed_states.items():
            claim_statement = f"The user's {claim_key.replace('_', ' ')} is '{claim_value}'"
            results[claim_key] = {
                'claim': claim_statement,
                'consistency': 'error',
                'error': str(e)
            }

        return results


def check_claims_consistency_individual(exposed_states, memory):
    """
    Check the consistency of claims against agent memory using LLM.

    Args:
        exposed_states (dict): Dictionary of claims to verify
        memory (list): List of memory entries with label, value, and timestamp
        model (str): Model to use for LLM response

    Returns:
        dict: Results showing consistency for each claim
    """

    # Convert memory list to a readable document format
    def format_memory_as_document(memory):
        document_parts = []
        for entry in memory:
            # Format each memory entry
            formatted_entry = f"- {entry['label']}: {entry['value']} (recorded: {entry['timestamp']})"
            document_parts.append(formatted_entry)

        if len(document_parts):
            return "\n".join(document_parts)
        else:
            return "NO DOCUMENT AVAILABLE (Empty memory)"

    # Format the memory as a document
    document = format_memory_as_document(memory)

    # Consistency check prompt template (adapted from MiniCheck)
#     prompt_template = """Determine whether the provided claim is consistent with the corresponding document. Consistency in this context implies that
# all information presented in the claim is substantiated by the document. If not, it should be considered inconsistent.
# Document: {document}
# Claim: {claim}
# Please assess the claim's consistency with the document by responding with either "yes" or "no".
# Answer: """

    prompt_template = """Below is a summary of information collected from conversations with a user, followed by a claim about their current characteristics or situation.

User's Conversational History Summary:
{document}

Claim about user: {claim}

Question: Is this claim consistent with what we know about the user from their conversational history? Answer "yes" if the claim is supported by the conversational evidence, or "no" if it is not supported or contradicted.

Answer: """

    results = {}

    # Check each claim against the memory document
    for claim_key, claim_value in tqdm(exposed_states.items()):
        # Format the claim as a readable statement
        claim_statement = f"The user's {claim_key.replace('_', ' ')} is '{claim_value}'"

        # Create the full prompt
        full_prompt = prompt_template.format(
            document=document,
            claim=claim_statement
        )

        # Get LLM response
        try:
            # llm_response = get_llm_response(full_prompt, model, return_type="text_only")
            messages = [{"role": "user", "content": full_prompt}]
            llm_response = call_llm(messages, llm_config=llm_config)

            # Parse the response to extract yes/no
            response_text = llm_response.strip().lower()

            # Extract the answer (looking for yes or no)
            if 'yes' in response_text:
                consistency_result = True
            elif 'no' in response_text:
                consistency_result = False
            else:
                consistency_result = False

            results[claim_key] = {
                'claim': claim_statement,
                'consistency': consistency_result,
                'raw_response': llm_response
            }

        except Exception as e:
            results[claim_key] = {
                'claim': claim_statement,
                'consistency': 'error',
                'error': str(e)
            }

    return results


def check_claims_consistency(exposed_states, memory, use_batch=True):
    """
    Main function that chooses between batch and individual processing.

    Args:
        exposed_states (dict): Dictionary of claims to verify
        memory (list): List of memory entries with label, value, and timestamp
        model (str): Model to use for LLM response
        use_batch (bool): Whether to use batch processing (recommended)

    Returns:
        dict: Results showing consistency for each claim
    """

    if use_batch:
        return check_claims_consistency_batch(exposed_states, memory)
    else:
        # Use original individual processing (kept for compatibility)
        return check_claims_consistency_individual(exposed_states, memory)


def print_consistency_results(results):
    """
    Print the consistency results in a readable format.

    Args:
        results (dict): Results from check_claims_consistency
    """
    print("Claims Consistency Check Results:")
    print("=" * 50)

    for claim_key, result in results.items():
        print(f"\nClaim: {claim_key}")
        print(f"Statement: {result['claim']}")
        print(f"Consistency: {result['consistency']}")

        if 'error' in result:
            print(f"Error: {result['error']}")
        elif 'raw_response' in result:
            print(f"LLM Response: {result['raw_response']}")

        print("-" * 30)


def compute_factual_consistency(folder, env_data_fn="data/environment_data/data-small.json", n_profiles=20):
    """Process all memory & state factual consistency for a run."""
    # Folder: in-context run results folder
    # e.g. folder = "eval-output/in-context/gpt-4.1-mini-in-context-bsz2-local4"

    all_env_data = load_json(env_data_fn)

    for data_id in range(n_profiles):
        env_data = all_env_data[data_id]
        results_fn = os.path.join(
            folder, env_data['id'], "memory_factual_consistency.json")

        if not os.path.exists(results_fn) and \
                os.path.exists(os.path.join(folder, env_data['id'], "overall_metrics.json")):   # When the profile has finished running AND factual consistency not calculated.

            print(data_id, env_data['id'])
            output_dir = os.path.join(folder, env_data['id'])
            memory_factual_consistency = []
            for period_id in tqdm(range(len(env_data['periods']))):  # period_00 to period_10
                # Load environmental states
                exposed_states = {}
                for session in env_data['periods'][period_id]['sessions']:
                    exposed_states.update(session['exposed_states'])

                # Load agent memory info
                state_dir = os.path.join(folder, env_data['id'], "agent_states", "period_"+(
                    "0"+str(period_id) if period_id < 10 else str(period_id)))
                memory_fn = os.path.join(state_dir, "in_context_memory.json")
                memory = load_json(memory_fn)

                # Run the consistency check
                results = check_claims_consistency(
                    exposed_states, memory, use_batch=True)
                memory_factual_consistency.append(results)

            # Save results
            save_json(os.path.join(
                output_dir, "memory_factual_consistency.json"), memory_factual_consistency)
        else:
            print(
                f"Profile not finished OR memory_factual_consistency.json already exists for: {env_data['id']}")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Compute factual consistency for agent evolution data")
    
    parser.add_argument("--env-data", type=str, default="data/environment_data/data-small.json", help="Path to environment data JSON file")
    parser.add_argument("--folder", type=str, required=True, help="Path to folder containing agent evolution data")
    parser.add_argument("--n-profiles", type=int, default=20, help="Number of profiles to process (default: 20)")

    args = parser.parse_args()
    
    all_env_data = load_json(args.env_data)

    # Compute factual consistency (Memory -> States)
    compute_factual_consistency(args.folder, args.env_data, n_profiles=args.n_profiles)