from __future__ import annotations

import sys
from typing import Optional

from llm4ad.base import Function
from llm4ad.tools.profiler.profile import ProfilerBase

try:
    import wandb
except:
    pass


class WandBProfiler(ProfilerBase):

    def __init__(self,
                 wandb_project_name: str,
                 log_dir: Optional[str] = None,
                 *,
                 evaluation_name='Problem',
                 method_name='Method',
                 initial_num_samples=0,
                 log_style='complex',
                 create_random_path=True,
                 **wandb_init_kwargs):
        """
        Args:
            wandb_project_name : the project name in which you sync your results.
            log_dir            : the directory of current run
            evaluation_name    : the name of the evaluation instance (the name of the problem to be solved).
            method_name        : the name of the search method.
            initial_num_samples: the sample order start with `initial_num_samples`.
            create_random_path : create a random log_path according to evaluation_name, method_name, time, ...
        """
        super().__init__(log_dir=log_dir,
                         evaluation_name=evaluation_name,
                         method_name=method_name,
                         initial_num_samples=initial_num_samples,
                         log_style=log_style,
                         create_random_path=create_random_path,
                         **wandb_init_kwargs)

        self._wandb_project_name = wandb_project_name

        # for MacOS and Linux
        if sys.platform.startswith('darwin') or sys.platform.startswith('linux'):
            setting = wandb.Settings(start_method='fork')
            self._logger_wandb = wandb.init(
                project=self._wandb_project_name,
                dir=self._log_dir,
                settings=setting,
                **wandb_init_kwargs
            )
        else:  # for Windows
            wandb.setup()
            self._logger_wandb = wandb.init(
                project=self._wandb_project_name,
                dir=self._log_dir,
                **wandb_init_kwargs
            )

    def get_logger(self):
        return self._logger_wandb

    def register_function(self, function: Function, *, resume_mode=False):
        """Record an obtained function. This is a synchronized function.
        """
        try:
            self._register_function_lock.acquire()
            self.__class__._num_samples += 1
            self._record_and_print_verbose(function, resume_mode=resume_mode)
            self._write_wandb()
            self._write_json(function)
        finally:
            self._register_function_lock.release()

    def _write_wandb(self, *args, **kwargs):
        self._logger_wandb.log(
            {
                'Best Score of Function': self._cur_best_program_score
            },
            step=self.__class__._num_samples
        )
        self._logger_wandb.log(
            {
                'Valid Function Num': self._evaluate_success_program_num,
                'Invalid Function Num': self._evaluate_failed_program_num
            },
            step=self.__class__._num_samples
        )
        self._logger_wandb.log(
            {
                'Total Sample Time': self._tot_sample_time,
                'Total Evaluate Time': self._tot_evaluate_time
            },
            step=self.__class__._num_samples
        )

    def finish(self):
        wandb.finish()
