import asyncio
import os
import time
import argparse
import shutil
from dataclasses import dataclass
from typing import Dict, List, Any, Tuple
from tqdm.asyncio import tqdm

from experiment.dataset import load_data
from experiment.module import set_module, ModuleMath, ModuleCode, ModuleMhop, get_iter_count
from experiment.utils import (
    duration_formatter,
    load_json,
    save_json,
    get_next_log_file,
    get_file_count,
    clean_code
)
from llm import get_token, get_call_count, set_model

LOG_DIR = "log/{dataset}/{model}/{size}"
ENSEMBLE_NUM = 5

# Dataset configuration
@dataclass
class DatasetConfig:
    question_key: str
    answer_key: str
    module_type: str
    scoring_function: str
    
    def requires_context(self) -> bool:
        return self.module_type in ["multi-hop", "code"]

# Dataset configuration mapping
DATASET_CONFIGS = {
    "gsm8k": DatasetConfig(question_key="question", answer_key="answer", 
                          module_type="math", scoring_function="score_math"),
    "math": DatasetConfig(question_key="problem", answer_key="solution", 
                         module_type="math", scoring_function="score_math"),
    "aime": DatasetConfig(question_key="Problem", answer_key="Answer", 
                         module_type="math", scoring_function="score_math"),
    "math500": DatasetConfig(question_key="problem", answer_key="solution", 
                         module_type="math", scoring_function="score_math"),
    "bbh": DatasetConfig(question_key="input", answer_key="target", 
                        module_type="multi-choice", scoring_function="score_mc"),
    "mmlu": DatasetConfig(question_key=["Question", "A", "B", "C", "D"], answer_key="Answer", 
                         module_type="multi-choice", scoring_function="score_mc"),
    "hotpotqa": DatasetConfig(question_key="question", answer_key="answer", 
                             module_type="multi-hop", scoring_function="score_mh"),
    "longbench": DatasetConfig(question_key="input", answer_key="answers", 
                              module_type="multi-hop", scoring_function="score_mh"),
    "mbpp": DatasetConfig(question_key="text", answer_key="test_list", 
                         module_type="code", scoring_function="score_code"),
    "livecodebench_atcoder": DatasetConfig(question_key="question_content", answer_key="public_test_cases",
                                   module_type="code", scoring_function="score_code"),
    "livecodebench_leetcode": DatasetConfig(question_key="question_content", answer_key="public_test_cases",
                                   module_type="code", scoring_function="score_code"),
    "livecodebench": DatasetConfig(question_key="question_content", answer_key="public_test_cases",
                                   module_type="code", scoring_function="score_code")
}

