import os
import sys
import csv
import json
import argparse
from typing import Dict, List
from tqdm import tqdm

sys.path.append('../src')
from data_tools import load_feedbackqa, load_helpsteer2, load_ultrafeedback

def load_judge_programs(storage_dir: str = "synthesized_program_judges") -> Dict[str, Dict]:
    """
    Load all judge programs and their metadata from the storage directory.
    
    Args:
        storage_dir (str): Directory where judge programs and metadata are stored.
    
    Returns:
        Dict[str, Dict]: Dictionary of program IDs mapped to their metadata and code.
    """
    programs = {}
    metadata_file = os.path.join(storage_dir, 'programs_metadata.json')

    if not os.path.exists(metadata_file):
        print(f"Error: Metadata file {metadata_file} not found.")
        return programs

    try:
        with open(metadata_file, 'r') as f:
            metadata = json.load(f)

        for prog_id, meta in metadata.items():
            prog_file = os.path.join(storage_dir, f"{prog_id}.py")
            if os.path.exists(prog_file):
                with open(prog_file, 'r') as f:
                    programs[prog_id] = {
                        'code': f.read(),
                        'description': meta.get('description', 'No description'),
                        'function_name': meta.get('function_name', ''),
                        'criteria': meta.get('criteria', 'Unknown'),
                        'file_path': meta.get('file_path', prog_file)
                    }
            else:
                print(f"Warning: Program file {prog_file} for {prog_id} not found.")
        return programs

    except json.JSONDecodeError:
        print(f"Error: Metadata file {metadata_file} is corrupted or invalid.")
        return programs

def test_judge_program(program: Dict, query: str, response: str) -> Dict:
    """
    Test a single judge program with the given query and response.
    
    Args:
        program (Dict): Program metadata including code and function name.
        query (str): The input query to test.
        response (str): The response to evaluate.
    
    Returns:
        Dict: Result of the judge program or an error result.
    """
    try:
        namespace = {}
        exec(program['code'], namespace)
        judge_function = namespace.get(program['function_name'])

        if not judge_function:
            return {
                'score': 0.0,
                'reasoning': f"Function {program['function_name']} not found in program code.",
                'criteria': program['criteria'],
                'executable': False
            }

        result = judge_function(query, response)
        if not (isinstance(result, dict) and all(key in result for key in ['score', 'reasoning', 'criteria'])):
            return {
                'score': 0.0,
                'reasoning': "Invalid output format from judge function. Expected dict with keys: score, reasoning, criteria.",
                'criteria': program['criteria'],
                'executable': False
            }
            
        result['executable'] = True
        return result

    except Exception as e:
        return {
            'score': 0.0,
            'reasoning': f"Error executing program: {str(e)}",
            'criteria': program['criteria'],
            'executable': False
        }

def save_results_to_csv(results: List[Dict], output_file: str):
    """
    Save judge results to a CSV file with prog_id as columns, question and answer as columns, and scores as cells.
    
    Args:
        results (List[Dict]): List of result dictionaries containing question, answer, prog_id, and score.
        output_file (str): Path to the output CSV file.
    """
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    # Get unique program IDs
    prog_ids = sorted(set(result['prog_id'] for result in results))
    
    # Organize results into a dictionary by question and answer
    entry_results = {}
    for result in results:
        key = (result['question'], result['answer'])
        if key not in entry_results:
            entry_results[key] = {}
        entry_results[key][result['prog_id']] = result['score']
    
    # Write to CSV
    with open(output_file, 'w', newline='', encoding='utf-8') as f:
        fieldnames = ['question', 'answer'] + prog_ids
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        
        for (question, answer), scores in entry_results.items():
            row = {'question': question, 'answer': answer}
            for prog_id in prog_ids:
                row[prog_id] = scores.get(prog_id, 0.0)
            writer.writerow(row)

def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Run judge programs on a specified dataset.")
    parser.add_argument('dataset', choices=['feedbackqa', 'helpsteer2', 'ultrafeedback'],
                        help="Dataset to process: feedbackqa, helpsteer2, or ultrafeedback")
    args = parser.parse_args()
    dataset_name = args.dataset

    # Load all judge programs
    programs = load_judge_programs()
    if not programs:
        print("No judge programs found. Please generate programs first.")
        return

    # Load the specified dataset
    if dataset_name == "feedbackqa":
        data_path = "../data/feedbackqa/feedback_train.json"
        ratings = load_feedbackqa(data_path)
    elif dataset_name == "helpsteer2":
        data_path = "../data/helpsteer2/helpsteer2_valid.json"
        ratings = load_helpsteer2(data_path)
    elif dataset_name == "ultrafeedback":
        data_path = "../data/ultrafeedback_sampled.csv"
        ratings = load_ultrafeedback(data_path)

    questions = ratings["question"]
    answers = ratings["answer"]

    print(f"\nProcessing dataset: {dataset_name}")
    print("-" * 50)

    results = []
    
    # Process each query-response pair with tqdm progress bar
    for i, (query, response) in enumerate(tqdm(zip(questions, answers), total=len(questions), desc=f"Processing {dataset_name}")):
        for prog_id, program in programs.items():
            print(f"  Testing {prog_id}: {program['description']} (Criteria: {program['criteria']})")
            result = test_judge_program(program, query, response)
            
            print(f"    Executable: {'✅' if result['executable'] else '🚫'}")
            print(f"    Score: {result['score']}")
            
            results.append({
                'question': query,
                'answer': response,
                'prog_id': prog_id,
                'score': result['score'],
                'criteria': result['criteria'],
                'executable': result['executable']
            })

    # Save results to CSV
    program_judge_output_dir = "./program_judge_outputs"
    output_file = os.path.join(program_judge_output_dir, f"{dataset_name}_program_judge_results.csv")
    save_results_to_csv(results, output_file)
    print(f"\nResults saved to {output_file}")

if __name__ == "__main__":
    main()