import os
import copy
import logging
import pandas as pd
from tqdm.asyncio import tqdm_asyncio
from pathlib import Path
from datetime import datetime
from LLM_call import LLMModel
from LLM_TM import LLMTM
from LLM_CG import LLMCG, CodeRunner
from Evaluator import Evaluator
from PR_predictor import PassRatePredictor
from collections import defaultdict
from typing import Dict, List
from tqdm import tqdm

# 创建Logger实例
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

# 定义日志格式
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# 输出到文件的Handler
log_dir = Path("log")
log_dir.mkdir(exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_filename = f"{timestamp}.log"
log_path = log_dir / log_filename
file_handler = logging.FileHandler(log_path)
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)

# 输出到控制台的Handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)

# 将Handler添加到Logger
logging.basicConfig(
    level=logging.DEBUG,
    datefmt="%Y-%m-%d %H:%M:%S",
    handlers=[
        file_handler,   # 文件输出
        console_handler # 控制台输出
    ]
)

class LCDP():
    def __init__(self, api_key, model="gpt-3.5-turbo", max_workers=5, ignore_advice=False, use_pr_predictor=True, use_web_search=False):
        self.llm_model = LLMModel(api_key, model, use_web_search=use_web_search)
        self.code_runner = CodeRunner(max_workers=max_workers)
        self.use_pr_predictor = use_pr_predictor
        if use_pr_predictor:
            self.pass_rate_predictor = PassRatePredictor()
        else:
            self.pass_rate_predictor = None
        self.evaluator = Evaluator(self.pass_rate_predictor)
        self.task_description = None
        self.current_plan = None
        self.test_weights = {}
        self.test_cases = {}
        self.test_timeout = None
        self.ignore_advice = ignore_advice

    def initialize(self):
        if self.use_pr_predictor:
            self.pass_rate_predictor = PassRatePredictor()
            self.evaluator = Evaluator(self.pass_rate_predictor)
        else:
            self.pass_rate_predictor = None
        self.task_description = None
        self.current_plan = None
        self.test_weights = {}
        self.test_cases = {}
        self.test_timeout = None

    async def run(self, task_description, max_iterations=3, stop_t=0.8, example_dataset=None,
                 num_plans=3, num_tests=5, num_codes=5, refine_rounds=3, use_pass_rate_for_train=False, test_timeout=None, min_tests=15, max_tests=20, use_async_generation=True, ai4s=False, use_llm_for_refine=False, best_only=False, error_test_num=3, use_data_format_extract=[], record_all_results=False, forced_test_cases=None):
        
        self.initialize()
        # debug
        self.test_cases = {}
        # print("debug0 ####################################################################")
        # print(self.test_cases)

        self.test_timeout = test_timeout
        self.task_description = task_description
        
        # Initialize LLM Task Manager
        self.llmtm = LLMTM(task_description, self.llm_model)
        self.llmcg = LLMCG(task_description, self.llm_model)

        if record_all_results:
            result_dict = {}

        if ai4s:
            # refine the task description
            if use_data_format_extract != []:
                data_format_info = {}
                for file_path in use_data_format_extract:
                    column_names = self.get_column_names(file_path)
                    if column_names is not None:
                        data_format_info[file_path] = column_names
                if data_format_info:
                    data_format_text = self.format_file_columns_text(data_format_info)
                    self.task_description = f"{self.task_description}\n\n{data_format_text}"
            self.task_description = self.llmtm.refine_plan_by_data_analysis(self.task_description, use_llm=use_llm_for_refine)
            # print("debug")
            # print(self.task_description)
            self.llmtm.task_description = self.task_description
            self.llmcg.task_description = self.task_description
        
        # Phase 1: Plan Generation and Refinement
        # print("########################################################################")
        # print("### Phase 1: Plan Generation and Refinement")
        logging.info("########################################################################")
        logging.info("### Phase 1: Plan Generation and Refinement")
        plan, plan_raw = self.llmtm.get_plan()
        if self.ignore_advice:
            self.current_plan = plan
        else:
            self.current_plan = await self._plan_refinement_loop(self.llmtm, plan_raw, refine_rounds)
        self.current_plan = self._plan_format_refinement(self.current_plan)

        # if ai4s:
        #     return self.current_plan
        
        # Phase 2: Test Case Generation and Weighting
        # print("\n########################################################################")
        # print("### Phase 2: Test Case Generation and Weighting")
        logging.info("\n########################################################################")
        logging.info("### Phase 2: Test Case Generation and Weighting")
        # self.test_cases = await self._generate_tests(self.llmtm, num_tests)
        # debug
        # if not ai4s:
        if forced_test_cases is not None:
            self.test_cases = forced_test_cases
        else:
            if use_async_generation:
                self.test_cases = await self._generate_tests_async(num_tests, task_description, use_example=False, original_test_cases={})
            else:
                self.test_cases = self._generate_tests(num_tests, task_description, use_example=False, original_test_cases={})
            self.test_cases = self._filter_test_cases(self.test_cases)

        # print("Calculating test weights...")
        logging.info("Calculating test weights...")
        self.test_weights = self._calculate_test_weights(self.test_cases, example_dataset)
        
        # Phase 3: Iterative Code Generation
        # print("\n########################################################################")
        # print("### Phase 3: Iterative Code Generation")
        logging.info("\n########################################################################")
        logging.info("### Phase 3: Iterative Code Generation")
        best_codes = {}

        # if ai4s:
        #     prompt = self._generate_codes(10, self.current_plan, None, best_codes, best_only=best_only, error_test_num=error_test_num, prompt_only=False)
        #     return prompt

        for iteration in range(max_iterations):
            # print(f"\n=== Iteration {iteration+1}/{max_iterations} ===")
            logging.info(f"\n=== Iteration {iteration+1}/{max_iterations} ===")

            # TODO: based on best code recordings, refine the plan
            # self.current_plan = await self._plan_refinement_loop(self.llmtm, plan_raw, refine_rounds)

            # print("debug1 ####################################################################")
            # print(self.test_cases)

            # if the test_cases are less than min_tests, generate more tests
            if len(self.test_cases) < min_tests:
                num_tests_to_gen = min((min_tests - len(self.test_cases)), 3)
                if use_async_generation:
                    self.test_cases = await self._generate_tests_async(num_tests_to_gen, task_description, use_example=False, original_test_cases=self.test_cases)
                else:
                    self.test_cases = self._generate_tests(num_tests_to_gen, task_description, use_example=False, original_test_cases=self.test_cases)

                self.test_cases = self._filter_test_cases(self.test_cases)
                logging.info(f"Generated {num_tests_to_gen} new test cases.")
                # debug
                test_case_number = len(self.test_cases)
                print(f"Generated {num_tests_to_gen} new test cases, total {test_case_number} test cases.")
                # Recalculate test weights
                logging.info("Recalculating test weights...")
                self.test_weights = self._calculate_test_weights(self.test_cases, example_dataset, test_results=None, previous_weights=self.test_weights)

            # print("debug2 ####################################################################")
            # print(self.test_cases)

            if len(self.test_cases) > max_tests:
                # only keep the top max_tests test cases
                sorted_test_cases = sorted(self.test_cases.items(), key=lambda x: self.test_weights[x[0]], reverse=True)
                self.test_cases = dict(sorted_test_cases[:max_tests])
                self.test_weights = {k: self.test_weights[k] for k in self.test_cases.keys()}
                logging.info(f"Filtered test cases to {max_tests}.")
                # debug
                print(f"Filtered test cases to {max_tests}.")
            
            # Generate new codes
            # new_codes = await self._generate_codes(num_codes)
            # print("debug3 ####################################################################")
            # print(self.test_cases)

            if use_async_generation:
                new_codes = await self._generate_codes_async(num_codes, self.current_plan, self.test_cases, best_codes, best_only=best_only, error_test_num=error_test_num)
            else:
                new_codes = self._generate_codes(num_codes, self.current_plan, self.test_cases, best_codes, best_only=best_only, error_test_num=error_test_num)

            # Evaluate codes
            logging.info("Evaluating codes...")
            scored_codes, filtered_test_result = self._evaluate_codes(new_codes, ai4s=ai4s)
            # remove the test cases that are not in the filtered_test_result
            self.test_cases = {k: v for k, v in self.test_cases.items() if k in list(filtered_test_result.keys())}

            logging.info("training pass_rate_predictor...")
            if self.pass_rate_predictor is None:
                pass
            else:
                self.pass_rate_predictor.add_data(scored_codes, use_pass_rate=use_pass_rate_for_train)
                self.pass_rate_predictor.train_model(epochs=50, batch_size=32, lr=0.001)
            
            # Update best codes
            best_codes = self._select_top_codes(scored_codes, top_k=3)

            if record_all_results:
                result_dict[iteration] = best_codes
            
            # User feedback
            if self.ignore_advice:
                self.current_plan['user_feedback'] = "Based on previous outputs, please improve the code quality."
            elif not await self._get_user_feedback(best_codes):
                self.current_plan['user_feedback'] = "Based on previous outputs, please improve the code quality."

            # check the best code's score (first one)
            try:
                best_code_score = best_codes[list(best_codes.keys())[0]]['pass_rate']
            except Exception as e:
                print(f"Error: {e}")
                print(best_codes)
                print("########################################################################")
                print(new_codes)
                print("########################################################################")
                print(scored_codes)
                raise e
            if best_code_score >= stop_t:
                # print("Best code's score is high enough, stopping iterations.")
                logging.info("Best code's score is high enough, stopping iterations.")
                break
        if record_all_results:
            return result_dict
        return best_codes

    def get_column_names(self, file_path):
        try:
            # Check if the file exists
            if not os.path.exists(file_path):
                print(f"Error: File not found at '{file_path}'")
                return None

            # Get the file extension
            _, file_extension = os.path.splitext(file_path)
            file_extension = file_extension.lower()

            df = None
            if file_extension == '.csv':
                # Read the CSV file
                df = pd.read_csv(file_path)
            elif file_extension in ['.xls', '.xlsx']:
                # Read the Excel file
                df = pd.read_excel(file_path)
            else:
                print(f"Error: Unsupported file type '{file_extension}'. "
                    "This function supports '.csv', '.xls', and '.xlsx' files.")
                return None

            # Get the column names
            if df is not None:
                return df.columns.tolist()
            else:
                # This case should ideally not be reached if logic is correct
                print("Error: DataFrame was not loaded.") 
                return None

        except pd.errors.EmptyDataError:
            print(f"Error: The file '{file_path}' is empty.")
            return None
        except pd.errors.ParserError:
            print(f"Error: Could not parse the file '{file_path}'. It might be corrupted or not a valid {file_extension} file.")
            return None
        except Exception as e:
            print(f"An unexpected error occurred: {e}")
            return None

    def format_file_columns_text(self, file_to_columns_map):
        if not file_to_columns_map:
            return "No file information provided."

        columns_to_files = defaultdict(list)

        for file_path, columns in file_to_columns_map.items():
            frozen_cols = frozenset(columns if columns is not None else [])
            columns_to_files[frozen_cols].append(file_path)

        output_lines = []
        for frozen_cols, file_paths in columns_to_files.items():
            sorted_file_paths = sorted(file_paths)
            joined_file_paths = ", ".join(sorted_file_paths)

            display_columns = sorted(list(frozen_cols))

            line = f"{joined_file_paths} contains following columns: {display_columns}"
            output_lines.append(line)
        return "\n".join(sorted(output_lines))

    async def _plan_refinement_loop(self, llmtm, initial_plan_raw, max_rounds):
        current_plan_raw = initial_plan_raw
        current_plan = llmtm.extract_plan(current_plan_raw)
        for _ in range(max_rounds):
            # Show current plan
            # print("Current Plan:\n", self.plan_json_to_str(current_plan["overall_plan"]))
            logging.info("Current Plan:\n" + self.plan_json_to_str(current_plan["overall_plan"]))
            print("Current Plan:\n" + self.plan_json_to_str(current_plan["overall_plan"]))
            
            # Get user feedback
            if input("Refine plan? (y/n): ").lower() != 'y':
                logging.info("Skipping plan refinement.")
                print("Skipping plan refinement.")
                break
            
            feedback = input("Enter refinement feedback: ")
            logging.info(f"User feedback: {feedback}")
            print(f"User feedback: {feedback}")
            current_plan, current_plan_raw = llmtm.refine_plan(feedback, current_plan_raw)
        
        return llmtm.extract_plan(current_plan_raw)

    def plan_json_to_str(self, plan):
        # Process Input Format
        input_fmt = plan["input_format"]
        input_lines = []
        for idx, (dtype, shape) in enumerate(input_fmt, 1):
            shape_str = f"shape={shape}" if shape is not None else "no fixed shape"
            input_lines.append(f"Argument {idx}: {dtype} with {shape_str}")
        input_section = "Input Format:\n" + "\n".join([f"- {line}" for line in input_lines])

        # Process Output Format
        output_fmt = plan["output_format"]
        output_lines = []
        for idx, (dtype, shape) in enumerate(output_fmt, 1):
            shape_str = f"shape={shape}" if shape is not None else "no fixed shape"
            output_lines.append(f"Output {idx}: {dtype} with {shape_str}")
        output_section = "Output Format:\n" + "\n".join([f"- {line}" for line in output_lines])

        # Build Overall Plan Details
        plan_part = [
            "=== Current Plan ===",
            input_section,
            output_section,
            f"Components Order: {', '.join(plan['components'])}",
            "Plan Steps:",
            *[f"- {step}" for step in plan["plan"]],
            "Overall Test Case Advice:",
            *[f"- {advice}" for advice in plan["test_case_generation_advise"]],
            "\n",
        ]

        return "\n".join(plan_part)

    def _plan_format_refinement(self, plan_dict):
        """Refines the input and output formats in the plan to be lists of lists."""
        
        # Create a deep copy to avoid modifying the original input
        refined_plan = copy.deepcopy(plan_dict)
        
        def refine_format(formats):
            """Ensure each format field is a list of lists."""
            if isinstance(formats, list):
                # Check if all elements are lists
                if not all(isinstance(elem, list) for elem in formats):
                    return [formats]
            else:
                # If it's not a list, wrap it into a list (though input is expected to be a list)
                return [formats]
            return formats
        
        # Process each component in 'components'
        for component in refined_plan["components"].values():
            for key in ["input_format", "output_format"]:
                if key in component:
                    component[key] = refine_format(component[key])
        
        # Process 'overall_plan'
        overall_plan = refined_plan.get("overall_plan")
        if overall_plan:
            for key in ["input_format", "output_format"]:
                if key in overall_plan:
                    overall_plan[key] = refine_format(overall_plan[key])
        
        return refined_plan
    
    async def _generate_tests_async(self, num_tests, task_description, use_example=True, original_test_cases={}):
        test_cases = original_test_cases
        task_list = [self.llmtm.get_test_cases_async(self.current_plan['overall_plan'], task_description=task_description, use_example=use_example) for _ in range(num_tests)]
        
        for task in tqdm_asyncio.as_completed(task_list, total=num_tests, desc="Generating async tests"):
            test = await task
            for key, value in test.items():
                # 生成唯一键名逻辑
                new_key = key
                suffix = 1
                while new_key in test_cases:
                    new_key = f"{key}_{suffix}"
                    suffix += 1
                test_cases[new_key] = value
                
        return test_cases
    
    def _generate_tests(self, num_tests, task_description, use_example=True, original_test_cases={}):
        # print("debug4 ####################################################################")
        # print(original_test_cases)

        test_cases = original_test_cases

        # print("debug5 ####################################################################")
        # print(original_test_cases)

        for _ in range(num_tests):
            test = self.llmtm.get_test_cases(self.current_plan['overall_plan'], task_description=task_description, use_example=use_example)
            for key, value in test.items():
                # 生成唯一键名逻辑
                new_key = key
                suffix = 1
                while new_key in test_cases:
                    new_key = f"{key}_{suffix}"
                    suffix += 1
                test_cases[new_key] = value

        return test_cases

    # def _filter_test_cases(self, dataset):
    #     # print(dataset)
    #     runnable_entries = {}
    #     for code_id, attributes in dataset.items():
    #         test_code = attributes.get("test_function", "")
    #         try:
    #             # Attempt to compile the code string to check for syntax errors.
    #             compile(test_code, "<string>", "exec")
    #             # If no exception is raised, consider the code as runnable.
    #             runnable_entries[code_id] = attributes
    #         except Exception as error:
    #             # If an exception is raised, skip this entry.
    #             continue
    #     return runnable_entries
    
    def _filter_test_cases(self, dataset):
        # print(dataset)
        runnable_entries = {}
        for code_id, attributes in dataset.items():
            test_code = attributes.get("test_function", "")
            try:
                # Attempt to compile the code string to check for syntax errors.
                # compile(test_code, "<string>", "exec")
                local_vars = {}
                exec(test_code, local_vars)
                for name, obj in local_vars.items():
                    if callable(obj):
                        runnable_entries[code_id] = attributes
                        break
                    else:
                        # If the object is not callable, skip this entry.
                        continue
                # If no exception is raised, consider the code as runnable.
                runnable_entries[code_id] = attributes
            except Exception as error:
                # If an exception is raised, skip this entry.
                continue
        return runnable_entries
            
    def _calculate_test_weights(self, test_cases, example_dataset, test_results=None, previous_weights=None):
        if not example_dataset and not test_results and not previous_weights:
            return {tid: 1.0 for tid in test_cases}
        
        # Run example dataset through tests
        if test_results is None and example_dataset is not None:
            _, test_results = self.code_runner.run_all_tests(example_dataset, test_cases)
        
        if previous_weights is not None:
            weights = previous_weights
        else:
            weights = {}

        # Calculate type ratios
        type_ratio = self.compute_test_type_ratios(test_cases)

        if test_results is not None:
            for tid, results in test_results.items():
                previous_weight = weights.get(tid, 1.0)
                test_case_info = test_cases[tid]
                new_weight = self.update_test_case_weight(previous_weight, results, test_case_info, type_ratio)
                weights[tid] = new_weight  # Weight tests that discriminate

        for tid in test_cases:
            if tid not in weights:
                weights[tid] = 1.0

        return weights

    def update_test_case_weight(self, previous_weight, current_test_result, test_case_info, type_ratio):
        num_codes = len(current_test_result)

        if num_codes == 0:
            # No codes were run against this test case, return previous weight or a default low weight.
            # Or, if this scenario implies an issue, could return a very low weight.
            return previous_weight

        num_passed = sum(1 for result in current_test_result.values() if result["success"])
        pass_rate = num_passed / num_codes

        # Update all_pass_times and all_fail_times based on the current round
        # These counters track consecutive rounds of all-pass or all-fail.
        if pass_rate == 1.0:
            test_case_info['all_pass_times'] += 1
            test_case_info['all_fail_times'] = 0
        elif pass_rate == 0.0:
            test_case_info['all_fail_times'] += 1
            test_case_info['all_pass_times'] = 0
        else:
            test_case_info['all_pass_times'] = 0
            test_case_info['all_fail_times'] = 0

        # --- Rule 4: Penalty for consistent non-discrimination (all pass or all fail multiple times) ---
        # This penalty is applied last and should be very strong.
        punishment_factor = 1.0
        if test_case_info['all_pass_times'] > 1:
            punishment_factor = 0.5 ** (test_case_info['all_pass_times'] - 1)
        elif test_case_info['all_fail_times'] > 1:
            punishment_factor = 0.3 ** (test_case_info['all_fail_times'] - 1)

        # --- Define thresholds for pass rate evaluation ---
        FEW_PASS_THRESHOLD = 0.2  # Below this, test case might be incorrect or too hard.
        MOST_PASS_THRESHOLD = 0.8 # Above this, test case might be too trivial.
                                # Between these thresholds is considered good discrimination.

        pass_rate_multiplier = 1.0

        if num_codes == 1:
            # If only one code, pass_rate is 0 or 1. This offers no discrimination for this round.
            # Treat it as a non-discriminating case for this round's multiplier.
            # Rule 4 will handle repeated occurrences.
            pass_rate_multiplier = 0.2 # Initial reduction for lack of discrimination info.
        elif pass_rate == 0.0 or pass_rate == 1.0:
            # All passed or all failed.
            # If this is the first time (all_pass_times/all_fail_times just became 1),
            # apply a significant reduction for non-discrimination in this round.
            # If it's a repeat (already > 1), Rule 4's punishment_factor will dominate.
            # This multiplier ensures an initial hit even before the strong punishment_factor kicks in fully.
            if test_case_info['all_pass_times'] <= 1 and test_case_info['all_fail_times'] <= 1:
                # This condition checks the *updated* counts. So if it just became 1.
                pass_rate_multiplier = 0.5
            else:
                # Already in a streak, Rule 4 will apply a strong penalty.
                # This multiplier provides an additional reduction.
                pass_rate_multiplier = 0.1
        elif 0 < pass_rate < FEW_PASS_THRESHOLD:
            # Few codes passed (e.g., 0-20%). Test case might be flawed or too niche. Decrease weight.
            # Multiplier scales from 0.1 (near 0% pass) to 0.5 (at FEW_PASS_THRESHOLD).
            pass_rate_multiplier = 0.1 + (pass_rate / FEW_PASS_THRESHOLD) * 0.4
        elif pass_rate > MOST_PASS_THRESHOLD:
            # Most codes passed (e.g., 80-100%). Test case might be too trivial. Decrease weight.
            # Multiplier scales from 0.5 (at MOST_PASS_THRESHOLD) down to 0.1 (near 100% pass).
            pass_rate_multiplier = 0.1 + ((1.0 - pass_rate) / (1.0 - MOST_PASS_THRESHOLD)) * 0.4
        else: # FEW_PASS_THRESHOLD <= pass_rate <= MOST_PASS_THRESHOLD
            # Good discrimination. "the less code passed, the more important it is".
            # Multiplier scales from 2.0 (at FEW_PASS_THRESHOLD) down to 1.0 (at MOST_PASS_THRESHOLD).
            # This boosts weight for test cases that are challenging but not impossible.
            normalized_discriminating_pass_rate = (pass_rate - FEW_PASS_THRESHOLD) / (MOST_PASS_THRESHOLD - FEW_PASS_THRESHOLD)
            pass_rate_multiplier = 2.0 - normalized_discriminating_pass_rate * 1.0

        # --- Type ratio importance ---
        # "the test case with low type ratio should be more important"
        test_type = test_case_info['test_type']
        # Default ratio if type is somehow missing from type_ratio dict (e.g., 0.2 assumes an average of 5 types)
        # A lower ratio_value means rarer type.
        ratio_value = type_ratio.get(test_type, 0.2)

        # Modifier: lower ratio_value -> higher type_ratio_modifier.
        # Example: ratio 0.1 (rare) -> 1/(0.1+0.3) = 1/0.4 = 2.5 (strong boost)
        #          ratio 0.2 (avg) -> 1/(0.2+0.3) = 1/0.5 = 2.0 (good boost)
        #          ratio 0.5 (common) -> 1/(0.5+0.3) = 1/0.8 = 1.25 (slight boost)
        #          ratio 0.8 (very common) -> 1/(0.8+0.3) = 1/1.1 approx 0.91 (slight reduction)
        # Adding 0.3 to denominator to control sensitivity and avoid division by zero for very small ratios.
        type_ratio_modifier = 1.0 / (ratio_value + 0.3)

        # --- Calculate updated weight ---
        # Start with previous_weight, apply pass_rate_multiplier, then type_ratio_modifier.
        updated_weight = previous_weight * pass_rate_multiplier * type_ratio_modifier

        # Apply the strong punishment_factor from Rule 4 at the end.
        # This ensures that if a test case is consistently non-discriminating, its weight becomes very small.
        updated_weight *= punishment_factor

        # Ensure weight does not fall below a minimum threshold (e.g., to prevent zero weight).
        MIN_WEIGHT = 0.001
        updated_weight = max(MIN_WEIGHT, updated_weight)

        # Max weight is not capped in this version but could be added if needed.

        return updated_weight
    
    def compute_test_type_ratios(self, test_cases):

        type_counts = {}
        total_count = 0

        # Count occurrences of each test type
        for case in test_cases.values():
            test_type = case.get('test_type')
            if test_type in type_counts.keys():
                type_counts[test_type] += 1
                total_count += 1
            else:
                type_counts[test_type] = 1
                total_count += 1

        # Compute ratio for each type
        if total_count == 0:
            return {}

        type_ratios = {test_type: count / total_count for test_type, count in type_counts.items()}
        return type_ratios

    async def _generate_codes_async(self, num_codes, current_plan, test_cases, best_codes=None, best_only=False, error_test_num=3):
        codes = {}
        task_list = [self.llmcg.get_code_async(extracted_plan=current_plan,
                                               test_cases=test_cases,
                                               best_codes=best_codes, 
                                               best_only=best_only,
                                               error_test_num=error_test_num) for _ in range(num_codes)]
        for task in tqdm_asyncio.as_completed(task_list, total=num_codes, desc="Generating async codes"):
            code = await task
            codes[f"code_{len(codes)}"] = code
        return codes
    
    def _generate_codes(self, num_codes, current_plan, test_cases, best_codes=None, best_only=False, error_test_num=3, prompt_only=False):
        codes = {}
        for _ in tqdm(range(num_codes), desc="Generating codes"):
            code = self.llmcg.get_code(
                extracted_plan=current_plan,
                test_cases=test_cases,
                best_codes=best_codes,
                best_only=best_only,
                error_test_num=error_test_num,
                prompt_only=prompt_only
            )
            if prompt_only:
                return code
            if code is None:
                continue
            else:
                codes[f"code_{len(codes)}"] = code
        return codes

    def transform_test_perspective(self, test_results):
        transformed = {}
        for test_case_id, code_results in test_results.items():
            for code_id, result in code_results.items():
                if code_id not in transformed:
                    transformed[code_id] = {}
                transformed[code_id][test_case_id] = result
        return transformed

    def _filter_test_cases_by_weight(self, test_results, threshold=0.2):
        filtered_test_case_list = []
        for test_case_id, weight in self.test_weights.items():
            if weight > threshold:
                filtered_test_case_list.append(test_case_id)
        self.test_cases = {k: v for k, v in self.test_cases.items() if k in filtered_test_case_list}
        self.test_weights = {k: v for k, v in self.test_weights.items() if k in filtered_test_case_list}
        filtered_test_results = {k: v for k, v in test_results.items() if k in filtered_test_case_list}
        logging.info(f"Filtered test cases: {len(self.test_cases)} out of {len(self.test_weights)}")

        filtered_fun_results = self.transform_test_perspective(filtered_test_results)
        return filtered_fun_results, filtered_test_results

    def _filter_test_cases_by_pass_rate(self, test_results, threshold=0.05):
        filtered_test_case_list = []
        filtered_test_results = {}
        test_case_length = len(test_results)
        for test_case_id, results in test_results.items():
            total = len(results)
            passed = sum(1 for item in results.values() if item["success"])
            if total == 0:
                continue
            # passed = sum(results.values())
            pass_rate = passed / total
            if pass_rate > threshold:
                filtered_test_case_list.append(test_case_id)
                filtered_test_results[test_case_id] = results
        self.test_cases = {k: v for k, v in self.test_cases.items() if k in filtered_test_case_list}
        logging.info(f"Filtered test cases: {len(self.test_cases)} out of {test_case_length}")
        # debug
        print(f"Filtered test cases: {len(self.test_cases)} out of {test_case_length}")

        filtered_fun_results = self.transform_test_perspective(filtered_test_results)

        return filtered_fun_results, filtered_test_results

    def _evaluate_codes(self, codes, timeout=None, ai4s=False, filter_threshold=0.05):
        if timeout is None:
            timeout = self.test_timeout
        print(f"Evaluating codes on {len(self.test_cases)} test cases...")
        fun_results, test_results = self.code_runner.run_all_tests(codes, self.test_cases, timeout=timeout)
        # print("###############################################################")
        # print(fun_results)
        # print("###############################################################")
        # print(test_results)
        # print("###############################################################")
        self.test_weights = self._calculate_test_weights(self.test_cases, example_dataset=None, test_results=test_results, previous_weights=self.test_weights)
        # filtered_fun_results, filtered_test_results = self._filter_test_cases_by_pass_rate(test_results, threshold=0.05)
        if ai4s:
            filter_threshold = 0.0
        filtered_fun_results, filtered_test_results = self._filter_test_cases_by_weight(test_results, threshold=filter_threshold)
        
        input_data = {}
        for code_id, results in filtered_fun_results.items():
            input_data[code_id] = {
                'code': codes[code_id]['code'],
                'test_results': results,
                'test_weights': self.test_weights
            }
        # Calculate scores
        output_scores, full_score_dict = self.evaluator.calculate_batch_scores(input_data)
        # Combine scores with code data
        output_results = {}
        for code_id in codes.keys():
            if code_id not in output_scores:
                continue
            else:
                code_info = {
                    'code': codes[code_id]['code'],
                    'plan': codes[code_id]['plan'],
                    'main_function_name': codes[code_id]['main_function_name'],
                    'score': output_scores[code_id],
                    'pass_rate': full_score_dict[code_id]['pass_rate'],
                    'pass_rate_score': full_score_dict[code_id]['pass_rate_score'],
                    'prediction_score': full_score_dict[code_id]['prediction_score'],
                    'pylint_score': full_score_dict[code_id]['pylint_score'],
                    'radon_score': full_score_dict[code_id]['radon_score'],
                    'test_case_results': filtered_fun_results[code_id],
                    'test_weights': self.test_weights,
                    }
                output_results[code_id] = code_info
        # output_results = {
        #     code_id: {
        #         'code': codes[code_id]['code'],
        #         'plan': codes[code_id]['plan'],
        #         'main_function_name': codes[code_id]['main_function_name'],
        #         'score': output_scores[code_id],
        #         'pass_rate_score': full_score_dict[code_id]['pass_rate_score'],
        #         'prediction_score': full_score_dict[code_id]['prediction_score'],
        #         'pylint_score': full_score_dict[code_id]['pylint_score'],
        #         'radon_score': full_score_dict[code_id]['radon_score'],
        #         'test_case_results': filtered_fun_results[code_id],
        #         'test_weights': self.test_weights,
        #     }
        #     for code_id in codes.keys()
        # }
        return output_results, filtered_test_results
        # return {
        #     code_id: {
        #         'code': codes[code_id]['code'],
        #         'plan':codes[code_id]['plan'],
        #         'main_function_name':codes[code_id]['main_function_name'],
        #         'score': self.evaluator.calculate_score(codes[code_id]['code'] ,results, self.test_weights)
        #     }
        #     for code_id, results in fun_results.items()
        # }

    def _select_top_codes(self, scored_codes, top_k=3):
        return dict(sorted(scored_codes.items(), 
                          key=lambda x: x[1]['score'], 
                          reverse=True)[:top_k])

    async def _get_user_feedback(self, top_codes):

        logging.info("\nTop Performing Codes:")
        for cid, data in top_codes.items():
            logging.info(f"{cid} [Score: {data['score']:.2f}]:")
            logging.info("Code workflow:")
            logging.info(data['plan'])
            logging.info("Partial Code:")
            logging.info(data['code'][:500] + "...\n")
        
        if input("Provide feedback? (y/n): ").lower() == 'y':
            feedback = input("Enter your feedback: ")
            logging.info(f"User feedback: {feedback}")
            # Store feedback for next generation cycle
            self.current_plan['user_feedback'] = feedback
            return True
        return False