class ExperimentRunner:
    def __init__(self, dataset: str, model: str, start: int = 0, end: int = -1, mode: str = "atom"):
        self.dataset = dataset
        self.model = model
        self.start = start
        self.end = None if end == -1 else end
        self.interval = "full" if self.end is None else f"{start}-{end}"
        self.timestamp = time.time()
        self.mode = mode

        if dataset not in DATASET_CONFIGS:
            raise ValueError(f"Unsupported dataset: {dataset}")
            
        self.config = DATASET_CONFIGS[dataset]
        
        MODULE_CONFIGS = {
            "math": ModuleMath(),
            "code": ModuleCode(),
            "multi-hop": ModuleMhop(),
        }
        self.module = MODULE_CONFIGS[self.config.module_type]
        set_model(model)
        self.method = mode

    async def gather_results(self, testset: List[Dict[str, Any]]) -> List[Any]:
        set_module(self.config.module_type)
        question_key = self.config.question_key
        tasks = []
        
        async def process_single_item(item):
            question = (self._format_question_from_keys(item, question_key) 
                       if isinstance(question_key, list) 
                       else item[question_key])
            
            context = None

            if self.config.requires_context():
                if self.config.module_type == "multi-hop":
                    from experiment.prompter.multihop import contexts
                    context = contexts(item, self.dataset)

                elif self.config.module_type == "code":
                    answer_key = self.config.answer_key
                    func_format = item[answer_key][0]
                    context = func_format

            
            results = {}
            
            if self.method == "direct":
                # direct_args = {"contexts": context} if context else {}
                direct_result = await self.module.direct(question=question, contexts=context)
                results["direct"] = ({"method": "direct", "response": direct_result.get("response", ""), "answer": direct_result.get("answer", "")}, 
                                   direct_result)
            
            elif self.method == "decompose":
                decompose_args = {"contexts": context} if context else {}
                decompose_result = await self.module.decompose(question=question, **decompose_args)
                results["decompose"] = ({"method": "decompose", "response": decompose_result.get("response", ""), "answer": decompose_result.get("answer", "")}, 
                                      decompose_result)
            
            elif self.method == "ensemble":
                # Read direct results from multiple log files
                direct_log_dir = LOG_DIR.format(dataset=self.dataset, model=self.model, size=self.interval) + "/direct"
                
                if not os.path.exists(direct_log_dir):
                    raise ValueError(f"Direct results directory not found: {direct_log_dir}")
                
                # Get all json files from direct directory (up to 5 files)
                global ENSEMBLE_NUM
                direct_files = [f for f in os.listdir(direct_log_dir) if f.endswith('.json') and not f.endswith('dag.json')][:ENSEMBLE_NUM]
                
                if not direct_files:
                    raise ValueError(f"No direct result files found in {direct_log_dir}")
                
                # Load direct results from each file
                direct_results = []
                for file in direct_files:
                    file_path = os.path.join(direct_log_dir, file)
                    results_data = load_json(file_path)
                    # Find the matching problem in results data
                    for result_item in results_data:
                        if result_item["problem"] == question:
                            direct_results.append({
                                "response": result_item.get("response", ""),
                                "answer": result_item.get("answer", "")
                            })
                            break
                
                if not direct_results:
                    raise ValueError(f"No matching direct results found for question: {question[:50]}...")
                
                # Call ensemble with direct results
                ensemble_result = await self.module.ensemble(question=question, solutions=direct_results)
                
                results["ensemble"] = ({"method": "ensemble", "response": ensemble_result.get("response", ""), "answer": ensemble_result.get("answer", "")}, 
                                      {"direct_results": direct_results})
            
            elif self.method == "atom":
                # Read direct result from log file
                direct_log_dir = LOG_DIR.format(dataset=self.dataset, model=self.model, size=self.interval) + "/direct"
                direct_log_file = os.path.join(direct_log_dir, f"{self.start}.json")
                direct_result = None
                if os.path.exists(direct_log_file):
                    direct_results = load_json(direct_log_file)
                    direct_result = direct_results[1]
                
                # Read decompose result from log file
                decompose_log_dir = LOG_DIR.format(dataset=self.dataset, model=self.model, size=self.interval) + "/decompose"
                decompose_log_file = os.path.join(decompose_log_dir, f"{self.start}.json")
                decompose_result = None
                if os.path.exists(decompose_log_file):
                    decompose_results = load_json(decompose_log_file)
                    decompose_result = decompose_results[1]
                
                # process atom 
                if self.config.module_type == "code":
                    atom_result = await self.module.atom(
                        question=question,
                        contexts=context,
                        direct_result=direct_result,
                        decompose_result=decompose_result,
                        test_cases=item.get(self.config.answer_key)
                    )
                else:    
                    atom_result = await self.module.atom(
                        question=question,
                        contexts=context,
                        direct_result=direct_result,
                        decompose_result=decompose_result
                    )
                results["atom"] = atom_result
            
            return results

        tasks = [process_single_item(item) for item in testset]
        return await tqdm.gather(*tasks, desc=f"Processing {self.dataset} tasks")
    
    def _format_question_from_keys(self, item: Dict[str, Any], keys: List[str]) -> str:
        # When question_key is a list, concatenate values from multiple keys into a single question
        parts = []
        for key in keys:
            if key in item:
                parts.append(f"{key}: {item[key]}")
        return "\n".join(parts)
    
    def construct_entry(self, result: Tuple[Dict[str, Any], Any], data: Dict[str, Any], method: str) -> Dict[str, Any]:
        # Construct result entry
        test_cases = None
        if self.mode == "atom":
            if self.config.module_type == "code":
                result_data, log, test_cases = result
            else:
                result_data, log = result
        else:
            result_data, log = result
        question_key = self.config.question_key
        answer_key = self.config.answer_key
        
        # Handle case where question_key is a list
        if isinstance(question_key, list):
            question = self._format_question_from_keys(data, question_key)
        else:
            question = data[question_key]
            
        groundtruth = data[answer_key]
        if test_cases:
            groundtruth = test_cases
        entry = {
                "problem": question,
                "groundtruth": groundtruth,
                "response": result_data.get("response")
        }
        if method == "atom":
            entry["answer"] = result_data.get("answer")
            entry["log"] = log
        elif method == "direct":
            entry["answer"] = result_data.get("answer")
        elif method == "decompose":
            entry["sub-questions"] = log.get("sub-questions")
            entry["answer"] = result_data.get("answer")
        elif method == "ensemble":
            entry["answer"] = result_data.get("answer")
            entry["direct_results"] = log.get("direct_results")
        
        # Dynamically import scoring function
        scoring_function = getattr(__import__(f"experiment.utils", fromlist=[self.config.scoring_function]), 
                                  self.config.scoring_function)
        
        # Pass different parameters based on scoring function
        if self.config.scoring_function == "score_math":
            entry["score"] = scoring_function(entry["answer"], groundtruth, self.dataset)
        elif self.config.scoring_function == "score_code":
            entry["score"] = scoring_function(entry["answer"], groundtruth, entry["problem"])
        else:
            entry["score"] = scoring_function(entry["answer"], groundtruth)
        return entry
    
    def update_score_log(self, accuracies: Dict[str, float]) -> None:
        score_log_file = LOG_DIR.format(
            dataset=self.dataset,
            model=self.model,
            size=self.interval
        ) + "/score.json"
        
        os.makedirs(os.path.dirname(score_log_file), exist_ok=True)
        
        existing_log = load_json(score_log_file) if os.path.exists(score_log_file) else {}
        
        count = get_file_count(
            LOG_DIR.format(dataset=self.dataset, model=self.model, size=self.interval) + f"/{self.method}",
            self.interval, 
            self.dataset,
            exclude_score=True
        )
        
        log_entry = {
            "start": self.start,
            "end": self.end,
            "token": {"prompt": get_token()[0], "completion": get_token()[1]},
            "call_count": get_call_count(),
            "accuracy": accuracies[self.method],
        }
        
        if self.method not in existing_log:
            existing_log[self.method] = {}
        
        existing_log[self.method][str(count)] = log_entry
        
        save_json(score_log_file, existing_log)
    
    async def run(self) -> float:
        print(f"Running {self.mode} experiment on {self.dataset} dataset from index {self.start} to {self.end}")
        
        testset = load_data(self.dataset, "test")[self.start:self.end]
        results = await self.gather_results(testset)
        
        accuracies = {}
        
        # Extract results for current method
        method_results = [result[self.method] for result in results]
        
        # Build results
        json_obj = [self.construct_entry(result, data, self.method) 
                   for result, data in zip(method_results, testset)]
        accuracy = sum(entry["score"] for entry in json_obj) / len(json_obj)
        accuracies[self.method] = accuracy
        
        # 创建方法目录
        method_dir = LOG_DIR.format(dataset=self.dataset, model=self.model, size=self.interval) + f"/{self.method}"
        os.makedirs(method_dir, exist_ok=True)
        
        # 获取下一个日志文件名
        log_file = get_next_log_file(
            method_dir,
            self.interval,
            self.dataset
        )
        save_json(log_file, json_obj)
        if self.mode == "atom":
            dag_log_file = log_file.split(".")[0] + "dag.json"
            shutil.copy("dag.json", dag_log_file)
            os.remove("dag.json")

        # 更新统一的 score.json
        self.update_score_log(accuracies)

        # Print summary
        print(f"\n{self.method.upper()} Results:")
        print(f"Unsolved: {round((1-accuracy) * len(testset))}")
        print(f"Accuracy: {accuracy:.4f}")
        if self.mode == "atom":
            print(f"Iteration_count: {get_iter_count()}")
        
        print(f"\nTime taken: {duration_formatter(time.time() - self.timestamp)}")
        
        return accuracy  # Return accuracy for the specified mode

