import numpy as np
import os
from tqdm import trange
from .io import load_npy


def EJinf(a, sigma, g):
    """Ground truth V from t=0 to infinity"""
    if a == 0:
        return sigma**2 / np.log(g) ** 2
    else:
        return sigma**2 / (np.log(g) * (np.log(g) + 2 * a))


def EJ(a, T, sigma, g=1):
    """Ground truth V from t=0 to T

    g: discount factor (default 1)
    """
    if g == 1:
        if a == 0:
            return sigma**2 * (T**2) / 2
        else:
            return (sigma**2) / (4 * a**2) * (np.exp(2 * a * T) - 1 - 2 * a * T)
    else:
        if a == 0:
            return sigma**2 * (1 - g**T + g**T * T * np.log(g)) / np.log(g) ** 2
        else:
            return (
                sigma**2
                / (2 * a)
                * (
                    (g**T * np.exp(2 * a * T) - 1) / (np.log(g) + 2 * a)
                    - (g**T - 1) / np.log(g)
                )
            )


def get_theta_star(a, g):
    return 1.0 / (np.log(g) + 2 * a)


def get_lr_str(lr_scheduler, alpha, lr_gamma):
    """Return the string for learning rate scheduler

    Args:
        lr_scheduler: 'constant', 'reciprocal' or 'exponential'
        alpha: the base learning rate
        lr_gamma: discount factor for exponential lr decay

    Returns:
        the string for the learning rate scheduler
    """
    if lr_scheduler == "constant":
        out_str = f"lr_{lr_scheduler}_{alpha}"
    elif lr_scheduler == "reciprocal":
        out_str = f"lr_{lr_scheduler}_{alpha}"
    elif lr_scheduler == "exponential":
        out_str = f"lr_{lr_scheduler}_{alpha}_lg_{lr_gamma}"
    else:  # unknown lr_scheduler
        assert False, f"lr_scheduler {lr_scheduler} is not recognized."
    return out_str


def lr_schedule(alpha, t, lr_gamma=0.9, mode="constant"):
    """Learning rate scheduler

    Args:
        mode: 'constant'(default), 'reciprocal' or 'exponential'
        gamma: discount factor for exponential lr decay
        t: iteration number
    Returns:
        the learning rate
    """
    if mode == "constant":
        return alpha
    elif mode == "reciprocal":
        return alpha / t
    elif mode == "exponential":
        return alpha * (lr_gamma ** (t - 1))
    else:
        print(f"lr scheduling mode {mode} is not recognized. Return base lr")
        return alpha


def compute_h_list(T, N0):
    """Compute list of h values"""
    h0 = T / N0  # 2^-16 * T
    imax = 15  # TD needs at least N>=2
    return 2 ** np.arange(1, imax + 1) * h0  # [2^-15, ... 2^-1] * T


def create_h_mask(h_list, T, B, eps=1e-6):
    """Create mask for valid h values"""
    return (h_list < T / 2) & (h_list >= T / B - eps)


