import os
import copy
import logging
import pandas as pd
from tqdm.asyncio import tqdm_asyncio
from pathlib import Path
from datetime import datetime
from LLM_call import LLMModel
from LLM_TM import LLMTM
from LLM_CG import LLMCG, CodeRunner
from Evaluator import Evaluator
from collections import defaultdict
from typing import Dict, List
from tqdm import tqdm
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from scipy.linalg import solve
import collections

# 创建Logger实例
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

# 定义日志格式
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# 输出到文件的Handler
log_dir = Path("log")
log_dir.mkdir(exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_filename = f"{timestamp}.log"
log_path = log_dir / log_filename
file_handler = logging.FileHandler(log_path)
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)

# 输出到控制台的Handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)

# 将Handler添加到Logger
logging.basicConfig(
    level=logging.DEBUG,
    datefmt="%Y-%m-%d %H:%M:%S",
    handlers=[
        file_handler,   # 文件输出
        console_handler # 控制台输出
    ]
)


class LCDP():
    def __init__(self, api_key, model="gpt-3.5-turbo", base_url="https://api.openai.com/v1/", max_workers=5, ignore_advice=False, use_web_search=False):
        self.llm_model = LLMModel(api_key, model, base_url=base_url, use_web_search=use_web_search)
        self.llmtm = LLMTM(self.llm_model)
        self.llmcg = LLMCG(self.llm_model)
        self.code_runner = CodeRunner(max_workers=max_workers)
        self.evaluator = Evaluator()
        self.task_description = None
        self.current_plan = {}
        self.test_cases = {}
        self.codes = {}
        self.test_timeout = None
        self.ignore_advice = ignore_advice
        self.gp_model = None

    def initialize(self):
        self.task_description = None
        self.current_plan = {}
        self.test_cases = {}
        self.codes = {}
        self.test_timeout = None
        self.gp_model = None
        self.k = 1

    def run(self, 
            task_description, 
            max_iterations=3, 
            stop_t=0.8, 
            num_plans=3, 
            num_tests=5, 
            num_codes=5,
            num_codes_select=3, 
            refine_rounds=3,
            code_refine_rounds=3,
            test_timeout=None, 
            min_tests=15, 
            max_tests=20, 
        #   use_async_generation=True, 
            knowledge_refine=False,
            best_only=False, 
            error_test_num=3, 
            record_all_results=False, 
            forced_test_cases=None):
        
        self.initialize()

        if knowledge_refine:
            print("\n########################################################################")
            print("### Phase 0: Refine Task Description")
            self.task_description = self.llmtm.task_knowledge_refinement(task_description)

        # generate plans
        print("\n########################################################################")
        print("### Phase 1: Plan Generation and Refinement")
        plan = self.llmtm.get_plan(self.task_description, num_plans=num_plans)

        if self.ignore_advice:
            # skip plan refinement by ignoring user feedback
            self.current_plan = plan
        else:
            # User iterative refinement to improve/select the plan
            self.current_plan = self.llmtm.plan_refinement_loop(plan, refine_rounds)

        self.plan_summary = self.llmtm.summarize_plan(self.current_plan)
        # print(self.plan_summary)

        # generate test cases
        print("\n########################################################################")
        print("### Phase 2: Test Case Generation and Weighting")
        
        if forced_test_cases is not None:
            self.test_cases = forced_test_cases
        
        print("initialize test cases and weights...")
        self.test_cases, debug_info = self.llmcg.generate_tests(num_tests, self.task_description, plan=self.plan_summary, original_test_cases=self.test_cases, debug=True)

        # print(self.test_cases)
        # raise NotImplementedError("filter_test_cases method is not implemented yet.")
        # self.test_cases = self.llmcg.filter_test_cases(self.test_cases)

        print("\n########################################################################")
        print("### Phase 3: First generation and evaluation")
        
        for plan_id, plan_i in self.current_plan.items():
            # generate codes for each plan
            print("generate code for plan: ", plan_id)
            generated_codes = self.llmcg.generate_codes(num_codes, self.task_description, plan_id, plan_i, self.test_cases, use_task_description=True)

            # evaluate codes and refine code by self-reflection and error analysis
            for _ in range(code_refine_rounds):
                print("\n--- Code Evaluation ---")
                code_results = self.evaluator.evaluate_codes(generated_codes, self.test_cases, self.test_timeout)
                print("\n--- Code Refinement ---")
                generated_codes = self.llmcg.refine_codes(generated_codes, code_results, self.test_cases, error_test_num)

            # evaluate codes based on static code analysis tools
            static_analysis_results = self.evaluator.static_code_analyzer(generated_codes)

            # update code storage with results
            self.codes = self.update_codes(self.codes, generated_codes, code_results, static_analysis_results)

        # preprocess code information (AST, embedding, etc.)
        self.codes = self.llmcg.preprocess_codes(self.codes)

        print("\n########################################################################")
        print("### Phase 4: Iterative Code Generation and Evaluation")
        for iteration in range(max_iterations):
            # update test cases' and codes' weight based on previous results
            self.test_cases, self.codes = self.update_weights(self.codes, self.test_cases)

            # filter test cases and codes based on weight
            self.test_cases, self.codes = self.filter_by_weights(self.test_cases, self.codes, test_threshold=0.2, code_threshold=0.2)

            # add more test cases if needed
            if len(self.test_cases) < min_tests:
                print(f"INFO: Test cases below threshold ({len(self.test_cases)} < {min_tests}). Generating more.")
                required_tests = min_tests - len(self.test_cases)
                self.test_cases, debug_info = self.llmcg.generate_tests(required_tests, self.task_description, plan=self.plan_summary, original_test_cases=self.test_cases, debug=True)

            # select best plan based on current codes' performance
            self.current_best_plan = self.select_best_plan(self.current_plan, self.codes)

            # generate new plan based on current best plan
            self.current_plan = self.llmtm.generate_new_plan(self.current_best_plan, self.task_description, self.codes, self.test_cases)
            
            # Optional user refinement for new plans
            if not self.ignore_advice:
                 new_plans = self.llmtm.plan_refinement_loop(self.llmtm, new_plans, refine_rounds)

            all_candidate_codes = {}
            for plan_id, plan_i in new_plans.items():
                # 5. Generate several new "candidate" codes for each new plan
                candidate_codes = self.llmcg.generate_codes(num_codes, self.task_description, plan_id, plan_i, self.test_cases, use_task_description=True)
                all_candidate_codes.update(candidate_codes)

            # 6. Preprocess all new codes 
            all_current_codes = self.llmcg.preprocess_codes(all_current_codes)

            # 7. Build Gaussian Process model
            self.gp_model = self.build_gp_model(self.codes)

            # 8. Select the most promising codes using the model and an acquisition function
            # k initial value is 1, each round times 0.75
            self.k *= 0.75
            codes_to_evaluate = self.select_codes_with_acquisition(all_candidate_codes, num_code_select=num_codes_select, k=self.k)

            # 9. Evaluate ONLY the selected codes
            print(f"INFO: Evaluating {len(codes_to_evaluate)} selected codes...")
            for _ in range(code_refine_rounds):
                code_results = self.evaluator.evaluate_codes(codes_to_evaluate, self.test_cases, test_timeout)
                codes_to_evaluate = self.llmcg.refine_codes(codes_to_evaluate, code_results, self.test_cases, 3)

            final_results = self.evaluator.evaluate_codes(codes_to_evaluate, self.test_cases, test_timeout)
            static_analysis_results = self.evaluator.static_code_analyzer(codes_to_evaluate)

            # 10. Update the main code repository with the newly evaluated codes
            self.codes = self.update_codes(self.codes, codes_to_evaluate, final_results)
            
            # 11. Check for stopping condition
            best_score = max(c.get('score', 0) for c in self.codes.values())
            print(f"INFO: Best score after iteration {iteration + 1} is {best_score:.2f}")
            if best_score >= stop_t:
                print(f"INFO: Stopping condition met. Score {best_score:.2f} >= {stop_t}")
                break
        
        # Final Output
        if not self.codes:
            print("ERROR: No code was generated.")
            return None
            
        best_code = max(self.codes.values(), key=lambda x: x.get('score', 0))
        print("\n########################################################################")
        print("### Process Finished")
        print(f"Best code found with score: {best_code['score']:.2f}")
        print("Content:")
        print(best_code['content'])
        print("########################################################################")

        return best_code

    def update_codes(self, existing_codes, generated_codes, code_results, static_analysis_results):
        # Make a copy to avoid modifying the original dict in place unexpectedly
        updated_codes = existing_codes.copy()

        # Iterate over each newly generated code
        for code_id, code_info in generated_codes.items():
            # Generate a new, guaranteed unique ID for our storage
            new_unique_id = str(uuid.uuid4())

            # Safely get test results, defaulting to an empty dict if not found
            test_results = code_results.get(code_id, {})
            
            # Calculate the pass rate from test results
            pass_rate = 0.0
            if test_results:
                passed_count = sum(1 for result in test_results.values() if result is True)
                total_count = len(test_results)
                if total_count > 0:
                    pass_rate = (passed_count / total_count) * 100

            # Create the consolidated record for the new code
            updated_codes[new_unique_id] = {
                'source_id': code_id,  # Keep track of the original ID
                'code_str': code_info.get('code_str'),
                'main_func_name': code_info.get('main_func_name'),
                'reasoning': code_info.get('reasoning'),
                'plan_id': code_info.get('plan_id'),
                'test_results': test_results,
                'pass_rate_percent': round(pass_rate, 2),
                'static_analysis': static_analysis_results.get(code_id, {}) 
            }
        
        return updated_codes

    def select_best_plan(self, plans, codes):
        print("INFO: Selecting the best performing plan...")
        plan_scores = {}
        for code_id, code_info in codes.items():
            plan_id = code_info.get("plan_id")
            if plan_id:
                plan_scores.setdefault(plan_id, []).append(code_info.get("score", 0))

        if not plan_scores:
            return list(plans.keys())[0] if plans else "default_plan"

        # Average score per plan
        avg_scores = {p_id: sum(s) / len(s) for p_id, s in plan_scores.items()}
        best_plan_id = max(avg_scores, key=avg_scores.get)
        print(f"INFO: Best plan is '{best_plan_id}' with average score {avg_scores[best_plan_id]:.2f}")
        return best_plan_id
    
    def update_weights(self,codes,test_cases,alpha = 0.9):
        # --- 1. Calculate Code Scores ---
        # The score is the weighted pass rate.
        # S(x) = sum(I(x passes Ti) * C(Ti)) / sum(C(Ti))
        
        sum_of_all_weights = sum(t_info.get('weight', 0) for t_info in test_cases.values())

        if sum_of_all_weights == 0:
            # If all weights are zero, assign a default score of 0 to all codes
            for code_id in codes:
                codes[code_id]['score'] = 0.0
        else:
            for code_id, code_info in codes.items():
                weighted_passes = 0
                for test_id, passed in code_info.get('test_results', {}).items():
                    if passed and test_id in test_cases:
                        weighted_passes += test_cases[test_id].get('weight', 0)
                
                score = weighted_passes / sum_of_all_weights
                codes[code_id]['score'] = round(score, 4)

        # --- 2. Update Test Case Weights ---
        # C(Ti)_new = (1-a)*C(Ti)_old + a * (avg_score_pass - avg_score_fail)
        
        for test_id, test_info in test_cases.items():
            passing_code_scores = []
            failing_code_scores = []

            # Segregate codes based on whether they passed or failed this test case
            for code_info in codes.values():
                # Ensure the code has a result for the current test case
                if test_id in code_info.get('test_results', {}):
                    if code_info['test_results'][test_id]:
                        passing_code_scores.append(code_info.get('score', 0))
                    else:
                        failing_code_scores.append(code_info.get('score', 0))
            
            # Calculate the average score for codes that passed the test
            avg_pass_score = (
                sum(passing_code_scores) / len(passing_code_scores)
                if passing_code_scores else 0.0
            )

            # Calculate the average score for codes that failed the test
            avg_fail_score = (
                sum(failing_code_scores) / len(failing_code_scores)
                if failing_code_scores else 0.0
            )

            discriminative_power = avg_pass_score - avg_fail_score
            old_weight = test_info.get('weight', 1.0)

            # Update the weight using the provided formula
            new_weight = (1 - alpha) * old_weight + alpha * discriminative_power
            test_cases[test_id]['weight'] = round(new_weight, 4)

        return test_cases, codes

    def filter_by_weights(self,test_cases,codes,test_threshold=0.2,code_threshold=0.2):
        # --- 1. Filter Test Cases ---
        filtered_test_cases = {
            test_id: test_info
            for test_id, test_info in test_cases.items()
            if test_info.get('weight', 0) >= test_threshold
        }
        kept_test_ids = set(filtered_test_cases.keys())

        # --- 2. Filter Codes ---
        filtered_codes = {
            code_id: code_info
            for code_id, code_info in codes.items()
            if code_info.get('score', 0) >= code_threshold
        }

        # --- 3. Clean 'test_results' in Surviving Codes ---
        for code_id, code_info in filtered_codes.items():
            original_results = code_info.get('test_results', {})
            # Keep results only for test cases that were not filtered out
            cleaned_results = {
                test_id: result
                for test_id, result in original_results.items()
                if test_id in kept_test_ids
            }
            filtered_codes[code_id]['test_results'] = cleaned_results

        return filtered_test_cases, filtered_codes
    
    def _rbf_kernel(self, X1, X2):
        # 1. 计算余弦相似度
        sim_matrix = cosine_similarity(X1, X2)

        # 2. 计算余弦距离的平方
        dist_sq = np.square(1 - sim_matrix)

        # 3. 计算RBF核
        if self.l is not None:
            kernel = np.exp(-dist_sq / (2 * self.l**2))
        else:
            kernel = np.exp(-dist_sq)

        return kernel

    def build_gp_model(self, codes):
        if not codes:
            raise ValueError("Input 'codes' dictionary cannot be empty.")

        # 1. 从字典中提取数据
        code_ids = list(codes.keys())
        X_train = np.array([codes[cid]["embedding"] for cid in code_ids])
        y_train = np.array([codes[cid]["score"] for cid in code_ids])

        if X_train.ndim == 1:
            X_train = X_train.reshape(1, -1)
            y_train = y_train.reshape(-1)


        # 2. 计算训练集的核矩阵 K
        K = self._rbf_kernel(X_train, X_train)
        # 加上噪声项以保证数值稳定性
        K_stable = K + self.sigma_n * np.eye(len(X_train))

        # 3. 预计算用于预测的关键部分
        # 我们需要求解 (K + sigma*I) * alpha = y
        # alpha = (K + sigma*I)^-1 * y
        # 使用 solve 比直接求逆更稳定、更高效
        try:
            alpha = solve(K_stable, y_train, assume_a='pos')
            K_inv = solve(K_stable, np.eye(len(X_train)), assume_a='pos')
        except np.linalg.LinAlgError:
            print("Warning: Kernel matrix is singular. Using pseudo-inverse.")
            pseudo_inv = np.linalg.pinv(K_stable)
            alpha = pseudo_inv @ y_train
            K_inv = pseudo_inv


        self.gp_model = {
            "X_train": X_train,
            "y_train": y_train,
            "code_ids": code_ids,
            "alpha": alpha, # 预计算的 (K+sigma*I)^-1 * y
            "K_inv": K_inv, # 预计算的 (K+sigma*I)^-1
        }
        print(f"GP model built successfully with {len(code_ids)} data points.")
        return self.gp_model

    def select_codes_with_acquisition(self,
                                      all_candidate_codes: dict,
                                      num_code_select: int,
                                      k: float):
        if self.gp_model is None:
            raise RuntimeError("GP model has not been built. Please call 'build_gp_model' first.")
        if not all_candidate_codes:
            return []

        # 1. 从模型和输入中提取数据
        X_train = self.gp_model["X_train"]
        alpha = self.gp_model["alpha"]
        K_inv = self.gp_model["K_inv"]

        candidate_ids = list(all_candidate_codes.keys())
        X_candidate = np.array([all_candidate_codes[cid]["embedding"] for cid in candidate_ids])

        # 2. 计算候选代码与训练集之间的核向量 k_*
        k_star = self._rbf_kernel(X_candidate, X_train)

        # 3. 计算每个候选代码的预测均值和方差
        # 预测均值: mu = k_* @ alpha
        mu_s = k_star @ alpha

        # 预测方差: sigma^2 = k_** - k_*^T @ K_inv @ k_*
        # k_** (候选点自身的核) 对角线上的值为1，因为d=0, exp(0)=1
        k_star_star = 1
        # var = k_** - np.diag(k_star @ K_inv @ k_star.T)
        var_s = k_star_star - np.einsum('ij,jk,ik->i', k_star, K_inv, k_star)
        # 确保方差非负，避免数值计算误差
        var_s = np.maximum(var_s, 1e-8)
        std_s = np.sqrt(var_s)

        # 4. 计算UCB分数
        ucb_scores = mu_s + k * std_s

        # 5. 选择分数最高的代码
        selected_indices = np.argsort(-ucb_scores)[:num_code_select]
        selected_code_ids = [candidate_ids[i] for i in selected_indices]

        # for i in selected_indices:
        #     print(f"Code {candidate_ids[i]}: UCB={ucb_scores[i]:.4f} (mu={mu_s[i]:.4f}, std={std_s[i]:.4f})")

        return selected_code_ids