import numpy as np


def analytical_svs(t, s, lamb, a_0, step_size, batch_size):
    """
    Calculates the analytical solution for the singular values 
    of my network with bias input.

    Args:
        t (np.array): The time points.
        s (float): the singular values of the input-output correlation matrix.
        lamb (float): the singular values of the input-input correlation matrix.
        a_0 (float): The initial mode strength
        step_size (float): The step size.
        batch_size (int): The batchsize used in SGD, we assume full batch here.

    Returns:
        np.array:
    """

    tau = 1/ (batch_size * step_size)
    numerator = s/lamb
    exponential_s = np.exp((- 2 * s /tau) * t)
    denominator = 1 - (1 - s/(lamb * a_0)) * exponential_s
    return numerator / denominator

def analytical_svs_shallow(t, s, lamb, a_0, step_size, batch_size):
    """
    Calculates the analytical solution for the singular values 
    of my shallow network with bias input.

    Args:
        t (np.array): The time points.
        s (float): the singular values of the input-output correlation matrix.
        lamb (float): the singular values of the input-input correlation matrix.
        a_0 (float): The initial mode strength
        step_size (float): The step size.
        batch_size (int): The batchsize used in SGD, we assume full batch here.

    Returns:
        np.array:
    """

    tau = 1/ (batch_size * step_size)
    left_side = s/lamb * (1- np.exp(-t*lamb/tau))
    right_side = a_0 * np.exp(-t*lamb/tau)
    return left_side + right_side

def ana_weights(u, ana_sol_svs, vt):
    """
    Calculates the analytical weights based on the left singular vectors,
    note that this is the multiplication of the left and right weight matrices.
    right singular vectors and singular values. over time

    Args:
        u(np.array): The left singular vectors of the analytical solution.
        ana_sol_svs(np.array): The singular values of the analytical solution over time (each column is a time point)
        vt(np.array): The right singular vectors of the analytical solution.

    Returns:
        np.array: The analytical weights W2 @ W1 over time shape (timepoints, n_out, n_in)
    """
    # getting the diagonal svs
    diag_ana_sol = np.array([np.diag(ana_sol_svs[:, i]) for i in range(ana_sol_svs.shape[1])])
    return np.array(
        [u @ diag_ana_sol[i] @ vt for i in range(diag_ana_sol.shape[0])]
    )

def squared_loss(ana_outputs, train_labels):
    """
    Calculates the squared loss between the analytical outputs and the training labels.

    Args:
        ana_outputs: The predicted outputs from the analytical solution.
        train_labels: The true labels from the training data.

    Returns:
        float: The squared loss between the predicted outputs and the true labels.
    """
    return 1/2 * np.sum((ana_outputs - train_labels)**2)

def ana_outputs(u, ana_sol_svs, vt, train_inputs):
    """
    Calculates the analytical outputs based on the analytical weights and training inputs.

    Args:
        u(np.array): The left singular vectors of the analytical solution.
        ana_sol_svs(np.array): The singular values of the analytical solution over time (each column is a time point)
        vt(np.array): The right singular vectors of the analytical solution.
        train_inputs(np.array): The training inputs.
    Returns:
        np.array: The analytical outputs.
    """
    weights = ana_weights(u, ana_sol_svs, vt)
    return np.array([weights[i] @ train_inputs for i in range(weights.shape[0])])

def ana_loss(u, ana_sol_svs, vt, train_inputs, train_labels):
    """
    Calculates the analytical loss based on the analytical weights and training inputs.

    Args:
        u(np.array): The left singular vectors of the analytical solution.
        ana_sol_svs(np.array): The singular values of the analytical solution over time 
                                (each column is a time point)
        vt(np.array): The right singular vectors of the analytical solution.
        train_inputs(np.array): The training inputs.
        train_labels(np.array): The training labels.

    Returns:
        float: The analytical loss.
    """
    ana_out = ana_outputs(u, ana_sol_svs, vt, train_inputs)
    return np.array([squared_loss(ana_out[i], train_labels) for i in range(ana_out.shape[0])])

def svs(A, U, VT):
    """
    Gets singular values of A using A's singular vectors U and VT 
    (keeps singular values aligned for plotting, unlike np.linalg.svd).

    Args:
    - A: The matrix A.
    - U: The left singular vectors of A.
    - VT: The right singular vectors of A.

    Returns:
    - U: The left singular vectors of A.
    - s: The singular values of A.
    - VT: The right singular vectors of A.
    """
    # Gets singular values of A using A's singular vectors U and VT 
    # (keeps singular values aligned for plotting, unlike np.linalg.svd)
    S = np.dot(U.T, np.dot(A,VT.T))
    small_length = np.min([S.shape[0], S.shape[1]])
    s = np.array([S[i,i] for i in range(small_length)])
    # s = s[:num_svds]
    return U, s, VT
