"""
Validation Interface for Schema-ICL Framework

This module provides a validation interface that allows you to ensure a specific problem
has been run the target number of iterations with a given configuration. If the problem
has been run fewer times than the target, it will execute additional iterations and
append the results to the CSV file.

The interface automatically:
1. Reads the results CSV to count existing iterations
2. Generates appropriate configuration based on the parameters
3. Runs missing iterations using the Schema-ICL pipeline
4. Appends new results to the CSV file with the same structure
"""


import base64
import hashlib

import os
import sys
import json
import pandas as pd
from typing import Dict, Any
import importlib

# Add the root directory to the path so we can import from src
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

subject_to_config = {
    "Chemistry": {"dataset_names": [["GPQA", "Synthetic"]], "problem_names_sizes": [["GPQA", 96]], "mapping_path": "mappings/rerank_mapping_chemistry.json"},
    "Physics": {"dataset_names": [["GPQAPhysics", "SyntheticPhysics"]], "problem_names_sizes": [["GPQAPhysics", 100]], "mapping_path": "mappings/rerank_mapping_physics.json"}
}

knowledge_level_to_config = {
    "Essentially Same": {"retreiver_approach": "paraphrase"},
    "Similar": {"retreiver_approach": "new_question"},
    "Different": {"retreiver_approach": "new_question_exam"},
    "High": {"retreiver_approach": "rag_rerank", "rerank_similarity_start": 0, "rerank_similarity_end": 1.0, "top_k": 1},
    "Medium": {"retreiver_approach": "rag_rerank", "rerank_similarity_start": 0, "rerank_similarity_end": 1.0, "top_k": 5},
    "Low": {"retreiver_approach": "rag_rerank", "rerank_similarity_start": 0, "rerank_similarity_end": 1.0, "top_k": 9}
}

solver_type_to_config = {
    "Baseline": {"solver_type": "BaselineSolver"},
    "Schema Only": {"solver_type": "SchemaSolver"},
    "One-Shot": {"solver_type": "ExampleSolver"},
    "One-Shot + Schema": {"solver_type": "ExampleSchemaSolver"},
    "Example Schema Only": {"solver_type": "ExampleSchemaNoActivationSolver"},
}

solver_model_to_config = {
    "Claude": "Claude",
    "GPT-4o": "GPT4o",
    "GPT-5": "GPT5",
    "GPT-4o Mini": "GPT4oMini",
    "Gemini": "Gemini",
    "Llama-3.1": "Llama3",
    "Ministral": "Ministral",
    "MistralSmall": "MistralSmall",
    "Qwen-3": "Qwen3"
}

def validate_and_run_problem(
    target_iterations: int,
    problem_id: str,
    subject: str,
    knowledge_level: str,
    solver_type: str,
    solver_model: str,
    reattempts: int
) -> None:
    """
    Validates and runs additional iterations for a specific problem if needed.
    
    Args:
        target_iterations: Target number of iterations the problem should be run
        problem_id: The specific problem ID (64-char hash)
        subject: Subject type (Chemistry or Physics)
        knowledge_level: Knowledge level (Similar, Different, etc.)
        solver_type: Solver type (Baseline, Schema Only, etc.)
        solver_model: Solver model (Gemini, GPT-4o Mini, etc.)
    """
    # Validate input parameters
    # load results from subsets/{name[0]}_{name[1]}_{name[2]}_{name[3]}
    results_file = os.path.join('raw_data', f"{subject}_{knowledge_level}_{solver_type}_{solver_model}.csv")


    if target_iterations < 0:
        raise ValueError(f"target_iterations must be non-negative, got {target_iterations}")
    
    if not problem_id or not problem_id.strip():
        raise ValueError("problem_id cannot be empty")
    
    if not results_file or not results_file.strip():
        raise ValueError("results_file cannot be empty")
    
    print(f"Starting validation for problem {problem_id}")
    print(f"Target iterations: {target_iterations}")
    print(f"Configuration: {subject}, {knowledge_level}, {solver_type}, {solver_model}")
    
    # Step 1: Read current results and count existing iterations
    current_iterations = _count_current_iterations(
        results_file, problem_id, subject, knowledge_level, solver_type, solver_model
    )
    
    print(f"Current iterations found: {current_iterations}")
    
    if current_iterations >= target_iterations:
        print(f"Problem already has {current_iterations} iterations (>= {target_iterations}). No additional runs needed.")
        return
    
    # Step 2: Generate config for missing iterations
    needed_iterations = target_iterations - current_iterations
    print(f"Need to run {needed_iterations} additional iterations")
    
    config = _generate_config(subject, knowledge_level, solver_type, solver_model, needed_iterations)
    
    
    # Step 3: Run the missing iterations
    print("Running missing iterations...")
    new_results = _run_problem_iterations(problem_id, config, needed_iterations, reattempts)
    
    # Step 4: Append results to CSV
    print(f"Appending {len(new_results)} new results to {results_file}")
    _append_results_to_csv(results_file, new_results, subject, knowledge_level, solver_type, solver_model)
    
    print("Validation completed successfully!")

