import math
from logging import FileHandler, Formatter, StreamHandler, getLogger
from tqdm import tqdm
from typing import Tuple

import numpy as np

log_fmt = Formatter(
    "%(asctime)s %(name)s L%(lineno)d [%(levelname)s][%(funcName)s] %(message)s "
)
logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel("INFO")
handler.setFormatter(log_fmt)
logger.setLevel("INFO")
logger.addHandler(handler)
handler = FileHandler("./result.log", "w")
handler.setLevel("INFO")
handler.setFormatter(log_fmt)
logger.setLevel("INFO")
logger.addHandler(handler)


def calc_tau(alpha: float, S: np.array, diag_i: np.array, diag_j: np.array) -> np.array:
    tau = 1 / 4 + 1 / (2 * math.pi) * np.arcsin(
        ((alpha ** 2) * S)
        / (np.sqrt(((alpha ** 2) * diag_i + 0.5) * ((alpha ** 2) * diag_j + 0.5)))
    )
    return tau


def calc_tau_dot(
    alpha: float, S: np.array, diag_i: np.array, diag_j: np.array
) -> np.array:
    tau_dot = (
        (alpha ** 2)
        / (math.pi)
        * 1
        / np.sqrt(
            (2 * (alpha ** 2) * diag_i + 1) * (2 * (alpha ** 2) * diag_j + 1)
            - (4 * (alpha ** 4) * (S ** 2))
        )
    )
    return tau_dot


def tree(
    X: np.array, max_depth: int, alpha: float
) -> Tuple[np.array, np.array, np.array]:
    K = np.zeros((8, X.shape[0], X.shape[0]))
    S = np.matmul(X, X.T)
    _diag = [S[i, i] for i in range(len(S))]
    diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))
    diag_j = diag_i.transpose()

    tau = calc_tau(alpha, S, diag_i, diag_j)
    tau_dot = calc_tau_dot(alpha, S, diag_i, diag_j)

    for i, depth in enumerate(tqdm(range(1, max_depth + 1, 1))):
        H = (2 * S * (2 ** (depth - 1)) * depth * tau_dot * tau ** (depth - 1)) + (
            (2 ** depth) * (tau ** depth)
        )
        if depth in (1, 2, 4, 8, 16, 32, 64, 128):
            K[int(np.log2(depth))] = H

    return K, tau, tau_dot


def asymtree(
    X: np.array, max_depth: int, alpha: float
) -> Tuple[np.array, np.array, np.array]:
    K = np.zeros((8, X.shape[0], X.shape[0]))
    S = np.matmul(X, X.T)
    _diag = [S[i, i] for i in range(len(S))]
    diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))
    diag_j = diag_i.transpose()

    tau = calc_tau(alpha, S, diag_i, diag_j)
    tau_dot = calc_tau_dot(alpha, S, diag_i, diag_j)

    for i, depth in enumerate(tqdm(range(1, max_depth + 1, 1))):
        if i == 0:
            H = (depth * S * tau_dot * (tau ** (depth - 1))) + (tau ** depth)
        else:
            H += (depth * S * tau_dot * (tau ** (depth - 1))) + (tau ** depth)
        if depth in (1, 2, 4, 8, 16, 32, 64, 128):
            K[int(np.log2(depth))] = (
                H + (depth * S * tau_dot * (tau ** (depth - 1))) + (tau ** depth)
            )

    return K, tau, tau_dot


def inf_asymtree(
    X: np.array, max_depth, alpha: float
) -> Tuple[np.array, np.array, np.array]:
    K = np.zeros((8, X.shape[0], X.shape[0]))
    S = np.matmul(X, X.T)
    _diag = [S[i, i] for i in range(len(S))]
    diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))
    diag_j = diag_i.transpose()

    tau = calc_tau(alpha, S, diag_i, diag_j)
    tau_dot = calc_tau_dot(alpha, S, diag_i, diag_j)
    for i, depth in enumerate(range(1, max_depth + 1, 1)):
        if depth in (1, 2, 4, 8, 16, 32, 64, 128):
            K[int(np.log2(depth))] = (S * tau_dot / ((1 - tau) ** 2)) + (
                tau / (1 - tau)
            )

    return K, tau, tau_dot


# For notebook
def tree_viz(
    X: np.array, max_depth: int, alpha: float
) -> Tuple[np.array, np.array, np.array]:
    K = np.zeros((max_depth, X.shape[0], X.shape[0]))
    S = np.matmul(X, X.T)
    _diag = [S[i, i] for i in range(len(S))]
    diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))
    diag_j = diag_i.transpose()

    tau = calc_tau(alpha, S, diag_i, diag_j)
    tau_dot = calc_tau_dot(alpha, S, diag_i, diag_j)

    for i, depth in enumerate((range(1, max_depth + 1, 1))):
        H = (2 * S * (2 ** (depth - 1)) * depth * tau_dot * tau ** (depth - 1)) + (
            (2 ** depth) * (tau ** depth)
        )
        K[depth - 1] = H

    return K, tau, tau_dot


def asymtree_viz(
    X: np.array, max_depth: int, alpha: float
) -> Tuple[np.array, np.array, np.array]:
    K = np.zeros((max_depth, X.shape[0], X.shape[0]))
    S = np.matmul(X, X.T)
    _diag = [S[i, i] for i in range(len(S))]
    diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))
    diag_j = diag_i.transpose()

    tau = calc_tau(alpha, S, diag_i, diag_j)
    tau_dot = calc_tau_dot(alpha, S, diag_i, diag_j)

    for i, depth in enumerate((range(1, max_depth + 1, 1))):
        if i == 0:
            H = (depth * S * tau_dot * (tau ** (depth - 1))) + (tau ** depth)
        else:
            H += (depth * S * tau_dot * (tau ** (depth - 1))) + (tau ** depth)
        K[depth - 1] = H + (depth * S * tau_dot * (tau ** (depth - 1))) + (tau ** depth)

    return K, tau, tau_dot


def inf_asymtree_viz(
    X: np.array, max_depth, alpha: float
) -> Tuple[np.array, np.array, np.array]:
    K = np.zeros((max_depth, X.shape[0], X.shape[0]))
    S = np.matmul(X, X.T)
    _diag = [S[i, i] for i in range(len(S))]
    diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))
    diag_j = diag_i.transpose()

    tau = calc_tau(alpha, S, diag_i, diag_j)
    tau_dot = calc_tau_dot(alpha, S, diag_i, diag_j)
    for i, depth in enumerate(range(1, max_depth + 1, 1)):
        K[depth - 1] = (S * tau_dot / ((1 - tau) ** 2)) + (tau / (1 - tau))

    return K, tau, tau_dot
