import pickle
import re
import traceback
import pandas as pd
from factors.expression_exec import factor_validation
from factors.register import DynamicFactorRuntime, FactorRegister, run_factor_eval
from llm_client.call_llm import call_llm_with_retry
from llm_client.prompot_bank import generate_factor_generate_prompt
import numpy as np
from factors.expression_lib import init_expression_lib
from portfolio.backtest import eval_equal_weight_portfolio_performance


def is_significantly_worse(portfolio, benchmark):
    if benchmark >= 0:
        return portfolio < benchmark * 0.5
    else:
        return portfolio < benchmark * 1.5
    


def update_generated_factors_performance(
    factor_name,
    portfolio_returns_dict,
    rank_ic_list,
    recall_20_list,
    factor_quality,
    portfolio_performance,
):
    portfolio_summary_row = pd.DataFrame(
        [
            {
                "mean_return": portfolio_returns_dict["mean_return"],
                "std_return": portfolio_returns_dict["std_return"],
                "sharpe_ratio": portfolio_returns_dict["sharpe_ratio"],
                "max_drawdown": portfolio_returns_dict["max_drawdown"],
                "final_value": portfolio_returns_dict["final_value"],
            }
        ],
        index=[factor_name],
    )
    new_portfolio_performance = pd.concat(
        [portfolio_performance, portfolio_summary_row]
    )
    new_portfolio_performance = new_portfolio_performance.sort_values(
        by="final_value", ascending=False
    )

    new_row_quality = pd.DataFrame(
        {
            "mean_rankic": round(np.mean(rank_ic_list), 6),
            "mean_recall@20": round(np.mean(recall_20_list), 6),
        },
        index=[factor_name],
    )

    new_factor_quality = pd.concat([factor_quality, new_row_quality])
    new_factor_quality = new_factor_quality.sort_values(
        by="mean_rankic", ascending=False
    )
    return new_factor_quality, new_portfolio_performance


def filter_factor_versions(df, min_versions=1, quality_column="final_value"):
    """
    Filter factor versions by selecting the latest and best-performing versions.
    
    Parameters:
    df (pd.DataFrame): Input DataFrame indexed by factor names with performance metrics.
    min_versions (int): Minimum versions required to apply filtering logic.
    quality_column (str): Column used to evaluate quality (e.g., 'final_value' or 'mean_rankic').

    Returns:
    pd.DataFrame: Filtered and sorted DataFrame of selected factor versions.
    """

    # Step 1: Extract base names and version numbers from factor names
    factor_data = []
    df = df[~df.index.duplicated(keep="first")]

    for name in df.index:
        match = re.search(r"_v(\d+)$", name)
        if match:
            base_name = name[: match.start()]
            version = int(match.group(1))
        else:
            base_name = name
            version = 0
        factor_data.append(
            {
                "original_name": name,
                "base_name": base_name,
                "version": version,
                "quality": df.loc[name, quality_column],
            }
        )

    name_df = pd.DataFrame(factor_data)

    # Step 2: Group by base_name and select versions
    selected_names = []
    for base_name, group in name_df.groupby("base_name"):
        if len(group) <= min_versions:
            selected_names.extend(group["original_name"].tolist())
            continue

        # import pdb;pdb.set_trace()
        latest = group.loc[group["version"].idxmax(), "original_name"]
        best = group.loc[group["quality"].idxmax(), "original_name"]

        if latest == best:
            selected_names.append(latest)
        else:
            selected_names.extend([latest, best])

    # Step 3: Filter and sort result
    filtered_df = df.loc[selected_names]
    return filtered_df.sort_values(by=quality_column, ascending=False)


def generate_backtest_report():
    pass