def _encode_string_to_id(input_string):
    return base64.urlsafe_b64encode(input_string.encode()).decode()

def _decode_id_to_string(encoded_id):
    return base64.urlsafe_b64decode(encoded_id.encode()).decode()


def _count_current_iterations(
    results_file: str, 
    problem_id: str, 
    subject: str, 
    knowledge_level: str, 
    solver_type: str, 
    solver_model: str
) -> int:
    """Count existing iterations for the specific problem and configuration."""
    try:
        df = pd.read_csv(results_file)
        print(len(df))

        # Filter for exact match of all parameters
        filtered_df = df[
            (df['problem_id'] == problem_id) &
            (df['subject'] == subject) &
            (df['knowledge_base'] == 'dense') &  # Always dense as specified
            (df['knowledge_level'] == knowledge_level) &
            (df['solver_type'] == solver_type) &
            (df['solver_model'] == solver_model)
        ]
        
        return len(filtered_df)
    except FileNotFoundError:
        print(f"Results file {results_file} not found. Starting with 0 iterations.")
        return 0
    except Exception as e:
        print(f"Error reading results file: {e}")
        return 0


def _load_model_classes_and_templates(config: Dict[str, Any]) -> None:
    """Load model classes and schema templates directly (same logic as config.py)."""
    # Load model classes
    for model_key in ["solver_model", "memory_model", "schema_generator_model", "schema_activator_model"]:
        if model_key in config:
            module_path = f'src.entity.models.{config[model_key]}'
            module = importlib.import_module(module_path)
            config[model_key] = getattr(module, config[model_key])
    
    # Load schema template
    if "schema_template" in config:
        module_path = f'src.entity.schema_templates.{config["schema_template"]}'
        module = importlib.import_module(module_path)
        config["SCHEMA_RESPONSE_CLASS"] = getattr(module, "Response")
        config["SCHEMA_PROMPT"] = getattr(module, "SCHEMA_PROMPT")
        config["SCHEMA_SAMPLE_QUESTION"] = getattr(module, "SCHEMA_SAMPLE_QUESTION")
        config["SCHEMA_SAMPLE_RESPONSE"] = getattr(module, "SCHEMA_SAMPLE_RESPONSE")
        config["SCHEMA_SOLVER_PROMPT"] = getattr(module, "SCHEMA_SOLVER_PROMPT")


def _generate_config(
    subject: str, 
    knowledge_level: str, 
    solver_type: str, 
    solver_model: str, 
    repeat_num: int
) -> Dict[str, Any]:
    """Generate a complete configuration dictionary based on the parameters."""
    
    # Start with base config (using chemistry as template)
    base_config = {
        "repeat_num": repeat_num,
        "shuffle": True,
        "embedding_model": "Alibaba-NLP/gte-base-en-v1.5",
        "exclude_self": True,
        "start_similarity": 0,
        "end_similarity": 1,
        "fallback_mode": False,
        "schema_activator_type": "NormalSchemaActivator",
        "memory_type": "semantic",
        "memory_including_answer": True,
        "context_length": 4096,
        "past_knowledge_mode": "question+answer",
        "top_k": 1,
        "random_from_k": False,
        "schema_generator_type": "IdentitySchemaGenerator",
        "schema_template": "ChemistrySchema"
    }
    
    # Apply subject-specific configurations
    if subject in subject_to_config:
        base_config.update(subject_to_config[subject])
    else:
        raise ValueError(f"Unknown subject: {subject}")
    
    # Apply knowledge level configurations
    if knowledge_level in knowledge_level_to_config:
        base_config.update(knowledge_level_to_config[knowledge_level])
    else:
        raise ValueError(f"Unknown knowledge_level: {knowledge_level}")
    
    # Apply solver type configurations
    if solver_type in solver_type_to_config:
        base_config.update(solver_type_to_config[solver_type])
    else:
        raise ValueError(f"Unknown solver_type: {solver_type}")
    
    # Apply solver model configurations
    if solver_model in solver_model_to_config:
        model_class_name = solver_model_to_config[solver_model]
        base_config.update({
            "solver_model": model_class_name,
            "schema_activator_model": model_class_name,
            "schema_generator_model": model_class_name,
            "memory_model": "GPT4oMini"  # Keep memory model as GPT4oMini as in base configs
        })
    else:
        raise ValueError(f"Unknown solver_model: {solver_model}")
    
    return base_config


