import numpy as np
import torch
from operators import *
from sklearn.datasets import make_spd_matrix, make_sparse_spd_matrix
from scipy.linalg import subspace_angles
from sklearn.linear_model import LinearRegression
from CNN_functions import FCN, CNN


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def tanh(x):
    return np.tanh(x)


def relu(x):
    return np.maximum(0, x)


def leaky_relu(x):
    return np.maximum(0.01 * x, x)


def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)


def x_plus_sin(x, a=1, b=1):
    return a * x + b * np.sin(x)


def sine_distance(x, y):
    cosine = x.dot(y) / (np.linalg.norm(x) * np.linalg.norm(y))
    return np.sqrt(np.abs(1 - cosine ** 2))


def compute_distance(A, B):
    model = LinearRegression(fit_intercept=False)
    model.fit(B, A)
    H = model.coef_.T
    dist = np.linalg.norm(A - B.dot(H))
    return dist


def unit(x, ord=None):
    return x / np.linalg.norm(x, ord)


def standardize(x):
    return (x - x.mean()) / x.std()


def generate_Cov(p, mode='decay', corr_strength=1.0, rho=0.9, alpha=0.99, random_state=0):
    if mode == 'decay':
        W = np.zeros([p, p])
        for i in range(p):
            for j in range(p):
                W[i, j] = 1 - (np.abs(i - j)) / p
        W = W * corr_strength
        np.fill_diagonal(W, 1)
    elif mode == 'exp_decay':
        W = np.zeros([p, p])
        for i in range(p):
            for j in range(p):
                W[i, j] = np.power(rho, np.abs(i - j))
    elif mode == 'make_spd':
        W = make_spd_matrix(p, random_state=random_state)
    elif mode == 'make_sparse_spd':
        W = make_sparse_spd_matrix(p, alpha=alpha, random_state=random_state)
    elif mode == 'sparse_precision':
        P = generate_Precision(p, mode='sparse')
        W = np.linalg.inv(P)
    elif mode == 'dense_precision':
        P = generate_Precision(p, mode='dense')
        W = np.linalg.inv(P)
    else:
        raise ValueError("mode must be 'dense', 'decay', 'exp_decay', 'sparse', 'make_spd', or 'make_sparse_spd'.")
    return W


def generate_Precision(p, mode='sparse'):
    W = np.eye(p)
    if mode == 'sparse':
        for i in range(p-1):
            W[i, i + 1] = 0.5
            W[i + 1, i] = 0.5
    elif mode == 'dense':
        W = W + 1

    return W


def truncate(x, tau):
    return np.sign(x) * np.minimum(np.abs(x), tau)


def phi(x, side='both'):
    x = np.maximum(x, -np.log(1 - x + x ** 2 / 2))
    x = np.minimum(x, np.log(1 + x + x ** 2 / 2))
    if side == 'positive':
        x = np.maximum(x, 0)
    elif side == 'negative':
        x = np.minimum(x, 0)
    else:
        x = x

    return x


def psi(A, side='both'):
    P, D = A.shape
    A_tilde = np.zeros((P + D, P + D))
    A_tilde[:P, P:] = A
    A_tilde[P:, :P] = A.T

    eigenvalues, eigenvectors = np.linalg.eigh(A_tilde)
    B = eigenvectors.dot(np.diag(phi(eigenvalues, side))).dot(eigenvectors.T)
    psi_A = B[:P, P:]

    return psi_A


def Index_model(input, link='Linear', noise_level=0.1, seed=0):
    rng = np.random.RandomState(seed)
    A = rng.normal(size=input[0].shape)
    if link == 'Linear':
        Y = np.array([np.vdot(A, Xi) for Xi in input])
        model = link
    elif link == 'Non-linear':
        Y = np.array([np.vdot(A, x_plus_sin(Xi, a=1, b=3)) for Xi in input])
        model = link
    elif link == 'FCN':
        input = torch.from_numpy(input.reshape(len(input), -1)).float()
        model = FCN(input_dim=input.shape[1])
        Y = model(input)
        Y = Y.detach().cpu().numpy()
    elif link == 'CNN':
        if A.ndim == 2:
            model = CNN(input_shape=A.shape, in_channel=1)
            input = torch.from_numpy(np.expand_dims(input, axis=1)).float()
        else:
            model = CNN(input_shape=A.shape[1:], in_channel=A.shape[0])
            input = torch.from_numpy(input).float()
        Y = model(input)
        Y = Y.detach().cpu().numpy()
    else:
        raise ValueError('Invalid link function')

    Y = (Y - Y.mean()) / Y.std()
    Y = Y + rng.normal(scale=noise_level, size=Y.shape)
    return Y.reshape(-1, 1), model


def estimate_Theta(RX, Y, score_func, R=1, params={}, truncate=False, k=1, side='both'):
    '''
    Parameters
    :param RX: RX
    :param Y: Y
    :param score_func: score function
    :param params: params for score_func
    :param truncate: True or False
    :param k: kappa of truncation
    :param side: 'both', 'positive' or 'negative'

    :return: beta: vectorized filter
    '''
    n, *Rshape = RX.shape
    p, d = Rshape
    RX_vec = RX.reshape(n, -1)
    # S = np.array([score_func(xi, **params) for xi in RX_vec])
    S = score_func(RX_vec, **params)
    if truncate:
        if k == 'default':
            M = np.maximum(np.mean(Y ** 4), np.mean(S[0] ** 4))
            kappa = 2 * np.sqrt(n * np.log(p + d)) / np.sqrt((p + d) * M)
        else:
            kappa = k
        S = np.array([si.reshape(Rshape) for si in S])
        YS = np.mean([psi(kappa * yi * Si, side) / kappa for (yi, Si) in zip(Y, S)], axis=0).reshape(Rshape)
    else:
        YS = S.T.dot(Y).reshape(Rshape)

    U, _, Vt = np.linalg.svd(YS)
    b_hat = Vt[:R].T

    return b_hat