def aggregate_factor(search_ckp_list, record_n, record_ratio):
    # Load and align all input files
    ckp_paths = [p.strip() for p in search_ckp_list.split(",")]
    all_records_list = []
    for path in ckp_paths:
        with open(path, "rb") as f:
            records = pickle.load(f)
            all_records_list.append(records)

    # Align to the shortest length
    min_len = min(len(r) for r in all_records_list)
    if record_n > 1:
        min_len = min(min_len, record_n)
    if record_ratio < 1:
        min_len = min(min_len, int(min_len * record_ratio))

    # Clip records
    all_records_list = [r[:min_len] for r in all_records_list]

    # Perform offline merging step-by-step
    merged_search_record = []
    for step in range(min_len):
        merged_perf = {}
        merged_quality = {}
        merged_expr = {}

        for record_list in all_records_list:
            record = record_list[step]

            perf_df = filter_factor_versions(
                record["portfolop_performance"], quality_column="final_value"
            )
            qual_df = filter_factor_versions(
                record.get("factor_quality", pd.DataFrame()),
                quality_column="mean_rankic",
            )
            expr_dict = dict(record["current_gen_factors"])

            for factor_name in perf_df.index:
                final_val = perf_df.loc[factor_name, "final_value"]
                rank_ic = (
                    qual_df.loc[factor_name, "mean_rankic"]
                    if factor_name in qual_df.index
                    else -np.inf
                )

                # If not seen or better performance
                if (factor_name not in merged_perf) or (
                    final_val > merged_perf[factor_name]
                ):
                    merged_perf[factor_name] = final_val
                    merged_quality[factor_name] = rank_ic
                    if factor_name in init_expression_lib:
                        merged_expr[factor_name] = init_expression_lib[factor_name]
                    else:
                        merged_expr[factor_name] = expr_dict.get(factor_name, "")

                elif (
                    final_val == merged_perf[factor_name]
                    and rank_ic > merged_quality[factor_name]
                ):
                    merged_quality[factor_name] = rank_ic  # update IC only
                # import pdb;pdb.set_trace()

        merged_record = {
            "portfolop_performance": pd.DataFrame.from_dict(
                merged_perf, orient="index", columns=["final_value"]
            ),
            "factor_quality": pd.DataFrame.from_dict(
                merged_quality, orient="index", columns=["mean_rankic"]
            ),
            "current_gen_factors": merged_expr,
        }
        merged_search_record.append(merged_record)

    search_record = merged_search_record  # used later in main loop
    return search_record   