def _run_problem_iterations(problem_id: str, config: Dict[str, Any], iterations: int, reattempts: int) -> list:
    """Run the specified problem for the given number of iterations using the real Schema-ICL pipeline."""
    
    try:
        # Load model classes and schema templates directly (same logic as config.py)
        _load_model_classes_and_templates(config)
        
        # Create a global config module that other modules can import from
        import sys
        import types
        config_module = types.ModuleType('config')
        config_module.config = config
        sys.modules['config'] = config_module
        
        # Import modules
        from src.usecase.datasets.controller.LoadingKnowledgeBaseController import LoadingKnowledgeBaseController
        from src.usecase.datasets.controller.LoadingProblemsController import LoadingProblemsController  
        from src.usecase.datasets.controller.GenerateMentalRepresentationController import GenerateMentalRepresentationController
        from src.entity.embedder.STEmbedder import STEmbedder
        
        # Get dataset, knowledge base, and problems
        dataset_names_list = config['dataset_names'][0]
        knowledge_base = LoadingKnowledgeBaseController(dataset_names_list)
        problems = LoadingProblemsController(config['problem_names_sizes'][0][0], config['problem_names_sizes'][0][1])
        
        # Initialize models (classes are already loaded by _load_model_classes_and_templates)
        memory_model = config['memory_model'](ctx_len=config['context_length'])
        solver_model = config['solver_model'](ctx_len=config['context_length'])
        schema_activator_model = config['schema_activator_model'](ctx_len=config['context_length'])
        
        # Get embedder
        embedder = STEmbedder(config['embedding_model'])
        
        # Generate mental representations
        problems = GenerateMentalRepresentationController(problems, config['memory_type'], memory_model, including_answer=False,embedder = embedder)
        knowledge_base.knowledges = GenerateMentalRepresentationController(
            knowledge_base.knowledges, 
            config['memory_type'],
            memory_model,
            including_answer=config['memory_including_answer'],
            embedder=embedder
        )

        # Add embeddings to knowledge base
        knowledge_base.add_embeddings()
        
        # Find the specific problem
        target_problem = None
        for problem in problems:
            if problem.id == problem_id:
                target_problem = problem
                break
        
        if target_problem is None:
            raise ValueError(f"Problem with ID {problem_id} not found in datasets {dataset_names_list}")
        
        results = []
        
        # Run for specified iterations
        error_count = 0
        iteration = 0
        while iteration < iterations:
            try:
                # Call our adapted handle_single_problem function
                log = _handle_single_problem(
                    target_problem, knowledge_base, config, solver_model, schema_activator_model
                )
                results.append(log)
                print(f"Completed iteration {iteration + 1}/{iterations}")
                iteration += 1
            except Exception as e:
                error_count+=1
                if error_count >= reattempts:
                    print(f"Too many errors encountered, stopping further iterations {e}.")
                    break
                else:
                    print(f"Error in iteration {iteration + 1}: {e}. Re trying...")
                    continue
        
        return results
        
    except Exception as e:
        print(f"Error in _run_problem_iterations: {e}")
        raise


