import os
import sys
from pathlib import Path

import mlflow
import numpy as np

from environments import *
from numerical_experiment import get_objects_from_config
from utils.logger import logger
from utils.save_func import (
    get_path_form_params,
    load_config,
    plot_results,
    save_result_json,
)

PROXIMAL_METHODS = [PROXIMAL_GRADIENT_DESCENT, ACCELERATED_PROXIMAL_GRADIENT_DESCENT]
QUASI_NEWTONS = [BFGS_QUASI_NEWTON, RANDOM_BFGS, SUBSPACE_QUASI_NEWTON]


def run_numerical_experiment(config):
    iteration = config["iteration"]
    log_interval = config["log_interval"]
    max_time = config["max_time"]
    algorithms_config = config["algorithms"]
    objectives_config = config["objective"]
    constraints_config = config["constraints"]
    solver_name = algorithms_config["solver_name"]
    objective_name = objectives_config["objective_name"]
    constraints_name = constraints_config["constraints_name"]

    use_prox = solver_name in PROXIMAL_METHODS

    (
        solver,
        solver_params,
        f,
        function_properties,
        con,
        constraints_properties,
        x0,
        prox,
    ) = get_objects_from_config(config)
    f.set_type(DTYPE)
    x0 = x0.astype(DTYPE)
    logger.info(f"dimension:{f.get_dimension()}")

    solver_dir = get_path_form_params(solver_params)
    func_dir = get_path_form_params(function_properties)

    if constraints_name != NOCONSTRAINTS:
        con_dir = get_path_form_params(constraints_properties)
        save_path = os.path.join(
            RESULTPATH,
            objective_name,
            func_dir,
            constraints_name,
            con_dir,
            solver_name,
            solver_dir,
        )
        con.set_type(DTYPE)
        if con.is_feasible(x0):
            logger.info("Initial point is feasible.")
        else:
            logger.info("Initial point is not feasible")
            return
    else:
        save_path = os.path.join(
            RESULTPATH,
            objective_name,
            func_dir,
            constraints_name,
            solver_name,
            solver_dir,
        )

    os.makedirs(save_path, exist_ok=True)
    logger.info(save_path)
    # 実験開始
    logger.info("Run Numerical Experiments")
    if constraints_name != NOCONSTRAINTS:
        if use_prox:
            solver.run(
                f=f,
                prox=prox,
                x0=x0,
                iteration=iteration,
                params=solver_params,
                save_path=save_path,
                log_interval=log_interval,
                max_time=max_time,
            )
        else:
            solver.run(
                f=f,
                con=con,
                x0=x0,
                iteration=iteration,
                params=solver_params,
                save_path=save_path,
                log_interval=log_interval,
                max_time=max_time,
            )
    else:
        if use_prox:
            solver.run(
                f=f,
                prox=prox,
                x0=x0,
                iteration=iteration,
                params=solver_params,
                save_path=save_path,
                log_interval=log_interval,
                max_time=max_time,
            )
        else:
            solver.run(
                f=f,
                x0=x0,
                iteration=iteration,
                params=solver_params,
                save_path=save_path,
                log_interval=log_interval,
                max_time=max_time,
            )
    solver.save_results(save_path)
    nonzero_index = solver.save_values["func_values"] != 0
    min_f_value = np.min(solver.save_values["func_values"][nonzero_index])
    execution_time = solver.save_values["time"][-1]
    values_dict = {"min_value": min_f_value, "time": execution_time}
    plot_results(save_path, solver.save_values)

    save_result_json(
        save_path=os.path.join(save_path, "result.json"),
        values_dict=values_dict,
        iteration=iteration,
    )
    logger.info("Finish Numerical Experiment")

    path = Path(save_path)
    optimizer_name = solver_name
    objective_properties = function_properties
    optimizer_properties = solver_params

    mlflow.set_experiment("MLP_DD_inv")

    with mlflow.start_run() as run:
        mlflow.log_params(objective_properties)
        mlflow.log_params(optimizer_properties)
        mlflow.log_param("optimizer", optimizer_name)
        mlflow.log_param("dtype", DTYPE)

        if optimizer_name == SUBSPACE_REGULARIZED_NEWTON:
            mlflow.set_tag("impl", "jnp.linalg.inv")

        # log metrics
        func_values = np.load(path / "func_values.npy")
        grad_norms = np.load(path / "grad_norm.npy")
        timestamps = np.load(path / "time.npy")
        mask = timestamps > 0
        func_values = func_values[mask]
        grad_norms = grad_norms[mask]
        timestamps = timestamps[mask]
        for i, (func_value, grad_norm, timestamp) in enumerate(
            zip(func_values, grad_norms, timestamps)
        ):
            timestamp = int(timestamp * 1000)  # convert to milliseconds
            mlflow.log_metric("func_value", func_value, step=i, timestamp=timestamp)
            mlflow.log_metric("grad_norm", grad_norm, step=i, timestamp=timestamp)


if __name__ == "__main__":
    args = sys.argv
    config_path = args[1]
    config = load_config(config_path)
    run_numerical_experiment(config)
