import os
from unittest.mock import DEFAULT
import torch
import pandas as pd
import numpy as np

DEFAULT_LOG_DIR = './results'


def safe_make_dir(path):
    if not os.path.exists(path):
        os.mkdir(path)


class TrainLogger:
    def __init__(self, name='log', log_dir=DEFAULT_LOG_DIR):
        self.log = {}
        self.name = name
        self.log_dir = log_dir

        safe_make_dir(self.log_dir)

    def add(self, varname, quantity, disp=False):
        if isinstance(quantity, torch.Tensor):
            quantity = quantity.detach().clone().cpu().numpy()
        
        if varname not in self.log:
            self.log[varname] = []
        self.log[varname].append(quantity)
        
        if disp:
            print(varname, quantity)
        
    def export(self):
        out = {}
        max_len = 0
        for key, item in self.log.items():
            mat = np.array(item)
            out[key] = mat
            max_len = max(max_len, len(mat))

        # pad single-entries (non-epoch logs) to same length 
        for key, item in out.items():
            if len(item) == 1:
                out[key] = np.repeat(item, max_len)
        
        table = pd.DataFrame(out)
        table.to_csv(
            os.path.join(self.log_dir, f'{self.name}.csv'),
            index=False
        )
        return table

    def export_sample(self, sample=0):
        out = {}
        for key, item in self.log.items():
            mat = np.array(item)
            if len(mat.shape) > 1:
                mat = mat[:, sample]
            out[key] = mat
        return pd.DataFrame(out)


class LogAggregator:
    def __init__(self, path):
        self.path = path
        self.logs = [x for x in os.listdir(self.path) if '.csv' in x]

    def compile(self, entries=['dataset', 'metric'], runs=[0, 1, 2]):
        logs_ids = []
        logs_full = []
        for log in self.logs:
            strname = log.rstrip('.csv')
            str_entries = strname.split('_')
            # assert len(str_entries) == len(entries) + 1
            if len(str_entries) != len(entries) + 1 or int(str_entries[-1]) not in runs:
                print('skipping: ', log)
                continue

            info = {}
            for i in range(len(entries)):
                info[entries[i]] = str_entries[i]
            info['run'] = str_entries[-1]
            info['path'] = log

            dat = pd.read_csv(os.path.join(self.path, log))
            for k, v in info.items():
                dat[k] = v

            logs_ids.append(info)
            logs_full.append(dat)

        self.log_info = pd.DataFrame(logs_ids)
        self.log_dfs = pd.concat(logs_full, axis=0)
