import math
from logging import FileHandler, Formatter, StreamHandler, getLogger
from typing import Tuple

import numpy as np

log_fmt = Formatter(
    "%(asctime)s %(name)s L%(lineno)d [%(levelname)s][%(funcName)s] %(message)s "
)
logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel("INFO")
handler.setFormatter(log_fmt)
logger.setLevel("INFO")
logger.addHandler(handler)
handler = FileHandler("./result.log", "w")
handler.setLevel("INFO")
handler.setFormatter(log_fmt)
logger.setLevel("INFO")
logger.addHandler(handler)


def mlp_relu(X: np.array, max_depth: int) -> np.array:
    K = np.zeros((max_depth, X.shape[0], X.shape[0]))
    S = np.matmul(X, X.T)
    H = np.zeros_like(S)
    for dep in range(max_depth):
        H += S
        K[dep] = H

        L = np.diag(S)
        P = np.clip(np.sqrt(np.outer(L, L)), a_min=1e-9, a_max=None)
        Sn = np.clip(S / P, a_min=-1, a_max=1)
        S = (
            (Sn * (math.pi - np.arccos(Sn)) + np.sqrt(1.0 - Sn * Sn))
            * P
            / 2.0
            / math.pi
        )
        H = H * (math.pi - np.arccos(Sn)) / 2.0 / math.pi

    return K


def calc_tau(alpha: float, S: np.array, diag_i: np.array, diag_j: np.array) -> np.array:
    tau = 1 / 4 + 1 / (2 * math.pi) * np.arcsin(
        ((alpha ** 2) * S)
        / (np.sqrt(((alpha ** 2) * diag_i + 0.5) * ((alpha ** 2) * diag_j + 0.5)))
    )
    return tau


def calc_tau_dot(
    alpha: float, S: np.array, diag_i: np.array, diag_j: np.array
) -> np.array:
    tau_dot = (
        (alpha ** 2)
        / (math.pi)
        * 1
        / np.sqrt(
            (2 * (alpha ** 2) * diag_i + 1) * (2 * (alpha ** 2) * diag_j + 1)
            - (4 * (alpha ** 4) * (S ** 2))
        )
    )
    return tau_dot


def tree(
    X: np.array, max_depth: int, alpha: float
) -> Tuple[np.array, np.array, np.array]:
    K = np.zeros((max_depth, X.shape[0], X.shape[0]))
    S = np.matmul(X, X.T)
    _diag = [S[i, i] for i in range(len(S))]
    diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))
    diag_j = diag_i.transpose()

    for i, depth in enumerate(range(1, max_depth + 1, 1)):
        tau = calc_tau(alpha, S, diag_i, diag_j)
        tau_dot = calc_tau_dot(alpha, S, diag_i, diag_j)
        H = (2 * S * (2 ** (depth - 1)) * depth * tau_dot * tau ** (depth - 1)) + (
            (2 ** depth) * (tau ** depth)
        )
        K[i] = H

    return K, tau, tau_dot