def compute_Vhat_TD_mpath(
    h,
    x,
    a,
    sigma,
    B,
    V,
    gamma=0.9,
    theta0=0.0,
    alpha=1.0,
    lr_gamma=0.9,
    nb_steps=1,
    lr_scheduler="constant",
    verbose=False
):
    """Compute theta and Vhat for mean-path semi-grad TD(0) estimator.

    Args:
        h (float): The value of h.
        x (ndarray): A two-dimensional array with M columns where each column contains the N+1 samples x(0),x(1),...,x(N) from a trajectory.
        a (float): The value of the drift coefficient a.
        sigma (float): The value of the noise scaling factor sigma.
        B (float): The value of data budget B.
        V (float): The true value of x(0).
        gamma (float, optional): The discount factor. Defaults to 0.9.
        theta0 (float, optional): The initial value of theta. Defaults to 0.0.
        alpha (float, optional): The learning rate. Defaults to 1.0.
        nb_steps (int, optional): The number of steps of gradient update. Defaults to 1.
        lr_scheduler (str, optional): learning rate scheduler, 'constant', 'reciprocal' or 'exponential'. Defaults to 'constant'.
        lr_gamma (float, optional): discount for exponential lr decay. Defaults to 0.9.
        verbose (bool, optional): Whether to print verbose output. Defaults to False.

    Returns:
        A tuple that contains:
            Vhat (float): The estimate of V from TD.
            theta (float): The estimate of theta from TD.
            Vhat_list (ndarray): The list of value estimate from all steps.
            theta_list (ndarray): The list of theta estimate from all steps.
            V (float): The true value.
            theta_star (float): The true value of theta.
            avg_ngrad_list (ndarray): The list of average negative gradient from all steps.
            lr_list (ndarray): The list of lr from all steps.
    """
    if x.shape[0] > B:
        if verbose:
            print(
                (
                    f"Number of samples {x.shape[0]} exceeds the budget {int(B)}. "
                    "Only use the first B samples."
                )
            )
        x_sq = x[: int(B), :].astype(float) ** 2
    else:
        x_sq = x.astype(float) ** 2
    if verbose:
        print("shape of data used for mean-path TD: ", x_sq.shape)
    feat = x_sq - sigma**2 / np.log(gamma)
    r = -h * x_sq
    theta_star = get_theta_star(a, gamma)

    theta_list = np.zeros((nb_steps + 1))  # th0, th1, ... th_{nb_steps}
    Vhat_list = np.zeros((nb_steps + 1))
    avg_ngrad_list = np.zeros((nb_steps))
    lr_list = np.zeros((nb_steps))

    theta = theta0
    theta_list[0] = theta
    Vhat_list[0] = -(sigma**2) / np.log(gamma) * theta
    ngrad = 0  # negative gradient
    N = x_sq.shape[0]  # number of samples in a trajectory
    M = x_sq.shape[1]  # number of trajectories

    num_data = (N - 1) * M  # total number of data pairs
    for i in range(nb_steps):
        for j in range(M):  # sum over all trajectories
            # sum gradients of all sample pairs in a trajectory
            ngrad += np.sum(
                (r[:-1, j] + theta * (gamma**h * feat[1:, j] - feat[:-1, j]))
                * feat[:-1, j]
            )

        lr = lr_schedule(alpha, i + 1, lr_gamma=lr_gamma, mode=lr_scheduler)
        avg_ngrad = ngrad / num_data  # grad is the sum over all data points
        theta += lr * avg_ngrad
        Vhat = -(sigma**2) / np.log(gamma) * theta

        # store the values
        theta_list[i + 1] = theta
        Vhat_list[i + 1] = Vhat
        avg_ngrad_list[i] = avg_ngrad
        lr_list[i] = lr

        # reset gradient
        ngrad = 0

        if verbose:
            print(f"After gradient step {i}")
            print(
                f"-- learning rate: {lr}, theta: {theta}, avg_ngrad: {avg_ngrad}, "
                f"err in theta: {theta-theta_star}, Vhat: {Vhat}, err in Vhat: {Vhat-V}"
            )

    return (
        Vhat,
        theta,
        Vhat_list,
        theta_list,
        V,
        theta_star,
        avg_ngrad_list,
        lr_list,
    )


def compute_Vhat_LSTD(
    h,
    x,
    a,
    sigma,
    B,
    V,
    gamma=0.9,
    verbose=False
):
    """Compute theta and Vhat for LSTD(0) estimator.

    Args:
        h (float): The value of h.
        x (ndarray): A two-dimensional array with M columns where each column contains the N+1 samples x(0),x(1),...,x(N) from a trajectory.
        a (float): The value of the drift coefficient a.
        sigma (float): The value of the noise scaling factor sigma.
        B (float): The value of data budget B.
        V (float): The true value of x(0).
        gamma (float, optional): The discount factor. Defaults to 0.9.
        verbose (bool, optional): Whether to print verbose output. Defaults to False.

    Returns:
        A tuple that contains:
            Vhat (float): The estimate of V from TD.
            theta (float): The estimate of theta from TD.
            V (float): The true value.
            theta_star (float): The true value of theta.
    """
    if x.shape[0] > B:
        if verbose:
            print(
                (
                    f"Number of samples {x.shape[0]} exceeds the data budget {int(B)}. "
                    "Only use the first B samples."
                )
            )
        x_sq = x[: int(B), :].astype(float) ** 2
    else:
        x_sq = x.astype(float) ** 2
    if verbose:
        print("shape of data used for mean-path TD: ", x_sq.shape)
    feat = x_sq - sigma**2 / np.log(gamma)
    r = -h * x_sq
    theta_star = get_theta_star(a, gamma)

    numerator = 0
    denominator = 0
    # N = x_sq.shape[0]  # number of samples in a trajectory
    M = x_sq.shape[1]  # number of trajectories
    # num_data = (N - 1) * M  # total number of data pairs

    for j in range(M):  # sum over all trajectories
        numerator += np.sum(feat[:-1, j] * r[:-1, j])
        denominator += np.sum(feat[:-1, j] *
                             (gamma**h * feat[1:, j]-feat[:-1, j]))

    theta = - numerator / denominator
    Vhat = -(sigma**2) / np.log(gamma) * theta

    return (
        Vhat,
        theta,
        V,
        theta_star
    )


def _load_x_for_MSE(path, run_id, a, T, max_T, B, max_B, verbose=False):
    """load from the data with the largest B and T if possible"""
    if T > max_T:
        assert False, f"No trajectory data for T={T}"
    npy_fname = os.path.join(path, f"x_run_{run_id}_a_{a}_T_{max_T}_B_{max_B}_seed_0_float16.npy")
    if not os.path.exists(npy_fname):  # in case we only have data for a smaller B
        npy_fname = os.path.join(path, f"x_run_{run_id}_a_{a}_T_{max_T}_B_{B}_seed_0_float16.npy")
    x = load_npy(npy_fname, verbose)  # (N0+1, M)
    return x


