from collections import defaultdict
import logging
import numpy as np
import torch as th
import wandb


class Logger:
    def __init__(self, console_logger):
        self.console_logger = console_logger

        self.use_tb = False
        self.use_sacred = False
        self.use_hdf = False
        self.save_file=''
        self.stats = defaultdict(lambda: [])

    def log_stat(self, key, value, t):
        self.stats[key].append((t, value))
        wandb.log({key: value, "timestep": t})

    def save_metrics_line(self, save_list_value):
        self.save_file = self.console_logger.handlers[1].baseFilename[:-4]
        with open(self.save_file+".csv", "a") as f:
            f.write(",".join([f"{v:.6f}" if isinstance(v, float) else str(v) for v in save_list_value]) + "\n")

    def print_recent_stats(self):
        save_list=['t_env','battle_won_mean','test_return_mean','test_battle_won_mean']
        save_list_value=[]
        save_list_value.append(self.stats[save_list[1]][0][0])
        for save_item in save_list[1:]:
            save_list_value.append(self.stats[save_item][0][1])
        self.save_metrics_line(save_list_value)
        log_str = "Recent Stats | t_env: {:>10} | Episode: {:>8}\n".format(
            *self.stats["episode"][-1]
        )
        i = 0
        for k, v in sorted(self.stats.items()):
            if k == "episode":
                continue
            i += 1
            window = 5 if k != "epsilon" else 1
            item = "{:.4f}".format(
                th.mean(th.tensor([float(x[1]) for x in self.stats[k][-window:]]))
            )
            log_str += "{:<25}{:>8}".format(k + ":", item)
            log_str += "\n" if i % 4 == 0 else "\t"
        self.console_logger.info(log_str)
        # Reset stats to avoid accumulating logs in memory
        self.stats = defaultdict(lambda: [])

# set up a custom logger
def get_logger(name, log_file=None, log_level=logging.INFO):
    logger = logging.getLogger(name=name)
    if logger.hasHandlers():
        logger.handlers.clear()
    stream_handler = logging.StreamHandler()
    handlers = [stream_handler]
    if log_file is not None:
        file_handler = logging.FileHandler(log_file, mode="w")
        handlers.append(file_handler)
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    for handler in handlers:
        handler.setFormatter(formatter)
        handler.setLevel(log_level)
        logger.addHandler(handler)
    logger.setLevel(log_level)

    return logger
