import json
import time
from options import prettyprint_args
import os
import numpy as np
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
#import wandb

class Logger:
    def __init__(self, n, val_n, args, metrics=['acc', 'loss']):
        self.log = {metric:[] for metric in metrics}
        self.n = n    
        self.val_n = val_n
        self.epochs_finished = 0 
        self.n_classes = args.n_classes 
        self.aux = {'args': vars(args),'cm': [], 'val_cm': []}
              
    def end_epoch(self, loss, preds, y, val_loss, val_preds, val_y):
        acc = (preds == y).sum().item() / y.size(0)
        cm = confusion_matrix(y, preds).tolist()
        self.aux['cm'].append(cm)
        #wandb.log({'cm':cm}, commit=False)
        prec, rec, f1, _  = precision_recall_fscore_support(y, preds, average='binary' if self.n_classes == 2 else 'macro')
        if 'acc' in self.log:
            self.log['acc'].append(acc)
            #wandb.log({"acc":acc}, commit=False)
        if 'loss' in self.log:
            self.log['loss'].append(loss / self.n)
            #wandb.log({"loss":loss / self.n}, commit=False)
        if 'precision' in self.log:
            self.log['precision'].append(prec)
            #wandb.log({"prec":prec}, commit=False)
        if 'recall' in self.log:
            self.log['recall'].append(rec)
            #wandb.log({"recall":rec}, commit=False)
        if 'f1' in self.log:
            self.log['f1'].append(f1)
            #wandb.log({"f1":f1}, commit=False)
        
        val_acc = (val_preds == val_y).sum().item() / val_y.size(0)
        val_cm = confusion_matrix(val_y, val_preds).tolist()
        self.aux['val_cm'].append(val_cm)
        #wandb.log({"val_cm":val_cm}, commit=False)
        val_prec, val_rec, val_f1, _ = precision_recall_fscore_support(val_y, val_preds, average='binary' if self.n_classes == 2 else 'macro')
        if 'val_acc' in self.log:
            self.log['val_acc'].append(val_acc)
            wandb.log({"val_acc":val_acc}, commit=False)
        if 'val_loss' in self.log:
            self.log['val_loss'].append(val_loss / self.val_n)
            wandb.log({"val_loss":val_loss}, commit=False)
        if 'val_precision' in self.log:
            self.log['val_precision'].append(val_prec)
            wandb.log({"val_precision":val_prec}, commit=False)
        if 'val_recall' in self.log:
            self.log['val_recall'].append(val_rec)
            wandb.log({"val_recall":val_rec}, commit=False)
        if 'val_f1' in self.log:
            self.log['val_f1'].append(val_f1)
            wandb.log({"val_f1":val_f1}, commit=False)
        self.epochs_finished += 1
        wandb.log({"epochs_finished":self.epochs_finished}, commit=True)

    def report(self):
        print("end epoch {} - ".format(self.epochs_finished) + ", ".join(["{}: {:.4f}".format(k, v[-1]) for k, v in sorted(self.log.items())]))
    
    def save(self, path):
        log_data = json.dumps({**self.log, **self.aux})
        if os.path.isfile(path): # avoid clobbering
            new_path = path
            i = 1
            while os.path.isfile(new_path): 
                new_path = path.rsplit(".", 1)[0] + "_{}.log".format(i)
                i += 1
            print("WARNING: file {} already exists! Considering renaming this file manually using 'mv'. Saving as {}".format(path, new_path))
            path = new_path
        with open(path, "w") as f:
            f.write(log_data)

