import functools
import math
from pathlib import Path
from typing import Tuple, List

import click
import torch
from beam.distributed.ray_dispatcher import RayDispatcher, RayClient
from torch.nn import SmoothL1Loss
from torch.optim import Adam
from tqdm import tqdm

from algorithms.convergence_algorithms.cma import CMA
from algorithms.convergence_algorithms.egl_scheduler import EGLScheduler
from algorithms.convergence_algorithms.mixin import EGLCMA
from algorithms.mapping.trust_region import TanhTrustRegion, LinearTrustRegion
from algorithms.mapping.value_normalizers import AdaptedOutputUnconstrainedMapping
from algorithms.nn.datasets import PairsInEpsRangeDataset
from algorithms.nn.distributions import QuantileWeights
from algorithms.nn.modules import BasicNetwork, ModelToTrain
from compute_result.factory import StoreTypes
from problems.benchmarks import find_problems_to_run
from problems.types import Benchmarks
from run_options import (
    RESULT_PATH_OPTION,
    RUN_NAME_OPTION,
    BUDGET_OPTION,
    BENCHMARK_OPTION,
    CONCURRENCY_OPTION,
    PART_OPTION,
    FUNCTION_NUMBER_OPTION,
)
from scripts.callbacks import create_save_run_handlers
from utils.logger import create_logger, create_file_log_path
from utils.python import timestamp_file_signature

RayClient()


def run_ray_worker(run_name: str, env, results_path: Tuple[Path, StoreTypes]):
    dtype = torch.float64
    device = 0
    torch.set_default_dtype(dtype)
    base_dir = Path(__file__).parent

    algorithm_name = EGLCMA.ALGORITHM_NAME
    normal_logs_path = (
        create_file_log_path(base_dir, algorithm_name, run_name)
        / rf"logs_for_{algorithm_name}_{repr(env)}_parallel-{timestamp_file_signature()}"
    )
    logger = create_logger(normal_logs_path, None, run_name, algorithm_name, env)
    tqdm.__init__ = functools.partialmethod(tqdm.__init__, disable=True)

    try:
        dims = env.dimension
        helper = BasicNetwork(dims=dims, device=device)
        helper_opt = Adam(helper.parameters(), lr=0.001)
        model = ModelToTrain(device=device, dtype=dtype, dims=dims)
        model_opt = Adam(model.parameters(), lr=0.01)
        egl = EGLScheduler(
            env=env,
            helper_network=helper,
            model_to_train=model,
            value_optimizer=helper_opt,
            model_to_train_optimizer=model_opt,
            epsilon=math.sqrt(dims) * 0.01,
            epsilon_factor=0.97,
            min_epsilon=1e-4,
            perturb=0,
            num_of_batch_reply=32,
            maximum_movement_for_shrink=math.sqrt(dims) * 0.2,
            output_mapping=AdaptedOutputUnconstrainedMapping(),
            input_mapping=TanhTrustRegion(
                env.lower_bound, env.upper_bound, device=device, min_trust_region_size=0
            ),
            dtype=dtype,
            device=device,
            logger=logger,
            grad_loss=SmoothL1Loss(),
            database_size=10_000 * (math.floor(math.sqrt(dims)) * 2),
            database_type=PairsInEpsRangeDataset,
            weights_creator=QuantileWeights(),
        )

        cma = CMA.from_space(
            env,
            input_mapping=LinearTrustRegion(
                env.lower_bound, env.upper_bound, device=device
            ),
        )
        algorithm = EGLCMA(egl, cma, env)
        algorithm.train(
            epochs=20_000,
            helper_model_training_epochs=1,
            exploration_size=8 * (math.ceil(math.sqrt(dims))),
            num_loop_without_improvement=10,
            min_iteration_before_shrink=40,
            shrink_trust_region=True,
            callback_handlers=create_save_run_handlers(
                algorithm_name, run_name, True, results_path
            ),
        )
    except BaseException as e:
        print(e)
        logger.exception(e)


@click.command()
@CONCURRENCY_OPTION
@BENCHMARK_OPTION(True)
@FUNCTION_NUMBER_OPTION()
@RUN_NAME_OPTION
@BUDGET_OPTION
@PART_OPTION
@RESULT_PATH_OPTION
def run_on_ray(
    concurrency: int,
    benchmark: Benchmarks,
    func_num: List[int],
    run_name: str,
    budget: int,
    part: Tuple[int, int],
    results_path: Tuple[Path, StoreTypes],
    **kwargs,
):
    parts = list(part) if isinstance(part[0], tuple) else [part]
    spaces_to_run = find_problems_to_run(
        benchmark, budget, func_num=func_num, parts=parts
    )
    print(f"Working on {len(spaces_to_run)}")

    workers = [
        RayDispatcher(run_ray_worker, remote_kwargs={"num_gpus": 1 / concurrency})
        for _ in range(len(spaces_to_run))
    ]
    results = [
        worker(run_name, space, results_path)
        for worker, space in zip(workers, spaces_to_run)
    ]
    print([res.value for res in results])


if __name__ == "__main__":
    run_on_ray()
