import pandas as pd

import numpy as np
import pickle
from collections import defaultdict

from factors.expression_lib import get_implementation_source, init_expression_lib
from factors.metric import compute_rank_ic, recall_at_n
from portfolio.backtest import eval_scores_topn_portfolio_performance
import inspect

import warnings


def compute_all_assets_scores(factor, prices, returns, volumes=None):
    assert len(prices) == len(
        returns
    ), "Input prices and returns should have same shape"

    if volumes is not None:
        assert len(prices) == len(volumes), "Volumes must match prices in length"

    # Inspect factor function arguments
    if isinstance(factor, DynamicFactorRuntime):
        sig = inspect.signature(factor.func)
        factor_name = factor.raw
    else:
        sig = inspect.signature(factor)
        factor_name = factor.__name__ if hasattr(factor, "__name__") else str(factor)

    expected_args = list(sig.parameters.keys())

    score_lists = []
    for i in range(len(prices)):
        kwargs = {}
        if "prices" in expected_args:
            kwargs["prices"] = prices[i]
        if "return_" in expected_args or "returns" in expected_args:
            # Try both common naming conventions
            if "return_" in expected_args:
                kwargs["return_"] = returns[i]
            else:
                kwargs["returns"] = returns[i]
        if "volume" in expected_args and volumes is not None:
            kwargs["volume"] = volumes[i]

        try:
            score = factor(**kwargs)
        except Exception as e:
            print(f"[Runtime] [Error] Evaluating factor {factor_name}: {e}")
            # import pdb;pdb.set_trace()
            score = np.nan
        score_lists.append(score)

    return score_lists


def run_factor_eval(
    factor_name,
    factor_expr,
    data,
    step,
    min_steps=30,
    T=60,
    TOL=1e-8,
    back_test_asset_num=10,
):
    rank_ic_list = []
    recall_20_list = []
    daily_scores = []
    error_count = 0

    # Determine the sliding window
    start = step - T if step - T >= min_steps else min_steps
    end = step

    # Cache the target return list for performance eval
    target_return_lists = [data[i]["target_return"] for i in range(start, end)]

    for i in range(start, end):
        input_prices = data[i]["input_prices"]
        input_returns = np.array(data[i]["input_returns"])
        target_return = data[i]["target_return"]

        try:
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter("always", category=RuntimeWarning)
                scores = compute_all_assets_scores(
                    factor_expr, input_prices, input_returns
                )

                for warning in w:
                    if issubclass(warning.category, RuntimeWarning):
                        raise RuntimeWarning(
                            f"Warning in factor {factor_name}: {warning.message}"
                        )

            if np.all(np.abs(scores - scores[0]) < TOL):
                raise ValueError(f"Factor {factor_name} has nearly constant scores")

            rank_ic = compute_rank_ic(scores, target_return)
            recall_20 = recall_at_n(scores, target_return, n=20)

            if np.isnan(rank_ic) or recall_20 == 0:
                raise ValueError(
                    f"Invalid metric for {factor_name} (NaN rank IC or recall@20 = 0)"
                )

        except Exception as e:
            print(f"[WARNING] Factor {factor_name} failed on step {i}: {e}")
            error_count += 1
            rank_ic = 0.0
            recall_20 = 0.0
            scores = np.full_like(target_return, 0.0)

        rank_ic_list.append(rank_ic)
        recall_20_list.append(recall_20)
        daily_scores.append(scores)

    # Evaluate portfolio performance from top-N selection
    portfolio_perf = eval_scores_topn_portfolio_performance(
        daily_scores, target_return_lists, n=back_test_asset_num, init_value=100
    )

    return rank_ic_list, recall_20_list, daily_scores, error_count, portfolio_perf


class DynamicFactorRuntime:
    # Inpput a function string and convert to runtime can be executed
    def __init__(self, func, func_str):

        self.func = func
        self.raw = func_str

    def __call__(self, *args, **kwargs):
        return self.func(*args, **kwargs)


