import argparse
import csv
import inspect
import os
import traceback
from tqdm import tqdm  # for progress bar

import pickle
import matplotlib.dates as mdates
import numpy as np
import pandas as pd
from collections import defaultdict

from factors.expression_exec import factor_validation, safe_exec_factor, str2fun
from factors.expression_lib import init_expression_lib, get_implementation_source
from factors.metric import compute_rank_ic
from factors.register import DynamicFactorRuntime, FactorRegister

from utils.data import load_data

from llm_client.call_llm import call_llm, parse_factors_string
from portfolio.backtest import (
    calculate_performance_metrics,
    eval_equal_weight_portfolio_performance,
)

from utils.core import aggregate_factor, filter_factor_versions
from utils.plot import plot_cw

import logging
from factors.metric import compute_rank_ic, movement_precision_at_n, recall_at_n


# Configure the logger
logging.basicConfig(
    level=logging.INFO
)  # Options: DEBUG, INFO, WARNING, ERROR, CRITICAL


def normalize_to_neg1_pos1(scores):
    scores = np.array(scores, dtype=float)  # convert to NumPy array if it's a list
    min_val = np.min(scores)
    max_val = np.max(scores)

    if min_val == max_val:
        # All values are the same; return zeros
        return np.zeros_like(scores)

    # Normalize to [0, 1]
    normalized = (scores - min_val) / (max_val - min_val)
    # Scale to [-1, 1]
    scaled = 2 * normalized - 1
    return scaled


def compute_all_assets_scores(factor, prices, returns, volumes=None):
    assert len(prices) == len(
        returns
    ), "Input prices and returns should have same shape"

    if volumes is not None:
        assert len(prices) == len(volumes), "Volumes must match prices in length"

    # Inspect factor function arguments
    if isinstance(factor, DynamicFactorRuntime):
        sig = inspect.signature(factor.func)
        factor_name = factor.raw
    else:
        sig = inspect.signature(factor)
        factor_name = factor.__name__ if hasattr(factor, "__name__") else str(factor)

    expected_args = list(sig.parameters.keys())

    score_lists = []
    for i in range(len(prices)):
        kwargs = {}
        if "prices" in expected_args:
            kwargs["prices"] = prices[i]
        if "return_" in expected_args or "returns" in expected_args:
            # Try both common naming conventions
            if "return_" in expected_args:
                kwargs["return_"] = returns[i]
            else:
                kwargs["returns"] = returns[i]
        if "volume" in expected_args and volumes is not None:
            kwargs["volume"] = volumes[i]

        try:
            score = factor(**kwargs)
        except Exception as e:
            print(f"[Runtime] [Error] Evaluating factor {factor_name}: {e}")
            # import pdb;pdb.set_trace()
            score = np.nan
        score_lists.append(score)

    return score_lists