def factor_searching_enhance(
    data,
    min_steps=20,
    warmup_teps=30,
    search_freq=5,
    max_search_ratio=1.0,
    data_filename=None,
    model="deepseek-chat",
    prompt_config="all",
    exp_name="",
    log_dir="./logs",
    gen_num=10,
    
):
    print(f"Using {len(init_expression_lib)} init factors.")
    Factor_Manager = FactorRegister(data, init_expression_lib, min_steps, warmup_teps)
    searching_record_lists = []
    win_size = 30

    max_search_steps = int(max_search_ratio * len(data))
    print(
        f"Start search to generate factors, max search steps: {max_search_steps}, "
        f"warmup steps: {warmup_teps}, search steps frequency: {search_freq}, "
    )
    for step in range(warmup_teps, max_search_steps, search_freq):  #
        try:
            searching_record = {}
            searching_record["search_step"] = step

            print(
                f"Search step: {step}, current latest backtest date: {data[step]['output_date']}"
            )

            # compute the benchmark
            target_return_lists = [
                data[i]["target_return"] for i in range(min_steps, step)
            ]

            benchmark_result = eval_equal_weight_portfolio_performance(
                target_return_lists, init_value=100
            )

            # Eval to current step
            factor_quality, portfolio_performance = Factor_Manager.get_current_factor_pool_performance(
                step  # evaluate until last step
            )
            Factor_Manager.clean_factor_pool_by_performance(
                factor_quality, portfolio_performance, max_size=500, keep_top_n=400
            )

            filtered_performance = filter_factor_versions(
                portfolio_performance, quality_column="final_value"
            )
            factor_quality = filter_factor_versions(
                factor_quality, quality_column="mean_rankic"
            )

            given_information_count = max(gen_num, 10)

            # Control ablation study information
            recent_performance_topn = (
                "Please use factor quality as reference."
                if prompt_config == "quality"
                else filtered_performance[:given_information_count].to_string()
            )
            recent_factor_quanlity_topn = (
                "Please use factor portfolio performance as reference."
                if prompt_config == "performance"
                else factor_quality[:given_information_count].to_string()
            )

            # Get the indices from both DataFrames and combine them
            combined_indices = filtered_performance[
                :given_information_count
            ].index.union(factor_quality[:given_information_count].index)

            # Use the combined indices to get factor descriptions
            factor_desc = Factor_Manager.get_factor_desc(list(combined_indices))

            # One time-get five has best quality
            if prompt_config == "value":
                recent_performance_topn_names = filtered_performance[:10].index
                recent_factor_quanlity_topn_names = factor_quality[:10].index
                prompts_gen = generate_factor_generate_prompt(
                    factor_desc,
                    "From high to low: " + str(list(recent_performance_topn_names)),
                    "From high to low: " + str(list(recent_factor_quanlity_topn_names)),
                    n=5,
                )
                
            else:
                prompts_gen = generate_factor_generate_prompt(
                    factor_desc,
                    recent_performance_topn,
                    recent_factor_quanlity_topn,
                    n=5,
                )

            # To enhance the robust, we apply repeat-searching mechanism
            need_threads = gen_num // 5 + 2
            
            frame_name = str(step) + "_factors.log"
            frame_save_dir = os.path.join(log_dir, "frame_record")
            os.makedirs(frame_save_dir, exist_ok=True)
            log_path = os.path.join(frame_save_dir, frame_name)
            
            accepted_factors = []
            added_factors_candidates = []
            
            for i in range(need_threads):
                llm_expression, llm_result = call_llm_with_retry(
                    prompts_gen, model, max_retries=5, min_functions=2
                )
                
                if llm_expression is None:
                    print(
                        f"Failed to generate factors in step {step}, retrying... ({i+1}/{need_threads})"
                    )
                    continue
                
                for factor_name, factor_data in llm_expression.items():
                    factor = DynamicFactorRuntime(factor_data["func"], factor_data["raw"])

                    if factor_validation(factor, win_size):
                        # Compute the performance for generated factors
                        rank_ic_list, recall_20_list, daily_scores, error_count, portfolio_returns_dict = run_factor_eval(
                            factor_name=factor_name,
                            factor_expr=factor_data["func"],
                            data=data,
                            step=step,
                            min_steps=min_steps,
                        )

                        if is_significantly_worse(
                            portfolio=portfolio_returns_dict["mean_return"],
                            benchmark=benchmark_result["mean_return"],
                        ):
                            print(
                                f"Factor {factor_name} has worse performance than benchmark, ignored."
                            )
                            continue

                        if error_count / len(rank_ic_list) > 0.1:
                            print(f"Factor {factor_name} has too many errors, ignored.")
                            continue
                        
                        added_factors_candidates.append((factor_name, factor, portfolio_returns_dict, rank_ic_list, recall_20_list))  
                            
                    else:
                        print(f"Factor {factor_name} doesn't pass test, ignored.")

            print(f"Accepted generated factors candidates total: {len(added_factors_candidates)}")
            
            added_count = 0
            sorted_added_factors_candidates = sorted(
                added_factors_candidates,
                key=lambda x: x[2]["mean_return"],  # Assuming rank_ic_list is a dict with "final_value"
                reverse=True
            )
            
            accepted_llm_expression = []
            accepted_factors = []
            # Scan the factor candidates and select the best k factors
            for factor_name, factor, portfolio_returns_dict, rank_ic_list, recall_20_list in sorted_added_factors_candidates:
                if added_count >= gen_num:
                    break
                if factor_name not in Factor_Manager.gen_factor_lib:
                    new_factor_quality, new_portfolio_performance = update_generated_factors_performance(
                        factor_name,
                        portfolio_returns_dict,
                        rank_ic_list,
                        recall_20_list,
                        factor_quality,
                        portfolio_performance,
                    )
                    
                    factor_quality, portfolio_performance = (
                        new_factor_quality,
                        new_portfolio_performance,
                    )

                    Factor_Manager.add_llm_gen_factor(
                        factor_name, factor
                    )  # We may need to evaluate these new factors
                    accepted_llm_expression.append(factor.raw)
                    accepted_factors.append(factor_name)
                    added_count += 1

            
            if llm_expression is None:
                print("Failed to generate sufficient valid factors, skip this step")
                searching_record["success"] = False
                continue
            else:
                print(f"Successfully generated {len(accepted_llm_expression)} factors")
                searching_record["success"] = True
                
            with open(log_path, "w", encoding="utf-8") as f:
                f.write("<<<<<Information Provide to LLMs>>>>>>\n")
                f.write(str(prompts_gen) + "\n")
                f.write("<<<<<LLM Response>>>>>>\n")
                for exp in accepted_llm_expression:
                    f.write(str(exp) + "\n")
                f.write(
                    f"In this stage, accepted factors: {accepted_factors}, length: {len(accepted_factors)}\n"
                )

            searching_record["llm_raw_prompt"] = prompts_gen
            searching_record["llm_raw_response"] = llm_result
            searching_record["factor_quality"] = factor_quality
            searching_record["portfolop_performance"] = portfolio_performance
            searching_record["current_gen_factors"] = [
                (factor_name, factor_data.raw)
                for factor_name, factor_data in Factor_Manager.gen_factor_lib.items()
            ]

            searching_record_lists.append(searching_record)
            # import pdb;pdb.set_trace()

            output_name = (
                f"searching_records_{data_filename}_{model}_{exp_name}.bin"
                if len(exp_name) > 0
                else f"searching_records_{data_filename}_{model}.bin"
            )

            if prompt_config != "all":
                output_name = output_name.replace(".bin", f"_{prompt_config}.bin")

            output_path = os.path.join(log_dir, output_name)
            with open(output_path, "wb") as f:
                pickle.dump(searching_record_lists, f)

        except Exception as e:
            error_trace = traceback.format_exc()
            print(f"Search step {step} has error {e}\nTraceback:\n{error_trace}pass")

    print("Search finished, total steps: ", len(searching_record_lists))
