import numpy as np
import utils
import time
import wandb


def set_log(
    self, num_frames, start_time, episode, return_per_frame_, test_return_per_frame_, current_option, option_termination_prob, option_termination_error, sigmoid_termonations, bmc_epsilon
):
    return_per_episode = utils.synthesize(self.logs["rewards"])
    return_per_frame_.append(list(return_per_episode.values())[2])

    duration = int(time.time() - start_time)
    header = ["episode", "frames", "duration"]
    data = [episode, num_frames, duration]
    header += ["return_" + key for key in return_per_episode.keys()]
    data += return_per_episode.values()
    header += ["policy_loss"]
    data += [np.mean(self.logs["loss"])]
    header += ["test_return_sum"]
    data += [np.mean(test_return_per_frame_[-5:]) if len(test_return_per_frame_) > 0 else 0]
    # data += [test_return_per_frame_[-1] if len(test_return_per_frame_) > 0 else 0]
    # if len(test_return_per_frame_) > 100:
        # pdb.set_trace()
    if utils.check_run.enable_optionQ(self):
        header += ["curr_option"]
        data += [current_option]
        if option_termination_prob and option_termination_error:
            header += ["option_termination_prob"]
            data += [option_termination_prob]
            header += ["option_termination_error"]
            data += [option_termination_error]
        if sigmoid_termonations:
            if len(sigmoid_termonations) == 4:
                header += ["termination-random", "termination-z", "termination-rnd", "termination-e"]
                data += [sigmoid_termonations[0], sigmoid_termonations[1], sigmoid_termonations[2], sigmoid_termonations[3]]
            elif len(sigmoid_termonations) == 3:
                header += ["termination-1", "termination-2", "termination-3"]
                data += [sigmoid_termonations[0], sigmoid_termonations[1], sigmoid_termonations[2]]
    if utils.check_run.is_bmc(self):
        header += ["bmc_epsilon"]
        data += [bmc_epsilon]

    self.txt_logger.info(
        "E {} | F {:06} | D {} | rR:μσmM {:.2f} {:.2f} {:.2f} | pL {:.3f}".format(*data)
    )

    if num_frames == 0:
        self.csv_logger.writerow(header)
    self.csv_logger.writerow(data)
    self.csv_file.flush()

    for field, value in zip(header, data):
        if field != "test_return_sum":
            self.tb_writer.add_scalar(field, value, num_frames)
        wandb.log({field: value}, step=num_frames)