class FactorRegister:
    def __init__(
        self, data, init_factors, min_steps, warmup_steps, max_stepsmax_gen_size=10
    ):
        self.data = data
        self.init_factor = init_factors
        self.min_steps = min_steps
        self.warmup_steps = warmup_steps
        self.gen_factor_lib = {}

        self.pool_size = len(self.init_factor)

    def add_llm_gen_factor(self, factor_name, factor_object):
        if factor_name in self.gen_factor_lib:
            print(f"Factor: {factor_name} is already in pool, pass.")
        else:
            self.gen_factor_lib[factor_name] = factor_object
            self.pool_size = len(self.gen_factor_lib) + len(self.init_factor)

    def clean_factor_pool_by_performance(
        self, factor_quality, portfolio_performance, max_size=200, keep_top_n=150
    ):
        if len(self.gen_factor_lib) <= max_size:
            print(
                f"Current generated factor pool size {len(self.gen_factor_lib)} is within limit {max_size}, no cleaning needed."
            )
            return
        else:
            print(
                f"Current generated factor pool size {len(self.gen_factor_lib)} exceeds limit {max_size}, cleaning needed."
            )

        # Step 1: Get top max_size factor names from each DataFrame
        top_factors_quality = set(factor_quality.index[:keep_top_n])
        top_factors_perf = set(portfolio_performance.index[:keep_top_n])

        # Step 2: Combine top factors from both sources
        factors_to_keep = top_factors_quality.union(top_factors_perf)

        # Step 3: Keep only those that are in self.gen_factor_lib
        gen_factor_names = set(self.gen_factor_lib.keys())
        gen_factors_to_keep = gen_factor_names.intersection(factors_to_keep)

        # Step 4: Drop the rest from gen_factor_lib
        to_delete = gen_factor_names - gen_factors_to_keep
        for name in to_delete:
            del self.gen_factor_lib[name]

        print(
            f"Cleaned {len(to_delete)} factors. New generated factor pool size: {len(self.gen_factor_lib)}."
        )

    def _gen_factor_performance_summary(
        self,
        rankic_performance_dict,
        portfolio_performance_dict,
        recall_performance_dict=None,
        sorted=True,
    ):
        # RankIC summary
        all_quality = {}
        for group in ["init", "gen"]:
            for factor_name, rankic_list in rankic_performance_dict[group].items():
                if factor_name not in all_quality:
                    all_quality[factor_name] = {}
                all_quality[factor_name]["mean_rankic"] = np.mean(rankic_list)
                # all_quality[factor_name]["std_rankic"] = np.std(rankic_list)

        # Recall summary (optional)
        if recall_performance_dict is not None:
            for group in ["init", "gen"]:
                for factor_name, recall_list in recall_performance_dict[group].items():
                    if factor_name not in all_quality:
                        all_quality[factor_name] = {}
                    all_quality[factor_name]["mean_recall@20"] = np.mean(recall_list)
                    # all_quality[factor_name]["std_recall@20"] = np.std(recall_list)

        quality_df = pd.DataFrame(all_quality).T  # Factors as rows
        if sorted and "mean_rankic" in quality_df.columns:
            quality_df = quality_df.sort_values(by="mean_rankic", ascending=False)

        # Portfolio performance summary
        all_perf = {}
        for group in ["init", "gen"]:
            for factor_name, perf_dict in portfolio_performance_dict[group].items():
                all_perf[factor_name] = {
                    "mean_return": perf_dict["mean_return"],
                    "std_return": perf_dict["std_return"],
                    "sharpe_ratio": perf_dict["sharpe_ratio"],
                    "max_drawdown": perf_dict["max_drawdown"],
                    "final_value": perf_dict["portfolio_values"][-1]
                    if perf_dict["portfolio_values"]
                    else np.nan,
                }

        portfolio_df = pd.DataFrame(all_perf).T
        if sorted and "final_value" in portfolio_df.columns:
            portfolio_df = portfolio_df.sort_values(by="final_value", ascending=False)

        return quality_df, portfolio_df

    def get_factor_desc(self, factor_names):
        factor_str_lists = []
        for factor in factor_names:
            if factor in self.init_factor:
                factor_str_lists.append(
                    (factor, get_implementation_source(factor, init_expression_lib))
                )
            elif factor in self.gen_factor_lib:
                factor_str_lists.append((factor, self.gen_factor_lib[factor].raw))
            else:
                print(
                    f"Doesn't support factor {factor}, it may be removed or not in pool."
                )
                continue

        return factor_str_lists

    def get_current_factor_pool_performance(self, step):
        rankic_performance_dict = {"init": {}, "gen": {}}
        recall_performance_dict = {"init": {}, "gen": {}}
        portfolio_performance_dict = {"init": {}, "gen": {}}
        daily_assets_score = {}

        # Combine init and gen for unified processing
        all_factors = {"init": self.init_factor, "gen": self.gen_factor_lib}

        print(
            "Current factor pool size:",
            len(all_factors["gen"]) + len(all_factors["init"]),
        )

        for group, factor_lib in all_factors.items():
            for factor_name in factor_lib:
                rankic_performance_dict[group][factor_name] = []
                recall_performance_dict[group][factor_name] = []
                daily_assets_score[factor_name] = []

        TOL = 1e-5
        error_counter = defaultdict(int)
        total_counter = defaultdict(int)
        invalid_factors = set()

        # We use 60 records window size
        for i in range(
            self.min_steps if step - self.min_steps < 60 else step - 60, step
        ):
            input_prices = self.data[i]["input_prices"]
            input_returns = np.array(self.data[i]["input_returns"])
            target_return = self.data[i]["target_return"]

            for group, factor_lib in all_factors.items():
                for factor_name, factor_expr in factor_lib.items():
                    total_counter[factor_name] += 1

                    try:
                        with warnings.catch_warnings(record=True) as w:
                            warnings.simplefilter("always", category=RuntimeWarning)
                            scores = compute_all_assets_scores(
                                factor_expr, input_prices, input_returns
                            )

                            for warning in w:
                                if issubclass(warning.category, RuntimeWarning):
                                    raise RuntimeWarning(
                                        f"Warning in factor {factor_name}: {warning.message}"
                                    )

                        if np.all(np.abs(scores - scores[0]) < TOL):
                            raise ValueError(
                                f"Factor {factor_name} has nearly constant scores"
                            )

                        rank_ic = compute_rank_ic(scores, target_return)
                        recall_20 = recall_at_n(scores, target_return, n=200)

                        if np.isnan(rank_ic) or recall_20 == 0:
                            # import pdb;pdb.set_trace()
                            raise ValueError(
                                f"Invalid metric for {factor_name} (NaN rank IC or recall@20 = 0)"
                            )

                        # Successful result
                        rankic_performance_dict[group][factor_name].append(rank_ic)
                        recall_performance_dict[group][factor_name].append(recall_20)
                        daily_assets_score[factor_name].append(scores)

                    except Exception as e:
                        print(f"[WARNING] Factor {factor_name} failed on step {i}: {e}")
                        error_counter[factor_name] += 1
                        rankic_performance_dict[group][factor_name].append(0.0)
                        recall_performance_dict[group][factor_name].append(0.0)
                        daily_assets_score[factor_name].append(
                            np.full_like(target_return, 0.0)
                        )

        # Post-processing: decide which factors to remove
        for factor_name in error_counter:
            error_rate = error_counter[factor_name] / total_counter[factor_name]
            if error_rate > 0.05:
                invalid_factors.add(factor_name)
                print(
                    f"[NOTICE] Removed factor {factor_name} due to high error rate ({error_rate:.1%})"
                )
                for group in rankic_performance_dict:
                    if factor_name in rankic_performance_dict[group]:
                        if group == "gen":
                            self.gen_factor_lib.pop(factor_name, None)
                        else:
                            self.init_factor.pop(factor_name, None)
            else:
                print(
                    f"[INFO] Factor {factor_name} had {error_counter[factor_name]} failures ({error_rate:.1%}) — tolerated."
                )

        # Evaluate portfolio performance
        target_return_lists = [
            self.data[i]["target_return"]
            for i in range(
                self.min_steps if step - self.min_steps < 60 else step - 60, step
            )
        ]

        for factor_name, scores in daily_assets_score.items():
            # Check if any score is NaN (using proper NaN checking)
            has_nan = any(
                np.isnan(score).any() if hasattr(score, "__iter__") else np.isnan(score)
                for score in scores
            )

            # Check if scores are constant (within precision 1e-5)
            is_constant = False
            if (
                not has_nan and len(scores) > 1
            ):  # Only check if no NaN and multiple scores exist
                first_score = scores[0]
                if hasattr(first_score, "__iter__"):  # If it's an array-like
                    # Check if all elements in all score arrays are approximately equal
                    is_constant = all(
                        np.allclose(score, first_score, atol=1e-5) for score in scores
                    )
                else:
                    # Check if all scalar scores are approximately equal
                    is_constant = all(
                        np.abs(score - first_score) < 1e-5 for score in scores
                    )

            if has_nan or is_constant:
                if has_nan:
                    print(f"{factor_name} has NaN values and will be removed!")
                if is_constant:
                    print(
                        f"{factor_name} produces constant values and will be removed!"
                    )
                self.gen_factor_lib.pop(factor_name, None)
                continue  # Skip to next factor

            # import pdb;pdb.set_trace()
            try:
                perf = eval_scores_topn_portfolio_performance(
                    scores, target_return_lists, n=10, init_value=100
                )
                group = "init" if factor_name in self.init_factor else "gen"
                portfolio_performance_dict[group][factor_name] = perf

            except AssertionError as e:
                # import pdb

                # pdb.set_trace()
                print(
                    f"Error when computing performance on {factor_name} error {e}, it will be removed!"
                )
                if factor_name in self.gen_factor_lib.keys():
                    self.gen_factor_lib.pop(factor_name, None)

        return self._gen_factor_performance_summary(
            rankic_performance_dict,
            portfolio_performance_dict,
            recall_performance_dict=recall_performance_dict,
            sorted=True,
        )

