

from __future__ import annotations

import json
import os
from threading import Lock
from typing import Optional

from llm4ad.base import Function
from llm4ad.tools.profiler.profile import ProfilerBase

try:
    os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'  # Disable TF onednn for better performance
    from torch.utils.tensorboard import SummaryWriter
except:
    pass


class TensorboardProfiler(ProfilerBase):
    # _num_samples = 0

    def __init__(self,
                 log_dir: Optional[str] = None,
                 *,
                 evaluation_name='Problem',
                 method_name='Method',
                 initial_num_samples=0,
                 log_style='complex',
                 create_random_path=True,
                 **kwargs):
        """Base profiler for recording experimental results.
        Args:
            log_dir            : the directory of current run
            evaluation_name    : the name of the evaluation instance (the name of the problem to be solved).
            method_name        : the name of the search method.
            initial_num_samples: the sample order start with `initial_num_samples`.
            create_random_path : create a random log_path according to evaluation_name, method_name, time, ...
        """
        super().__init__(log_dir=log_dir,
                         evaluation_name=evaluation_name,
                         method_name=method_name,
                         initial_num_samples=initial_num_samples,
                         log_style=log_style,
                         create_random_path=create_random_path,
                         **kwargs)

        # summary writer instance for Tensorboard
        if log_dir:
            self._writer = SummaryWriter(log_dir=self._log_dir)


    def get_logger(self):
        return self._writer

    def register_function(self, function: Function, *, resume_mode=False):
        """Record an obtained function. This is a synchronized function.
        """
        try:
            self._register_function_lock.acquire()
            self.__class__._num_samples += 1
            self._record_and_print_verbose(function, resume_mode=resume_mode)
            self._write_tensorboard()
            self._write_json(function)
        finally:
            self._register_function_lock.release()

    def finish(self):
        if self._log_dir:
            self._writer.close()

    def _write_tensorboard(self, *args, **kwargs):
        if not self._log_dir:
            return

        self._writer.add_scalar(
            'Best Score of Function',
            self._cur_best_program_score,
            global_step=self.__class__._num_samples
        )
        self._writer.add_scalars(
            'Legal/Illegal Function',
            {
                'legal function num': self._evaluate_success_program_num,
                'illegal function num': self._evaluate_failed_program_num
            },
            global_step=self.__class__._num_samples
        )
        self._writer.add_scalars(
            'Total Sample/Evaluate Time',
            {'sample time': self._tot_sample_time, 'evaluate time': self._tot_evaluate_time},
            global_step=self.__class__._num_samples
        )
