import numpy as np
import pandas as pd

# ------------------Functions to learn shapelets --------------------
def USIDL(X, y, lambda_, K, q, c, epsilon, maxIter, maxInnerIter, runid):
    """
    X: n x p, training data, n times series with length p
    y: binary label (-1, 1), not used in this model, just for plotting
    lambda_: regularization parameter for l1 norm
    K: number of basis
    q: length of basis over
    c: Squared L2-norm of basis, i.e., ||s_k||^2 <= c
    epsilon: epsilon
    maxIter: maximum outer iterations
    maxInnerIter: maximum inner iterations
    runid: magic string prefix for plotting

    Returns:
        S: learned basis
        A: coefficients for training data
        Offsets: matched location of the basis
        F_obj: array of objective values
    """
    n, p = X.shape

    S = np.random.randn(K, q)  # initialize bases
    A = np.random.randn(n, K)  # basis initializations
    Offsets = np.random.randint(0, p - q + 1, (n, K))  # initialize offsets

    F_obj = []

    for iter in range(int(maxIter)):
        # update coefficients and matching offsets
        A, Offsets, F_all = update_A_par(X, S, A, Offsets, lambda_, maxInnerIter, epsilon)

        # update bases
        S = update_S(X, S, A, Offsets, lambda_, c, maxInnerIter, epsilon)

        # check convergence
        F_all = unsup_obj(X, S, A, Offsets, lambda_)

        F_obj.append(F_all)
        if len(F_obj) > 1 and abs(F_obj[-1] - F_obj[-2]) / F_obj[-2] < epsilon:
            print('Converged!')
            return S, A, Offsets, F_obj

    print('Maximum Iteration Reached!')
    return S, A, Offsets, F_obj


def op_shift(S, offsets, target_dim):
    """
    Shifts the basis functions according to the given offsets.

    Parameters:
    S (numpy.ndarray): Basis functions (K x q)
    offsets (numpy.ndarray): Offsets (1 x K)
    target_dim (int): Target dimension for the shifted basis functions

    Returns:
    numpy.ndarray: Shifted basis functions (K x target_dim)
    """
    K, q = S.shape

    res = np.zeros((K, target_dim))
    for i in range(K):
        offset = offsets[i]
        if offset + q <= target_dim:
            res[i, offset:offset + q] = S[i]
        else:
            raise ValueError(f"Offset {offset} with q {q} exceeds target_dim {target_dim}")

    return res

def update_A_par(X, S, A, Offsets, lambda_, maxIter, epsilon):
    """
    X: n x p
    S: K x q
    A: n x K
    Offsets: n x K

    Returns:
        A: updated coefficients for training data
        Offsets: updated matched location of the basis
        F_all: final objective value
    """
    n, p = X.shape
    KK, q = S.shape
    seg_idx = np.add.outer(np.arange(p - q + 1), np.arange(q))

    F_obj = []
    for iter in range(int(maxIter)):
        for i in range(n):  # compute activation and matching offset for X_i
            x = X[i, :]
            offs = Offsets[i, :]
            shifted_S = op_shift(S, offs, p)

            # compute for base k
            for k in np.random.permutation(KK):
                base = S[k, :]
                temp_a = A[i, :].copy()
                temp_a[k] = 0  # exclude alpha_k

                x_residue = x - np.dot(temp_a, shifted_S)
                residue_norm2 = np.linalg.norm(x_residue) ** 2
                base_norm2 = np.linalg.norm(base) ** 2  # ||s_k||^2

                segs = x_residue[seg_idx]
                dot_prods = np.dot(segs, base)

                M_idx = np.argmax(np.abs(dot_prods))
                M_dp = dot_prods[M_idx]

                if np.abs(M_dp) <= lambda_:
                    a_k_star = 0
                else:
                    a_k_star = np.sign(M_dp) * (np.abs(M_dp) - lambda_) / base_norm2
                    t_k_star = M_idx

                A[i, k] = a_k_star
                if a_k_star != 0:
                    shifted_S[k, :] = 0
                    shifted_S[k, t_k_star: t_k_star + q] = base
                    Offsets[i, k] = t_k_star

        F_all = unsup_obj(X, S, A, Offsets, lambda_)
        F_obj.append(F_all)
        if len(F_obj) > 1 and abs(F_obj[-1] - F_obj[-2]) / F_obj[-2] < epsilon:
            # print('Updating A: Converged!\n\n')
            return A, Offsets, F_all

    # print('Updating A: Reached max iter.\n\n')
    return A, Offsets, F_all


def update_S(X, S, A, Offsets, lambda_, c, maxIter, epsilon):
    """
    X: n x p
    S: K x q
    A: n x K
    Offsets: n x K

    Returns:
        S: updated basis functions
    """
    n, p = X.shape
    K, q = S.shape

    F_obj = []

    for iter in range(int(maxIter)):
        for k in range(K):  # optimize s_k
            M_k = np.linalg.norm(A[:, k]) ** 2
            if M_k == 0:  # inactive bases, no need to update
                continue

            s_k = np.zeros(q)

            for i in range(n):
                temp_a = A[i, :].copy()
                temp_a[k] = 0
                shifted_S = op_shift(S, Offsets[i, :], p)
                xi_residue = X[i, :] - np.dot(temp_a, shifted_S)

                t_ik = Offsets[i, k]
                s_k += A[i, k] * xi_residue[t_ik:t_ik + q]

            # compute s_k
            if M_k <= np.linalg.norm(s_k) / np.sqrt(c):
                s_k = np.sqrt(c) / np.linalg.norm(s_k) * s_k
            else:
                s_k = s_k / M_k

            S[k, :] = s_k

        F_all = unsup_obj(X, S, A, Offsets, lambda_)
        F_obj.append(F_all)
        if len(F_obj) > 1 and abs(F_obj[-1] - F_obj[-2]) / F_obj[-2] < epsilon:
            # print('Updating S: Converged!\n\n')
            return S

    # print('Updating S: Reached max iter.\n\n')
    return S



def unsup_obj(X, S, A, Offsets, lambda_):
    """
    X: n x p
    S: K x q
    A: n x K
    Offsets: n x K

    Returns:
        F: objective value
    """
    n, p = X.shape
    K, q = S.shape

    F = 0
    reconstruction = []

    for i in range(n):
        x = X[i, :]
        shifted_S = op_shift(S, Offsets[i, :], p)
        reconstruction.append(np.dot(A[i, :], shifted_S))
        F += 0.5 * np.linalg.norm(x - np.dot(A[i, :], shifted_S)) ** 2 + lambda_ * np.linalg.norm(A[i, :], 1)

    return F