def _handle_single_problem(problem, knowledge_base, config, solver_model, schema_activator_model):
    """Handle a single problem execution - adapted from sdl.py for validation context."""
    from src.usecase.retreival.GetRelevantKnowledgesController import GetRelevantKnowledgesController
    from src.usecase.datasets.controller.FormatPastKnowledgeController import FormatPastKnowledgeController
    from src.usecase.schema_activation.SchemaActivationController import SchemaActivationController
    from src.usecase.solver.SolveController import SolveController

    log = {}
    log["problem"] = str(problem)
    log["problem_id"] = problem.id  # Store the actual problem ID
    log["ground_truth"] = problem.get_ground_truth()
    
    if config['solver_type'] == "BaselineSolver":
        response_content, final_answer = SolveController(
            problem, None, "BaselineSolver", solver_model, None, None
        )
        log["baseline"] = True
        log["final_answer"] = final_answer
        log["solver_response"] = response_content
        log["correct"] = final_answer == problem.get_ground_truth()
    else:
        # EXACT same parameter extraction as sdl.py (lines 70-74)
        start_similarity = config.get('start_similarity', None)
        end_similarity = config.get('end_similarity', None)
        rerank_similarity_start = config.get('rerank_similarity_start', None)
        rerank_similarity_end = config.get('rerank_similarity_end', None)
        mapping_path = config.get('mapping_path', None)

        # EXACT same GetRelevantKnowledgesController call as sdl.py (lines 76-80)
        relevant_knowledges, similarity = GetRelevantKnowledgesController(problem, knowledge_base, config['top_k'], approach=config['retreiver_approach'],
                                                                          start_similarity=start_similarity, end_similarity=end_similarity,
                                                                          exclude_self=config['exclude_self'], random_from_k=config['random_from_k'],
                                                                          rerank_similarity_start=rerank_similarity_start, rerank_similarity_end=rerank_similarity_end,
                                                                          mapping_path=mapping_path)
        if relevant_knowledges is None or similarity is None:
            print("No relevant knowledge found")
            if config['fallback_mode']:
                # call fallback mode - EXACT same as sdl.py (lines 85-90)
                response_content, final_answer = SolveController(problem, None, "BaselineSolver", solver_model, None, None)
                log["baseline"] = True
                log["final_answer"] = final_answer
                log["solver_response"] = response_content
                log["correct"] = final_answer == problem.get_ground_truth()
            else:
                log["correct"] = None

        else:
            log["baseline"] = False
            formatted_knowledges = FormatPastKnowledgeController(relevant_knowledges, similarity, config['past_knowledge_mode'])

            log["relevant_knowledges"] = [str(knowledge) for knowledge in relevant_knowledges]
            log["formatted_knowledges"] = formatted_knowledges

            log["similarity"] = similarity

            schema_for_solver = None
            if config['solver_type'] != "ExampleSolver":
                response_content, schema_for_solver = SchemaActivationController(problem, formatted_knowledges,
                                                                      schema_activator_type=config['schema_activator_type'],
                                                                      schema_activator_model=schema_activator_model)
                log["activated_schema"] = schema_for_solver
                log["original_sensory_memory"] = problem.mental_representation
                log["schema_activation_response"] = response_content

            response_content, final_answer = SolveController(problem, schema_for_solver, config['solver_type'], solver_model,
                                                             relevant_knowledges, formatted_knowledges)
            log["final_answer"] = final_answer
            log["solver_response"] = response_content
            log["correct"] = final_answer == problem.get_ground_truth()
    
    return log


def _append_results_to_csv(
    results_file: str, 
    new_results: list, 
    subject: str, 
    knowledge_level: str, 
    solver_type: str, 
    solver_model: str
) -> None:
    """Append new results to the CSV file."""
    # Convert results to the CSV format
    csv_rows = []
    for result in new_results:
        if result.get("correct") is not None:  # Only add if we have a result (not None)
            # Extract problem ID from the log
            problem_id = _extract_problem_id_from_log(result)
            
            csv_rows.append({
                'process_id': _encode_string_to_id(str(result)),
                'problem_id': problem_id,
                'correctness': int(result["correct"]),
                'subject': subject,
                'knowledge_base': 'dense',  # Always dense as specified
                'knowledge_level': knowledge_level,
                'solver_type': solver_type,
                'solver_model': solver_model
            })
    
    if not csv_rows:
        print("No valid results to append.")
        return
    
    # Create DataFrame and append to CSV
    new_df = pd.DataFrame(csv_rows)
    
    # Check if file exists and has content
    try:
        if os.path.exists(results_file) and os.path.getsize(results_file) > 0:
            # Ensure existing file ends with newline before appending
            with open(results_file, 'r+b') as f:
                f.seek(-1, os.SEEK_END)
                if f.read(1) != b'\n':
                    f.write(b'\n')
            # Append without header
            new_df.to_csv(results_file, mode='a', header=False, index=False)
        else:
            # Write with header
            new_df.to_csv(results_file, mode='w', header=True, index=False)
    except Exception as e:
        print(f"Error writing to CSV: {e}")


def _extract_problem_id_from_log(result: dict) -> str:
    """Extract problem ID from the result log."""
    # The problem ID is now directly stored in the log
    return result.get("problem_id", "")

