# -*- coding: utf-8 -*-
# @Author  : 
# @Desc    :
from typing import List, Dict, Tuple
import time
from dataclasses import dataclass

import fire

from metagpt.logs import logger, define_log_level

from di_project.roles.data_interpreter import DataInterpreter
from InfriAgent_DABench.DABench import DABench

@dataclass
class EvaluationResult:
    id: str
    prediction: str
    label: str
    is_correct: bool
    time_cost: float
    code_str: str = ""
    error_msg: str = ""


async def evaluate_single_case(DA: DABench, case_id: str) -> EvaluationResult:
    """Evaluate a single test case"""
    start_time = time.time()
    try:
        define_log_level(name=f"InfiAgent-{case_id}")
        requirement = DA.get_prompt(case_id)
        logger.info(f"Requirement: {requirement}")
        
        # Initialize and run interpreter
        di = DataInterpreter()
        result = await di.run(requirement)
        
        # Extract results
        prediction, is_correct = DA.eval(case_id, result.content)
        
        return EvaluationResult(
            id=case_id,
            prediction=str(result.content),
            label=str(DA.get_answer(case_id)),
            is_correct=bool(is_correct),
            time_cost=time.time() - start_time
        )
    
    except Exception as e:
        logger.error(f"Error processing case {case_id}: {str(e)}")
        prediction, is_correct = DA.eval(case_id, "")
        return EvaluationResult(
            id=case_id,
            prediction="",
            label=str(DA.get_answer(case_id)),
            is_correct=bool(is_correct),
            time_cost=time.time() - start_time,
            error_msg=str(e)
        )


async def main_sync(target_ids: List[str] = ['all'], save_name: str = "update") -> Dict:
    """Synchronous evaluation of multiple test cases"""
    DA = DABench()
    if target_ids == ['all']:
        target_ids = list(DA.answers.keys())
    logger.info(f"Evaluating cases: {target_ids}")
    results = []
    for case_id in target_ids:
        if case_id not in DA.answers:
            continue
        result = await evaluate_single_case(DA, case_id)
        results.append(result)
        logger.info(f"Completed case {case_id}: {'Success' if not result.error_msg else 'Failed'}")
    
    successful_cases = [r for r in results if not r.error_msg]
    avg_time = sum(r.time_cost for r in successful_cases) / len(successful_cases) if successful_cases else 0
    
    summary = {
        'results': results,
        'statistics': {
            'total_cases': len(results),
            'successful_cases': len(successful_cases),
            'failed_cases': len(results) - len(successful_cases),
            'average_time': avg_time,
            'total_time': sum(r.time_cost for r in results)
        }
    }
    
    logger.info(f"Evaluation completed. Average time cost: {avg_time:.2f} seconds")
    return summary

if __name__ == "__main__":
    fire.Fire(main_sync)
