import time
from typing import Any, Dict, Optional, Tuple

import torch
from botorch.models import SingleTaskGP
from torch import Tensor
from torch.quasirandom import SobolEngine

from .helpers import model_fit_helper, replication_setup


METHOD_NAMES = [
    "vanilla",
    "relevance_pursuit",
    "relevance_pursuit_fwd",
    "student_t",
    "trimmed_mll",
    "power_transform",
    "winsorize",
    "sobol",
]


def run_one_regression_replication(
    results_fpath: str,
    seed: int,
    method_name: str,
    function_name: str,
    outlier_fraction: float,
    outlier_generator_name: str,
    n_train: int,
    n_test: Optional[int] = None,
    outlier_generator_kwargs: Optional[Dict[str, Any]] = None,
    dtype: torch.dtype = torch.double,
    device: Optional[torch.device] = None,
    noise_std: Optional[float] = None,
) -> None:
    tkwargs = {"dtype": dtype, "device": device}
    outlier_generator_kwargs = outlier_generator_kwargs or {}
    objective_function, outlier_generator, dim, minimize, bounds = replication_setup(
        method_name=method_name,
        function_name=function_name,
        outlier_fraction=outlier_fraction,
        outlier_generator_name=outlier_generator_name,
        outlier_generator_kwargs=outlier_generator_kwargs,
        dtype=dtype,
        device=device,
        noise_std=noise_std,
    )
    # Start of regression code
    # Sobol batch
    if n_test is None:
        n_test = n_train

    X = SobolEngine(dimension=dim, scramble=True, seed=seed).draw(n_train).to(**tkwargs)  # pyre-ignore
    X = (bounds[1] - bounds[0]) * X + bounds[0]

    # Important: increment seed for test set
    X_test = SobolEngine(dimension=dim, scramble=True, seed=seed + 1).draw(n_test).to(**tkwargs)  # pyre-ignore
    X_test = (bounds[1] - bounds[0]) * X_test + bounds[0]
    Y_train_true = objective_function.evaluate_true(X).unsqueeze(-1)
    if objective_function.negate:  # NOTE evaluate_true doesn't negate!
        Y_train_true = -Y_train_true
    Y_test_true = objective_function.evaluate_true(X_test).unsqueeze(-1)
    if objective_function.negate:
        Y_test_true = -Y_test_true

    with torch.random.fork_rng():
        torch.manual_seed(seed)
        Y = objective_function(X).unsqueeze(-1)

    test_stats = {}
    train_stats = {}
    fit_time = 0
    Y_test_pred = torch.zeros_like(Y_test_true)
    Y_train_pred = torch.zeros_like(Y_train_true)
    try:
        # Fit model
        start_time = time.monotonic()
        model = model_fit_helper(method_name=method_name, X=X, Y=Y, minimize=minimize)
        fit_time += time.monotonic() - start_time

        # Compute regression results (train / test statistics) on ground truth targets
        train_stats, Y_train_pred = compute_regression_statistics(
            model=model, X=X, Y=Y_train_true
        )
        test_stats, Y_test_pred = compute_regression_statistics(
            model=model, X=X_test, Y=Y_test_true
        )

    except Exception as e:
        print(e)

    # Post-process
    # Save the final output
    output_dict = {
        "method_name": method_name,
        "function_name": function_name,
        "outlier_generator_name": outlier_generator_name,
        "outlier_fraction": outlier_fraction,
        "outlier_generator_kwargs": outlier_generator_kwargs,
        "n_train": n_train,
        "n_test": n_test,
        "X": X.cpu(),
        "Y": Y.cpu(),
        "X_test": X_test.cpu(),
        "Y_test_true": Y_test_true.cpu(),
        "Y_test_pred": Y_test_pred.cpu(),
        "Y_train_true": Y_train_true.cpu(),
        "Y_train_pred": Y_train_pred.cpu(),
        "fit_time": fit_time,
        "train_stats": train_stats,
        "test_stats": test_stats,
    }
    torch.save(output_dict, results_fpath)


def compute_regression_statistics(
    model: SingleTaskGP, X: Tensor, Y: Tensor
) -> Tuple[Dict[str, float], Tensor]:
    # noisy predictive not available in closed form for Student-T model
    post_X = model.posterior(X, observation_noise=False)
    post_mean = post_X.mean.squeeze(-1).detach()
    Y = Y.squeeze(-1)
    err = Y - post_mean
    assert err.ndim == 1
    mae = err.abs().mean().item()
    rmse = err.square().mean().sqrt().item()

    rel_rmse = err.square().mean().sqrt().item() / Y.square().mean().sqrt().item()

    # log likelihood quantifies predictive accuracy and calibration (on ground truth)
    # loglike = post_X.log_prob(Y).item()
    post_X = model.posterior(X.unsqueeze(-2), observation_noise=False)
    loglike = post_X.log_prob(Y.unsqueeze(-1)).mean().item()
    stats = {"rmse": rmse, "mae": mae, "rel_rmse": rel_rmse, "loglike": loglike}

    post_X = model.posterior(X.unsqueeze(-2), observation_noise=False)

    # the following are particularly relevant for real-world applications, where we
    # can't observe the ground truth Y and the validation data might be corrupted too.
    for percentile in (50, 60, 70, 80, 90):
        quantile = percentile / 100
        p_err = err.abs().quantile(quantile).item()
        stats["p" + str(percentile) + "err"] = p_err
        # rmse conditional on the error being smaller than the p-th percentile
        stats["p" + str(percentile) + "rmse"] = (
            err[err < p_err].square().mean().sqrt().item()
        )

    return stats, post_mean