def eval_factors(assets_scores, target_return):
    """
    Evaluate a factor based on its scores and next-day returns.

    Parameters:
    - assets_scores: 1D numpy array of factor scores for n assets
    - target_return: 1D numpy array of actual returns for the same n assets

    Returns:
    - Dictionary with evaluation metrics
    """
    return {
        "rank_ic": compute_rank_ic(assets_scores, target_return),
        "recall_20": recall_at_n(assets_scores, target_return, n=20),
        "precision_20": movement_precision_at_n(assets_scores, target_return, n=20),
        "recall_10": recall_at_n(assets_scores, target_return, n=10),
        "precision_10": movement_precision_at_n(assets_scores, target_return, n=10),
        "recall_5": recall_at_n(assets_scores, target_return, n=5),
        "precision_5": movement_precision_at_n(assets_scores, target_return, n=5),
    }


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description="Stock Portfolio Optimization Parameters"
    )

    # Add arguments with default values
    parser.add_argument(
        "--win_size",
        type=int,
        default=30,
        help="Window size in trading days (default: 60 for ~3 months)",
    )
    parser.add_argument(
        "--m",
        type=int,
        default=10,
        help="Number of stocks to hold at a time (default: 10)",
    )
    parser.add_argument(
        "--n_factors",
        type=int,
        default=5,
        help="Number of stocks to hold at a time (default: 5)",
    )
    parser.add_argument("--search_steps", type=int, default=5, help="Search step")

    parser.add_argument(
        "--data_path",
        type=str,
        default="benchmark_dataset/csv/FF25new.csv",
        help="Path to the dataset CSV file (default: benchmark_dataset/csv/FF25new.csv)",
    )

    parser.add_argument(
        "--search_ckp",
        type=str,
        default="searching_records_FF25new.bin",
        help="Single searching checkpoint (used when --multi_ckp_mode is False)",
    )
    parser.add_argument(
        "--multi_ckp_mode",
        action="store_true",
        help="Enable if using multiple search checkpoints combined from --search_ckp_list",
    )
    parser.add_argument(
        "--search_ckp_list",
        type=str,
        default="",
        help="Comma-separated list of checkpoint files (used only when --multi_ckp_mode is True)",
    )

    parser.add_argument(
        "--trans_cost",
        type=float,
        default=0.0,
        help="Transaction cost as a percentage (e.g., 0.002 for 0.2%)",
    )

    parser.add_argument(
        "--use_lowch",  # The flag name
        action="store_true",  # Sets to True when flag is present
        default=False,  # Default is False (optional, since False is default)
        help="Use low change hands (default: False, use equally weight)",
    )

    parser.add_argument(
        "--use_w2s",  # The flag name
        action="store_true",  # Sets to True when flag is present
        default=False,  # Default is False (optional, since False is default)
        help="Apply weight to score (default: False, use equally weight)",
    )

    parser.add_argument(
        "--record_n",  # The flag name
        type=int,
        default=0,  # Default is False (optional, since False is default)
        help="Only use first n searching records",
    )

    parser.add_argument(
        "--record_ratio",  # The flag name
        type=float,
        default=1,  # Default is False (optional, since False is default)
        help="Only use first n searching records",
    )

    parser.add_argument(
        "--exp_name",
        type=str,
        default="default",  # "gpt-4.1"
        help="Experinment prefix in output file",
    )

    parser.add_argument(
        "--lag_one",
        action="store_true",
        help="If set, apply one-period lag to the data.",
    )


    # Parse arguments
    args = parser.parse_args()

    data = pd.read_csv(args.data_path)
    if "date" in data.columns:
        return_matrix = data.drop(columns=["date"]).values
        tickers = data.columns[1:]
        daily_date = data["date"]
    else:
        return_matrix = data.values
        tickers = data.columns
        daily_date = [i for i in range(return_matrix.shape[0])]

    filename = os.path.basename(args.data_path).split(".")[0]

    print(return_matrix.shape)

    daily_return = return_matrix

    asset_names = [f"Asset {i+1}" for i in range(daily_return.shape[1])]
    data_batch = load_data(
        return_matrix=return_matrix, daily_date=daily_date, lag_oneday=args.lag_one
    )

    # import pdb;pdb.set_trace()
    # import pdb;pdb.set_trace()
    print(f"Cost: {args.trans_cost}")
    if not args.multi_ckp_mode:
        with open(args.search_ckp, "rb") as f:
            search_record = pickle.load(f)
    else:
        search_record = aggregate_factor(args.search_ckp_list, args.record_n, args.record_ratio)
        
    search_record = (
        search_record[: args.record_n] if args.record_n > 0 else search_record
    )
    # import pdb;pdb.set_trace()
    warmup_steps = 30
    init_portfolio_value = 1
    selected_indices_list = []
    base_portfolio_value_list = []
    daily_returns_list = []

    portfolio_value_list = []
    portfolio_value = init_portfolio_value
    baseline_value = init_portfolio_value
    factor_perd_list = []

    print(f"Origin search records: {len(search_record)}")

    if args.record_n > 1:
        search_record = search_record[: args.record_n]  # max take 300 records

    if args.record_ratio < 1:
        search_record = search_record[: int(len(search_record) * args.record_ratio)]

    num_search_record = len(search_record)
    print(f"Used search records: {num_search_record}")

    search_frequency = args.search_steps
    n_assets = args.m

    # Cache the computed factor result and its valid step range
    cached_auto_factor_result = None
    cached_result_start = -1
    cached_result_end = -1

    if args.use_w2s:
        print("Using weighted scores")

    backtest_record = []
    prev_indices = None
    prev_weights = None

    lowch_fixed_weights = None
    lowch_fixed_indices = None

    for step in range(len(data_batch) - 1):
        # Before the warmup_steps, using equally weight
        target_return = data_batch[step]["target_return"]
        input_prices = data_batch[step]["input_prices"]
        input_returns = data_batch[step]["input_returns"]
        avg_daily_return = target_return.mean()

        baseline_value *= avg_daily_return
        base_portfolio_value_list.append(baseline_value)

        if step <= warmup_steps:
            portfolio_value *= (
                avg_daily_return
            )  # This can be adjusted depends on running method
            portfolio_value_list.append(portfolio_value)
            daily_returns_list.append(np.array(avg_daily_return) - 1)

        # After warm_steps, using the strategy from factors
        else:
            # Update the record every search_frequency steps (e.g., every 5 steps)

            if step == warmup_steps + 1:
                current_record = search_record[0]
                print(f"Step {step} using record index 0")

            if (step - 1) % search_frequency == 0:
                # Get the appropriate record index (ensure we don't go out of bounds)
                record_index = min(
                    (step - 1 - warmup_steps) // search_frequency, num_search_record - 1
                )
                # if record_index % 2 == 1:
                #     record_index = record_index - 1
                current_record = search_record[record_index]
                print(f"Step {step} using record index {record_index}")

            if step > num_search_record * search_frequency + warmup_steps:
                # Use last record if we exceed the number of records
                current_factor_result = search_record[-1]["portfolop_performance"]

            else:
                current_factor_result = current_record["portfolop_performance"]

            # For the step after searching step, manully compute the performance of all factors in pool

            try:

                filtered_performance = filter_factor_versions(
                    current_record["portfolop_performance"],
                    quality_column="final_value",
                )
                factor_quality = filter_factor_versions(
                    current_record["factor_quality"], quality_column="mean_rankic"
                )
                filter_current_factor_result = filtered_performance  # factor_quality #

                topn_factor_score = []

                # Select the factor for this step
                # print(filter_current_factor_result)
                llm_gen_factor_dict = dict(current_record["current_gen_factors"])
                valid_factors = set(llm_gen_factor_dict.keys()) ^ set(
                    init_expression_lib.keys()
                )
                if len(valid_factors) == 0:
                    valid_factors = set(llm_gen_factor_dict.keys())
                # import pdb;pdb.set_trace()
                # Step 2: Filter DataFrame rows where index (i.e. factor name) is in valid_factors
                filtered_result = filter_current_factor_result.loc[
                    filter_current_factor_result.index.intersection(valid_factors)
                ]

                # import pdb;pdb.set_trace()
                top_factor_names = [
                    filtered_result.iloc[i].name
                    for i in range(min(args.n_factors, len(filtered_result)))
                ]
                # print(current_factor_result)

                for best_factor in top_factor_names:
                    if best_factor in llm_gen_factor_dict:
                        # import pdb;pdb.set_trace()
                        if best_factor in init_expression_lib:
                            parsed_factor = init_expression_lib[best_factor]
                        else:
                            parsed_factor = DynamicFactorRuntime(
                                str2fun(llm_gen_factor_dict[best_factor]),
                                llm_gen_factor_dict[best_factor],
                            )
                        assets_scores = [
                            safe_exec_factor(
                                parsed_factor,
                                prices=input_prices[i],
                                returns=input_returns[i],
                            )
                            for i in range(len(input_prices))
                        ]
                    else:
                        assets_scores = [
                            safe_exec_factor(
                                init_expression_lib[best_factor],
                                prices=input_prices[i],
                                returns=input_returns[i],
                            )
                            for i in range(len(input_prices))
                        ]

                    topn_factor_score.append(normalize_to_neg1_pos1(assets_scores))

                use_fixed_weight = args.use_lowch and (step % search_frequency != 0)

                if (
                    use_fixed_weight
                    and lowch_fixed_weights is not None
                    and lowch_fixed_indices is not None
                ):
                    top_indices = lowch_fixed_indices
                    weights = lowch_fixed_weights
                    print(
                        f"Step {step} reusing lowch-fixed weights from step {step - (step % search_frequency)}"
                    )
                else:
                    assets_scores = np.array(topn_factor_score)
                    if len(assets_scores) == 0:
                        raise ValueError("topn_factor_score is empty")
                    assets_scores = np.nan_to_num(
                        assets_scores.mean(0), nan=0.0, posinf=0.0, neginf=0.0
                    )

                    if len(assets_scores) < n_assets:
                        raise ValueError("Number of assets in score < n_assets")

                    top_indices = np.argsort(assets_scores)[-n_assets:][::-1]
                    selected_indices_list.append(top_indices.tolist())

                    weights = np.ones(n_assets) / n_assets
                    if args.use_w2s:
                        selected_scores = assets_scores[top_indices]
                        selected_scores = np.maximum(selected_scores, 0)
                        selected_scores = np.nan_to_num(selected_scores, nan=0.0)

                        if np.sum(selected_scores) <= 0:
                            weights = np.ones(n_assets) / n_assets
                        else:
                            weights = selected_scores / np.sum(selected_scores)

                    # If use_lowch is enabled and this is a checkpoint step, store weights
                    if args.use_lowch and (step % search_frequency == 0):
                        lowch_fixed_indices = top_indices.copy()
                        lowch_fixed_weights = weights.copy()
                        print(f"Step {step} saved new lowch-fixed weights.")

                # Compute transaction cost
                trans_cost = args.trans_cost
                cost = 0.0

                if (
                    prev_weights is not None
                    and prev_indices is not None
                    and trans_cost > 0
                ):
                    # Build full weight vectors aligned to current index universe
                    prev_weight_vec = np.zeros_like(assets_scores)
                    curr_weight_vec = np.zeros_like(assets_scores)

                    # Map previous weights to positions
                    for idx, w in zip(prev_indices, prev_weights):
                        prev_weight_vec[idx] = w
                    for idx, w in zip(top_indices, weights):
                        curr_weight_vec[idx] = w

                    # L1 norm of weight difference
                    turnover = np.sum(np.abs(curr_weight_vec - prev_weight_vec))
                    cost = trans_cost * turnover

                top_returns = target_return[top_indices]

                prev_indices = top_indices  # update for next step
                prev_weights = weights

                portfolio_return = np.dot(top_returns, weights)
                portfolio_return *= 1 - cost  # apply cost
                portfolio_value *= portfolio_return

                daily_returns_list.append(np.array(portfolio_return) - 1)
                portfolio_value_list.append(portfolio_value)
                topn_factor_results = [
                    eval_factors(scores, target_return - 1)
                    for scores in topn_factor_score
                ]
                factor_perd_list.append(topn_factor_results)

            except Exception as e:
                error_trace = traceback.format_exc()
                logging.error(f"Error in step {step}: {e} {error_trace}")

                portfolio_value *= (
                    avg_daily_return
                )  # This can be adjusted depends on running method
                portfolio_value_list.append(portfolio_value)
                daily_returns_list.append(np.array(avg_daily_return) - 1)

                selected_indices_list.append([])
                best_factor = "Error"

            daily_report = {
                "Date": data_batch[step]["output_date"],
                "Portfolio Value": portfolio_value,
                "Baseline Value": baseline_value,
                "Daily Return": portfolio_return,
                "Best Factor": top_factor_names,
                "Selected Indices": selected_indices_list[-1],
                "Selected Tickers": tickers[top_indices],
                "Selected Assets Returns": top_returns,
                "Portfolio Weight": weights,
                "All Scores": assets_scores,
            }

            backtest_record.append(daily_report)

            print(
                f"Step {step} Date {data_batch[step]['output_date']} Value {portfolio_value} Baseline {baseline_value} Daily Return {portfolio_return} Cost {cost}"
            )

    factor_data_name = f'factor_{args.n_factors}_result_EFS_{os.path.basename(args.data_path).split(".")[0]}_{n_assets}_{args.exp_name}.bin'

    with open(factor_data_name, "wb") as f:
        pickle.dump(factor_perd_list, f)

    report_data_name = f'backtest_{args.n_factors}_result_EFS_{os.path.basename(args.data_path).split(".")[0]}_{n_assets}_{args.exp_name}.bin'
    with open(report_data_name, "wb") as f:
        pickle.dump(backtest_record, f)

    perf = calculate_performance_metrics(daily_returns_list, portfolio_value_list)
    perf["CW"] = portfolio_value_list[-1]

    # Flatten the list of lists
    all_test_results = [f for step_results in factor_perd_list for f in step_results]

    all_rank_ic = [
        f["rank_ic"] if not np.isnan(f["rank_ic"]) else 0 for f in all_test_results
    ]

    mean_rank_ic = np.mean(all_rank_ic)
    std_rank_ic = np.std(all_rank_ic, ddof=1)

    if len(all_rank_ic) >= 252:
        t_annualized = 252
    else:
        t_annualized = len(all_rank_ic)

    icir = (mean_rank_ic / std_rank_ic) * np.sqrt(t_annualized)

    perf["test_avg_rank_ic"] = np.mean(all_rank_ic)
    perf["test_icir"] = icir
    perf["test_recall_20"] = np.mean([f["recall_20"] for f in all_test_results])
    perf["test_precision_20"] = np.mean([f["precision_20"] for f in all_test_results])
    perf["test_recall_10"] = np.mean([f["recall_10"] for f in all_test_results])
    perf["test_precision_10"] = np.mean([f["precision_10"] for f in all_test_results])
    perf["test_recall_5"] = np.mean([f["recall_5"] for f in all_test_results])
    perf["test_precision_5"] = np.mean([f["precision_5"] for f in all_test_results])

    print(perf)

    output_file = f'factor_{args.n_factors}_result_EFS_{os.path.basename(args.data_path).split(".")[0]}_{n_assets}_{args.exp_name}.csv'

    with open(output_file, "w", newline="") as csvfile:
        writer = csv.writer(csvfile)

        # Write header
        writer.writerow(["Metric", "Value"])

        # Write data rows
        for key, value in perf.items():
            writer.writerow([key, value])

    figure_name = f'factor_{args.n_factors}_result_EFS_{os.path.basename(args.data_path).split(".")[0]}_{n_assets}_{args.exp_name}.png'
    plot_cw(
        data_batch,
        portfolio_value_list,
        base_portfolio_value_list,
        figure_name=figure_name,
    )
