import re
import pprint
import logging

from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities.rank_zero import rank_zero_only

import json
import csv
import os

class ConsoleLogger(LightningLoggerBase):
    def __init__(self, log_keys=[]):
        super().__init__()
        self.log_keys = [re.compile(k) for k in log_keys]
        self.dict_printer = pprint.PrettyPrinter(indent=2, compact=False).pformat
    
    def match_log_keys(self, s):
        return True if not self.log_keys else any(r.search(s) for r in self.log_keys)

    @property
    def name(self):
        return 'console'
    
    @property
    def version(self):
        return '0'
    
    @property
    @rank_zero_experiment
    def experiment(self):
        return logging.getLogger('pytorch_lightning')
    
    @rank_zero_only
    def log_hyperparams(self, params):
        pass

    @rank_zero_only
    def log_metrics(self, metrics, step):
        metrics_ = {k: v for k, v in metrics.items() if self.match_log_keys(k)}
        if not metrics_:
            return
        self.experiment.info(f"\nEpoch{metrics['epoch']} Step{step}\n{self.dict_printer(metrics_)}")

class ResultLogger():
    def __init__(self, file_path):
        # Initialize a log file in JSON format
        self.file_path = file_path
        if not os.path.isfile(file_path): 
            with open(file_path, 'w') as f:
                json.dump({}, f)

    def log_new_item(self, log_item: dict):
        # Load the existing log file into the global log dictionary
        with open(self.file_path, 'r') as f:
            log_dict = json.load(f)

        # Generate a unique index for the new log item
        index = str(len(log_dict))

        # Add the new log item to the global log dictionary
        log_dict[index] = log_item

        # Write the updated global log dictionary to the log file
        with open(self.file_path, 'w') as f:
            json.dump(log_dict, f)

    def export_csv(self): 
        csv_path = self.file_path.replace('.json', '.csv')

        # Load the global log dictionary from the log file
        with open(self.file_path, 'r') as f:
            log_dict = json.load(f)

        # Get the field names from the first log item in the global log dictionary
        
        field_names = list(log_dict[next(iter(log_dict))].keys())

        # Open a new CSV file for writing
        with open(csv_path, 'w', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=field_names)

            # Write the header row to the CSV file
            writer.writeheader()

            # Write each log item as a row to the CSV file
            for key, value in log_dict.items():
                writer.writerow(value)

