import os
import torch
import logging.config
# import pandas as pd
# from bokeh.io import output_file, save, show
# from bokeh.plotting import figure
# from bokeh.layouts import column
# import matplotlib as plt
from pytorch_lightning.utilities import rank_zero_only


def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
    """Initializes multi-GPU-friendly python logger."""

    logger = logging.getLogger(name)
    logger.setLevel(level)

    # this ensures all logging levels get marked with the rank zero decorator
    # otherwise logs would get multiplied for each GPU process in multi-GPU setup
    for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"):
        setattr(logger, level, rank_zero_only(getattr(logger, level)))

    return logger

###### For setting up for log.txt file##############


def setup_logging(log_file='log.txt'):
    """
    Setup logging configuration
    """
    logging.basicConfig(level=logging.DEBUG,
                        format="%(asctime)s-%(levelname)s-%(message)s",
                        datefmt="%Y-%m-%d %H:%M:%S",
                        filename=log_file,
                        filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(messages)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)
# ############### For csv file ################


# class ResultsLog(object):
#     def __init__(self, path='results.csv', plot_path=None):
#         self.path = path
#         self.plot_path = plot_path or (self.path + '.html')
#         self.figures = []
#         self.results = None

#     def add(self, **kwargs):
#         df = pd.DataFrame([kwargs.values()], columns=kwargs.keys())
#         if self.results is None:
#             self.results = df
#         else:
#             self.results = self.results.append(df, ignore_index=True)

#     def save(self, title='Training Results'):
#         if len(self.figures) > 0:
#             if os.path.isfile(self.plot_path):
#                 os.remove(self.plot_path)
#             output_file(self.plot_path, title=title)
#             plot = column(*self.figures)
#             save(plot)
#             self.figure = []
#         self.results.to_csv(self.path, index=False, index_label=False)

#     def load(self, path=None):
#         path = path or self.path
#         if os.path.isfile(path):
#             self.results.read_csv(path)

#     def show(self):
#         if len(self.figures > 0):
#             plt.column(*self.figures)
#             show(plt)

#     def image(self, *kargs, **kwargs):
#         fig = figure()
#         fig.image(*kargs, **kwargs)
#         self.figures.append(fig)

def init_path_test_h(cfg):
    results_save_path = cfg.path_test_h()
    image_save_path = results_save_path + '/storing_images'
    P_save_path = results_save_path + '/storing_P'
    os.makedirs(image_save_path, exist_ok=True)
    os.makedirs(P_save_path, exist_ok=True)

    setup_logging(os.path.join(results_save_path, 'log.txt'))
    logging.debug("run arguments:%s", cfg)
    return results_save_path, image_save_path, P_save_path, None


def init_path(cfg):
    results_save_path = cfg.get_save_path()
    if os.path.exists(results_save_path + '/storing_P/log10symkl_20.pt'):
        raise NameError('This trial is already runned')
    image_save_path = results_save_path + '/storing_images'
    P_save_path = results_save_path + '/storing_P'
    os.makedirs(image_save_path, exist_ok=True)
    os.makedirs(P_save_path, exist_ok=True)

    setup_logging(os.path.join(results_save_path, 'log.txt'))
    logging.debug("run arguments:%s", cfg)
    return results_save_path, image_save_path, P_save_path, None


def dump_nn(h, map, epoch, path, save_h=True):
    torch.save(map.state_dict(), path + f'/map_epoch{epoch}.pt')
    if save_h == True:
        torch.save(h.state_dict(), path + f'/h_epoch{epoch}.pt')
