import jax
import jax.numpy as jnp
import numpy as np
from sklearn.linear_model import LinearRegression


def expected_simplified_loss_no_ds_repetition(B, T, d):
    def _loss(params, _):
        w, delta_a = params["w"], params["delta_a"]
        alpha = 1 / ((T - B) * jnp.exp(-delta_a) + B)
        W_star_norm = jnp.sqrt(d)
        loss = (B * alpha * w - W_star_norm) ** 2 / d
        loss = loss + (1 - B * alpha) ** 2 / (T - B) / d * w**2
        return 0.5 * loss

    return _loss


def initial_simplified_loss_no_ds_repetition(B, T, d):
    def _loss(params, _):
        w, delta_a = params["w"], params["delta_a"]
        lambd = B / T / jnp.sqrt(d)
        return 0.5 - lambd * w - lambd * w * delta_a

    return _loss


def w_inf(B, T, d, alpha):
    denom = (1 - B * alpha) ** 2 / d / (T - B) + B**2 * alpha**2 / d
    return B * alpha / denom / jnp.sqrt(d)


def expected_simplified_loss_ds_repetition(B, p, T, d):
    def _loss(params, _):
        w0, w, delta_a = params["w0"], params["w"], params["delta_a"]
        alpha = 1 / ((T - B) * jnp.exp(-delta_a) + B)
        loss = (B * alpha * w0 - 1) ** 2 * (p + (1 - p) / d)
        loss = loss + (B * alpha * w - jnp.sqrt(d - 1)) ** 2 * (1 - p) / d
        loss = loss + (1 - B * alpha) ** 2 / (T - B) / d * (w**2 + w0**2)
        return 0.5 * loss

    return _loss


def sampled_loss(params, X, y):
    return


def get_loss(B, T, d, p=0.0, type="expectation_simplified"):
    if type == "expectation_simplified" and p == 0.0:
        # no dataset repetition
        return expected_simplified_loss_no_ds_repetition(B, T, d)
    elif type == "expectation_simplified":
        # dataset repetition
        return expected_simplified_loss_ds_repetition(B, p, T, d)
    elif type == "initial_simplified" and p == 0.0:
        return initial_simplified_loss_no_ds_repetition(B, T, d)
    elif type == "sampled":
        return sampled_loss(B, p)
    else:
        raise ValueError(f"Invalid type: {type}")


def initialize_params(key, B, p, T, type="expectation_simplified"):
    if type == "expectation_simplified" and p == 0.0:
        w = 0.0
        delta_a = 0.0
        teacher = None
        return {"w": w, "delta_a": delta_a}, teacher
    elif type == "expectation_simplified":
        w0 = 0.0
        w = 0.0
        delta_a = 0.0
        teacher = None
        return {"w0": w0, "w": w, "delta_a": delta_a}, teacher
    else:
        raise ValueError(f"Invalid type: {type}")


def optimize(losses, params, data, n_steps, lr):
    train_loss, eval_loss = losses

    def _step(params, _):
        tl, g = jax.value_and_grad(train_loss)(params, data)
        el = eval_loss(params, data)
        new_params = jax.tree.map(lambda p, g: p - g * lr, params, g)
        return new_params, (tl, el, params)

    return jax.lax.scan(_step, params, jnp.arange(n_steps))[1]


def hessian_init(p, d, B, T):
    if p > 0.0:
        sigma = (1 - p) / d
        sigma_tilde = (1 - p) / d + p
        H = jnp.array(
            [
                [
                    -(T - B) / T**2 / d - B**2 * sigma_tilde / T**2,
                    0,
                    B * sigma_tilde * (T - B) / T**2,
                ],
                [
                    0,
                    -(T - B) / T**2 / d - B**2 * sigma / T**2,
                    B * sigma * (T - B) * jnp.sqrt(d - 1) / T**2,
                ],
                [B * sigma_tilde * (T - B) / T**2, B * sigma * (T - B) * jnp.sqrt(d - 1) / T**2, 0],
            ]
        )
        return H
    else:
        raise ValueError("p == 0 not implemented yet.")


def plateau_length(p, d, B, T):
    return 1 / jnp.linalg.eigvalsh(hessian_init(p, d, B, T))[-1]


def compute_scaling_law(df, x_col, y_col, c_col, extra_cols=[]):
    df = df[[x_col, y_col, c_col] + extra_cols]

    # 1. Aggregate data: Group by dimension and sequence length and calculate the mean plateau length.
    #    Use the filtered data.
    agg_data = df.groupby([x_col, c_col] + extra_cols)[y_col].mean().reset_index()

    # 2. Prepare data for log-log regression
    if not agg_data.empty:
        # Calculate logarithms
        agg_data["log_" + x_col] = np.log(agg_data[x_col])
        agg_data["log_" + y_col] = np.log(agg_data[y_col])
        agg_data["log_" + c_col] = np.log(agg_data[c_col])
        for col in extra_cols:
            agg_data["log_" + col] = np.log(agg_data[col])

        # 3. Fit the linear model in log-log space
        X = agg_data[["log_" + x_col, "log_" + c_col] + ["log_" + col for col in extra_cols]]
        y = agg_data["log_" + y_col]

        model = LinearRegression()
        model.fit(X, y)

        # 4. Extract parameters
        log_c = model.intercept_
        a, b = model.coef_[0], model.coef_[1]
        c = np.exp(log_c)

        # 5. Print the results
        print(f"Fitted power law: {y_col} = c * ({x_col}^a) * ({c_col}^b)")
        print(f"  a (exponent for {x_col}) = {a:.4f}")
        print(f"  b (exponent for {c_col}) = {b:.4f}")
        print(f"  c (constant) = {c:.4f}")
        if extra_cols:
            for i, col in enumerate(extra_cols):
                print(f"  {col} (exponent) = {model.coef_[i+2]:.4f}")
        print(f"  R^2 score: {model.score(X, y):.4f}")

    if extra_cols == []:
        return a, b, c
    else:
        return c, model.coef_
