from pathlib import Path
import json

import numpy as np
import torch as pt

from .callback import Callback


class AverageLog(Callback):
    """"""

    def __init__(self, log_file=None):
        self.log_file = log_file
        self.idx = None
        self.current_dict = {}

    def index(self, epoch, model, **kwds):
        self.idx = f"{epoch}" if model.training else f"{epoch}/val"
        self.current_dict.clear()

    def append(self, loss, metric, **kwds):
        for key, value in {**loss, **metric}.items():
            value = value.detach().cpu().numpy()
            if key in self.current_dict:
                self.current_dict[key].append(value)
            else:
                self.current_dict[key] = [value]

    def mean(self, **kwds):
        avg_dict = {}
        for k, v in self.current_dict.items():
            val = np.array(v).mean(0)  # .round(10) not work ???
            avg_dict[k] = val.tolist()
        if self.log_file:
            __class__.save(self.idx, avg_dict, self.log_file)
        print(self.idx, avg_dict)
        return avg_dict

    @staticmethod
    def save(key, avg_dict, log_file):
        line = json.dumps({key: avg_dict})
        with open(log_file, "a") as f:
            f.write(line + "\n")

    before_epoch = index
    after_step = append
    after_epoch = mean


class SaveModel(Callback):
    """"""

    def __init__(self, save_dir=None, since_step=0, weights_only=True):
        self.save_dir = save_dir
        self.since_step = since_step  # self.after_step is taken
        self.weights_only = weights_only

    def __call__(self, epoch, step_count, model, **kwds):
        if step_count >= self.since_step:
            save_file = Path(self.save_dir) / f"{epoch:04d}.pth"
            model.save(save_file, self.weights_only)

    after_epoch = __call__