async def main():
    # Main function
    parser = argparse.ArgumentParser(description='Run experiments on various datasets')
    parser.add_argument('--dataset', type=str, default='math', 
                        choices=list(DATASET_CONFIGS.keys()),
                        help='Dataset to run experiment on')
    parser.add_argument('--start', type=int, default=0, 
                        help='Start index of the dataset')
    parser.add_argument('--end', type=int, default=500, 
                        help='End index of the dataset (-1 for all)')
    parser.add_argument('--model', type=str, default='gpt-4o-mini',
                        help='Model to use for the experiment') 
    parser.add_argument('--mode', type=str, choices=['atom', 'direct', 'decompose'], default='direct',
                        help='Mode: atom (standard experiment), direct (direct solving), or decompose (decomposition only)')
    
    args = parser.parse_args()
    args.end = 30
    args.mode = "direct"
    args.dataset = "aime"
    args.model = "deepseek-ai/DeepSeek-R1"
    print(args.dataset, args.model, args.start, args.end, args.mode)
    global ENSEMBLE_NUM
    ENSEMBLE_NUM = 3
    
    if args.mode in ['atom', 'direct', 'decompose', 'ensemble']:
        runner = ExperimentRunner(
            dataset=args.dataset,
            model=args.model,
            start=args.start,
            end=args.end,
            mode=args.mode
        )
        await runner.run()
    else:
        raise ValueError(f"Invalid mode: {args.mode}")

if __name__ == "__main__":
    asyncio.run(main())