def _get_trajectory_info(h, T, B, h0, verbose):
    N = min(
        int(T / h), B
    )  # number of samples in each trajectory, cannot exceed B
    M = max(int(B / N), 1)  # number of trajectories, at least 1
    h_ratio = int(h / h0)  # used to skip samples
    if verbose:
        print(f"h={h}: T={T}, N={N}, M={M}, h/h0={h_ratio}")
    return (N, M, h_ratio)


def compute_MSE(
    a,
    T,
    B,
    N0=2**16,
    sigma=1.0,
    gamma=0.9,
    num_runs=50,
    nb_steps=1,
    theta0=0.0,
    alpha=1.0,
    lr_scheduler="constant",
    lr_gamma=0.9,
    max_T=8.0,
    max_B=2**16,
    path="data/share",
    verbose=True
):
    """ compute MSE for {a,T,B} and save the result """
    h_list = compute_h_list(T, N0)
    len_h = len(h_list)
    V = -EJinf(a, sigma, gamma)
    theta_star = get_theta_star(a, gamma)
    lr_str = get_lr_str(lr_scheduler, alpha, lr_gamma)

    Vhat_list = np.zeros((num_runs, len_h, nb_steps + 1))
    theta_list = np.zeros((num_runs, len_h, nb_steps + 1))
    avg_ngrad_list = np.zeros((num_runs, len_h, nb_steps))
    lr_list = np.zeros((num_runs, len_h, nb_steps))

    fname = os.path.join(
        path,
        f"Vhat_TD_a_{a}_T_{T}_B_{B}_g_{gamma}_step_{nb_steps}_{lr_str}_theta0_{theta0}_runs_{num_runs}_seed_0.npz",
    )

    if os.path.exists(fname):
        print(f"The file {fname} already exists. Returning.")
        return
        # data contains the keys: ("Vhat_list", "avg_ngrad", "lr_list", "theta_list")
    else:  # compute from x
        h0 = T / N0  # 2^-16 * T
        for i in trange(num_runs, desc="runs", disable=False):
            x = _load_x_for_MSE(path, i, a, T, max_T, B, max_B, verbose)

            for k in trange(h_list.shape[0], desc="h_list", leave=False,
                            position=1, disable=True):
                h = h_list[k]
                N, M, h_ratio = _get_trajectory_info(h, T, B, h0, verbose)

                # Vhat_list shape: [num_runs, len(h_list), nb_step+1]
                (
                    _,
                    _,
                    Vhat_list[i, k, :],
                    theta_list[i, k, :],
                    _,
                    _,
                    avg_ngrad_list[i, k, :],
                    lr_list[i, k, :]
                ) = compute_Vhat_TD_mpath(
                    h,
                    x[:-1:h_ratio, :M],
                    a,
                    sigma,
                    B,
                    V,
                    gamma=gamma,
                    theta0=theta0,
                    nb_steps=nb_steps,
                    alpha=alpha,
                    lr_scheduler=lr_scheduler,
                    lr_gamma=lr_gamma,
                    verbose=verbose
                )

        with open(fname, "wb") as f:
            np.savez(
                f,
                Vhat_list=Vhat_list,
                avg_ngrad=avg_ngrad_list,
                lr_list=lr_list,
                theta_list=theta_list,
                theta_star=theta_star,
                V=V
            )


def compute_MSE_LSTD(
    a,
    T,
    B,
    N0=2**16,
    sigma=1.0,
    gamma=0.9,
    num_runs=50,
    max_T=8.0,
    max_B=2**16,
    path="data/share",
    verbose=True
):
    """ compute MSE for {a,T,B} and save the result """
    h_list = compute_h_list(T, N0)
    len_h = len(h_list)
    V = -EJinf(a, sigma, gamma)
    theta_star = get_theta_star(a, gamma)

    Vhat_list = np.zeros((num_runs, len_h))
    theta_list = np.zeros((num_runs, len_h))

    fname = os.path.join(
        path,
        f"Vhat_LSTD_a_{a}_T_{T}_B_{B}_g_{gamma}_runs_{num_runs}_seed_0.npz",
    )

    if os.path.exists(fname):
        if verbose:
            print(f"{fname} already exists.")
        return
    else:  # compute from x
        h0 = T / N0  # 2^-16 * T

        for i in trange(num_runs, desc="runs", disable=False):
            x = _load_x_for_MSE(path, i, a, T, max_T, B, max_B, verbose)

            for k in trange(h_list.shape[0], desc="h_list", leave=False,
                            position=1, disable=True):
                h = h_list[k]
                N, M, h_ratio = _get_trajectory_info(h, T, B, h0, verbose)

                # Vhat_list shape: [num_runs, len(h_list)]
                (
                    Vhat_list[i, k],
                    theta_list[i, k],
                    _,
                    _
                ) = compute_Vhat_LSTD(
                    h,
                    x[:-1:h_ratio, :M],
                    a,
                    sigma,
                    B,
                    V,
                    gamma=gamma,
                    verbose=verbose
                )

        with open(fname, "wb") as f:
            np.savez(
                f,
                Vhat_list=Vhat_list,
                theta_list=theta_list,
                theta_star=theta_star,
                V=V
            )