def _generate_hash(input_string):
    """
    Generate a 36-character alphanumeric hash from the input string.

    Args:
        input_string (str): The string to hash.

    Returns:
        str: A 36-character hash.
    """
    # Generate a SHA-256 hash of the input string
    sha256_hash = hashlib.sha256(input_string.encode()).hexdigest()

    # Truncate or extend the hash to 36 characters
    hash_36 = sha256_hash[:36]  # Take the first 36 characters

    return hash_36
def _load_chemistry_problem_ids () -> list:
    # load from the raw csv file
    project_root = os.path.dirname(os.path.abspath(__file__))
    # load from the raw csv file

    df = pd.read_csv(os.path.join(project_root, 'src/entity/datasets/raw_files/gpqa_chemistry.csv'))
    problem_ids = df['id']
    return [str(pid) for pid in problem_ids if pd.notna(pid)]
def _load_physics_problem_ids () -> list:
    # get project root
    project_root = os.path.dirname(os.path.abspath(__file__))
    # load from the raw csv file

    df = pd.read_csv(os.path.join(project_root, 'src/entity/datasets/raw_files/gpqa_diamond_physics.csv'))
    problem_ids = df['id']
    return [str(pid) for pid in problem_ids if pd.notna(pid)]
    # df = pd.read_csv('src/entity/datasets/raw_files/gpqa_diamond_physics.csv')
    # return [_generate_hash(row['Question']) for _, row in df.iterrows() if pd.notna(row['Question'])]

# Example usage and main function for testing
if __name__ == "__main__":
    validate_and_run_problem(
        target_iterations=1,
        problem_id="2662eff7a6231613f45042d2ef5df63caaeb",
        subject="Chemistry",
        knowledge_level="Essentially Same",
        solver_type="One-Shot",
        solver_model="GPT-5",
        reattempts=1
    )

    exit(0)

    # get results_file, subject, knowledge_level, solver_type, solver_model from command line arguments
    import argparse
    parser = argparse.ArgumentParser(description="Validate and run problem iterations.")
    parser.add_argument('--target_iterations', type=int, required=True, help='Target number of iterations')
    parser.add_argument('--subject', type=str, required=True, choices=subject_to_config.keys(), help='Subject type (Chemistry or Physics)')
    parser.add_argument('--knowledge_level', type=str, required=True, choices=knowledge_level_to_config.keys(), help='Knowledge level (Similar, Different, etc.)')
    parser.add_argument('--solver_type', type=str, required=True, choices=solver_type_to_config.keys(), help='Solver type (Baseline, Schema Only, etc.)')
    parser.add_argument('--solver_model', type=str, required=True, choices=solver_model_to_config.keys(), help='Solver model (Gemini, GPT-4o Mini, etc.)')

    # load an optional argument for reattempts, if not provided, default to 20
    parser.add_argument('--reattempts', type=int, default=20, help='Number of reattempts for failed iterations (default: 20)')


    args = parser.parse_args()

    # load problems based on subject
    problem_ids = _load_chemistry_problem_ids() if args.subject == "Chemistry" else _load_physics_problem_ids()

    # load dataset.
    # get project root / subsets / subject_knowledge_level_solver_type_solver_model.csv
    csv_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'subsets', f"{args.subject}_{args.knowledge_level}_{args.solver_type}_{args.solver_model}.csv")

    dataset = pd.read_csv(csv_path)
    # filter by knowledge_level, solver_type, solver_model
    dataset = dataset[(dataset['subject'] == args.subject) & (dataset['knowledge_level'] == args.knowledge_level) & (dataset['solver_type'] == args.solver_type) & (dataset['solver_model'] == args.solver_model)]
    # get all problem ids who appear >= target_iterations times
    existing_problem_ids = dataset['problem_id'].value_counts()
    existing_problem_ids = existing_problem_ids[existing_problem_ids >= args.target_iterations].index.tolist()

    print(f"Total good problems: {len(existing_problem_ids)}")
    # filter out existing_problem_ids from problem_ids
    problem_ids = [pid for pid in problem_ids if pid not in existing_problem_ids]
    for problem_id in problem_ids:
        validate_and_run_problem(target_iterations=args.target_iterations,
            problem_id=problem_id, subject=args.subject,
            knowledge_level=args.knowledge_level, solver_type=args.solver_type, solver_model=args.solver_model,
                                 reattempts=args.reattempts)
# python3 validation.py --target_iterations 3 --subject Physics --knowledge_level "Essentially Same" --solver_type
# "Example Schema Only" --solver_model "GPT-4o Mini" --reattempt 3


