import pandas as pd
import numpy as np


def calculate_performance_metrics(daily_returns_list, portfolio_value_list):
    # Convert to numpy arrays for easier calculations
    daily_returns = np.array(daily_returns_list)
    portfolio_values = np.array(portfolio_value_list)

    # 1. Calculate Annualized Sharpe Ratio
    # Assuming risk-free rate is zero for simplicity (you can adjust this)
    risk_free_rate = 0.0
    sharpe_ratio = (np.mean(daily_returns - risk_free_rate)) / np.std(daily_returns)
    annualized_sharpe = sharpe_ratio * np.sqrt(252)  # Annualizing for trading days

    # 2. Calculate CVaR (Conditional Value at Risk) at 95% confidence level
    confidence_level = 0.95
    sorted_returns = np.sort(daily_returns)
    var_index = int((1 - confidence_level) * len(sorted_returns))
    cvar = -np.mean(sorted_returns[:var_index])

    # 3. Calculate Maximum Drawdown
    cumulative_returns = np.cumprod(1 + daily_returns) - 1
    peak = np.maximum.accumulate(cumulative_returns)
    drawdown = (peak - cumulative_returns) / (1 + peak)
    max_drawdown = np.max(drawdown)

    # Alternatively, using portfolio values:
    # peaks = np.maximum.accumulate(portfolio_values)
    # drawdowns = (peaks - portfolio_values) / peaks
    # max_drawdown = np.max(drawdowns)

    return {
        "sharpe_ratio": sharpe_ratio,
        "annu_sharpe_ratio": annualized_sharpe,
        "cvar_95": cvar,
        "max_drawdown": max_drawdown,
    }


def round_all_results(results_dict):
    def round_nested(value):
        if isinstance(value, list):
            return [round_nested(v) for v in value]
        elif isinstance(value, float):
            return round(value, 5)
        else:
            return value

    return {k: round_nested(v) for k, v in results_dict.items()}


def eval_scores_topn_portfolio_performance(
    assets_scores_lists, target_return_lists, n=5, init_value=100, reverse=False
):
    """
    Evaluate performance of a top-n scoring asset portfolio.

    Args:
        assets_scores_lists (List[np.ndarray]): List of per-asset scores for each day.
        target_return_lists (List[np.ndarray]): List of per-asset returns for each day.
        n (int): Number of top assets to include in portfolio.
        init_value (float): Starting portfolio value.
        reverse (bool): If True, select bottom-n assets instead of top-n.

    Returns:
        dict: {
            'portfolio_values': List of cumulative portfolio values,
            'daily_returns': List of daily portfolio returns,
            'mean_return': Float,
            'std_return': Float,
            'sharpe_ratio': Float (daily),
            'max_drawdown': Float,
            'selected_indices': List of lists, each inner list contains selected asset indices for that day
        }
    """
    assert (
        np.array(assets_scores_lists).shape == np.array(target_return_lists).shape
    ), f"Wrong shape scores: {np.array(assets_scores_lists).shape} returns: {np.array(target_return_lists).shape} doesn't match"

    def get_top_indices(scores, n, reverse=False):
        """
        Get indices of top-n scores.
        
        If reverse=True, return indices of highest scores (descending).
        If reverse=False, return indices of lowest scores (ascending).
        """
        scores = np.array(scores)
        sorted_indices = scores.argsort()
        if reverse:
            return sorted_indices[-n:][::-1]  # Top-n largest
        else:
            return sorted_indices[:n]  # Top-n smallest

    portfolio_values_list = []
    daily_returns_list = []
    selected_indices_list = []
    cum_value = init_value

    for i, target_return in enumerate(target_return_lists):
        scores = assets_scores_lists[i]

        top_indices = np.array(scores).argsort()[-n:][
            ::-1
        ]  # Select top-n asset indices
        selected_indices_list.append(top_indices.tolist())

        top_returns = target_return[top_indices]
        weights = np.ones(n) / n  # Equal weight
        portfolio_return = np.dot(top_returns, weights)

        cum_value *= portfolio_return
        daily_returns_list.append(np.array(portfolio_return) - 1)
        portfolio_values_list.append(cum_value)

    daily_returns_array = np.array(daily_returns_list)
    mean_return = daily_returns_array.mean()
    std_return = daily_returns_array.std()
    sharpe_ratio = mean_return / std_return if std_return > 0 else 0

    portfolio_array = np.array(portfolio_values_list)
    peak = np.maximum.accumulate(portfolio_array)
    drawdown = (portfolio_array - peak) / peak
    max_drawdown = drawdown.min()

    results = {
        "portfolio_values": portfolio_values_list,
        "daily_returns": daily_returns_list,
        "mean_return": mean_return,
        "std_return": std_return,
        "sharpe_ratio": sharpe_ratio,
        "max_drawdown": max_drawdown,
        "selected_indices": selected_indices_list,
        "final_value": cum_value,
    }

    rounded_results = round_all_results(results)

    return rounded_results


