import numpy as np
import pandas as pd
import plotly.express as px
import torch as t
from scipy.optimize import curve_fit


def generate_log_log_plot(df: pd.DataFrame) -> None:
    fig = px.line(df, y="hit_times", title="Feature hit times", log_y=True, log_x=True)
    fig.write_html("feature_hit_times.html")


# Seems roughly zipf distributed until the high points, and there it might just be undesirably low, this is kinda like dead features.
# Let's get the exact distribution using curve_fit


def zipf_func(x: np.ndarray, bias: float, exponent: float, norm_constant: float) -> np.ndarray:
    return norm_constant / (x + bias) ** exponent


def fit_zipf_curve(features_hit_df: pd.DataFrame) -> tuple[np.ndarray, float]:
    eps = 1e-4

    lower_bounds = [eps, eps, eps]
    upper_bounds = [np.inf, np.inf, np.inf]
    initial_vals = [1.0, 1.0, 1.0]

    params, _ = curve_fit(
        f=zipf_func,
        xdata=features_hit_df["features"],
        ydata=features_hit_df["hit_times"],
        p0=initial_vals,
        bounds=(lower_bounds, upper_bounds),
        full_output=False,
    )

    y_pred = zipf_func(features_hit_df["features"], *params)  # type: ignore

    # Compute residuals (difference between observed and predicted y values)
    residuals = features_hit_df["hit_times"] - y_pred

    # Compute Sum of Squares of Residuals
    ss_res = np.sum(residuals**2)

    # Compute Total Sum of Squares
    ss_tot = np.sum(
        (features_hit_df["hit_times"] - np.mean(features_hit_df["hit_times"])) ** 2  # type: ignore
    )

    # Calculate R^2 value
    r_squared = 1 - (ss_res / ss_tot)

    return params, r_squared


if __name__ == "__main__":

    features_hit_times_F = t.load("paper_charts/features_hit_times_mc_before_training.pt")

    features_hit_times_F = t.sort(features_hit_times_F, descending=True).values
    features_hit_times_F_np = features_hit_times_F.cpu().numpy()
    df = pd.DataFrame({"hit_times": features_hit_times_F_np})

    # Create a 'features' column representing the ranks (starting from 1)
    df["features"] = np.arange(1, len(df) + 1)

    # Rearrange the columns to match the previous DataFrame structure
    df = df[["features", "hit_times"]]

    # Fit the Zipf curve to the data
    params, r_squared = fit_zipf_curve(df)
    print("Optimized parameters:", params)
    print("R^2 value:", r_squared)

    # Prepare data for plotting
    df_plot = pd.DataFrame(
        {
            "Rank": df["features"],
            "Zipf Distribution": zipf_func(df["features"].to_numpy(), *params),
            "Feature Density Before Training": df["hit_times"],
        }
    )

    # Melt the DataFrame to long format
    df_melted = df_plot.melt(
        id_vars="Rank",
        value_vars=["Zipf Distribution", "Feature Density Before Training"],
        var_name="Distribution",
        value_name="Value",
    )

    # Plot using Plotly Express with custom labels
    fig = px.line(
        df_melted,
        x="Rank",
        y="Value",
        color="Distribution",
        log_x=True,
        log_y=True,
        labels={"Rank": "Rank", "Value": "", "Distribution": ""},
        title="Feature Density At Initialisation & The Fitted Zipf Curve",
    )

    # Update the layout as per your requirements
    fig.update_layout(
        plot_bgcolor="white",
        xaxis=dict(
            showline=True,
            showgrid=False,
            linecolor="black",
            ticks="outside",
            tickfont=dict(size=20),
            title_font=dict(size=20),
        ),
        yaxis=dict(
            showline=True,
            showgrid=False,
            linecolor="black",
            ticks="",
            tickfont=dict(size=12),
            title_font=dict(size=14),
            showticklabels=False,
        ),
        legend=dict(
            title="",
            font=dict(size=18),
            bordercolor="black",
            borderwidth=1,
            xanchor="left",
            yanchor="bottom",
            x=0.1,
            y=0.1,
            bgcolor="rgba(255,255,255,0.5)",
        ),
        title_font=dict(
            size=30,
        ),
    )

    # Display the figure
    fig.show()

    # px.line(df, log_x=True, log_y=True).show()