def eval_equal_weight_portfolio_performance(target_return_lists, init_value=100):
    """
    Evaluate performance of an equal-weighted portfolio.
    
    Args:
        target_return_lists (List[np.ndarray]): List of daily return arrays for each day. Each array represents per-asset returns that day.
        init_value (float): Starting value of the portfolio.

    Returns:
        dict: {
            'portfolio_values': List of cumulative portfolio values,
            'daily_returns': List of daily portfolio returns (averaged across assets),
            'mean_return': Float,
            'std_return': Float,
            'sharpe_ratio': Float (daily),
            'max_drawdown': Float
        }
    """
    portfolio_values_list = []
    daily_returns_list = []
    cum_value = init_value

    for i, target_return in enumerate(target_return_lists):
        avg_daily_return = target_return.mean()
        cum_value *= avg_daily_return
        daily_returns_list.append(np.array(avg_daily_return) - 1)
        portfolio_values_list.append(cum_value)

    daily_returns_array = np.array(daily_returns_list)
    mean_return = daily_returns_array.mean()
    std_return = daily_returns_array.std()
    sharpe_ratio = mean_return / std_return if std_return > 0 else 0

    portfolio_array = np.array(portfolio_values_list)
    peak = np.maximum.accumulate(portfolio_array)
    drawdown = (portfolio_array - peak) / peak
    max_drawdown = drawdown.min()

    results = {
        "portfolio_values": portfolio_values_list,
        "daily_returns": daily_returns_list,
        "mean_return": mean_return,
        "std_return": std_return,
        "sharpe_ratio": sharpe_ratio,
        "max_drawdown": max_drawdown,
        "final_value": cum_value,
    }

    rounded_results = round_all_results(results)

    return rounded_results


if __name__ == "__main__":
    # Generate synthetic daily return data for 30 days and 5 assets
    np.random.seed(42)
    num_days = 30
    num_assets = 5

    # Simulate daily returns around 1.001 (approx. 0.1% daily return) with small noise
    target_return_lists = [
        1 + np.random.normal(loc=0.001, scale=0.01, size=num_assets)
        for _ in range(num_days)
    ]

    # Call the evaluation function
    result = eval_equal_weight_portfolio_performance(
        target_return_lists, init_value=100
    )

    # Print results
    print("Portfolio Values:", result["portfolio_values"])
    print("Mean Daily Return:", result["mean_return"])
    print("Std of Daily Return:", result["std_return"])
    print("Sharpe Ratio (daily):", result["sharpe_ratio"])
    print("Max Drawdown:", result["max_drawdown"])

    # Simulate example with random scores and returns
    np.random.seed(0)
    num_days = 30
    num_assets = 10

    # Scores and returns: each day a vector of length num_assets
    assets_scores_lists = [np.random.rand(num_assets) for _ in range(num_days)]
    target_return_lists = [
        1 + np.random.normal(loc=0.001, scale=0.01, size=num_assets)
        for _ in range(num_days)
    ]

    # Run evaluation
    result = eval_scores_topn_portfolio_performance(
        assets_scores_lists, target_return_lists, n=5
    )

    # Print summary
    print("Mean Return:", result["mean_return"])
    print("Std Return:", result["std_return"])
    print("Sharpe Ratio:", result["sharpe_ratio"])
    print("Max Drawdown:", result["max_drawdown"])
    print("Selected Indices (first 5 days):", result["selected_indices"][:5])
