import logging
import os
import torch
import numpy as np
import scipy
import itertools
import matplotlib.pyplot as plt

from deepxde.callbacks import Callback
from deepxde.utils.internal import list_to_str

from utils import plot_lines

class TesterCallback(Callback):

    def __init__(self, test_data, save_path, log_every=100, verbose=True):
        super(TesterCallback, self).__init__()

        self.log_every = log_every
        self.verbose = verbose

        self.indexes = []
        self.maes = []    # Mean Average Error
        self.mses = []    # Mean Square Error
        self.mxes = []    # Maximum Error
        self.l1res = []   # L1 Relative Error
        self.l2res = []   # L2 Relative Error
        self.crmses = []  # CSV_Loss

        self.epochs_since_last_resample = 0
        self.valid_epoch = 0
        self.disable = False
        self.save_path = save_path
        self.test_x, self.test_y = test_data

    def on_train_begin(self):
        self.solution_l1 = np.abs(self.test_y).mean()
        self.solution_l2 = np.sqrt((self.test_y**2).mean())


    def on_epoch_end(self):
        self.epochs_since_last_resample += 1
        self.valid_epoch += 1
        if self.disable or self.log_every is None or self.epochs_since_last_resample < self.log_every:
            return
        self.epochs_since_last_resample = 0

        with torch.no_grad():
            y = self.model.predict(self.test_x)

        mse = ((y - self.test_y)**2).mean()
        mae = np.abs(y - self.test_y).mean()
        mxe = np.max(np.abs(y - self.test_y))
        l1re = mae / self.solution_l1
        l2re = np.sqrt(mse) / self.solution_l2
        crmse = np.abs((y - self.test_y).mean())

        self.indexes.append(self.valid_epoch)
        self.mses.append(mse)
        self.maes.append(mae)
        self.mxes.append(mxe)
        self.l1res.append(l1re)
        self.l2res.append(l2re)
        self.crmses.append(crmse)

        if self.verbose:
            print('Validation: epoch {} MSE {:.5f} MAE {:.5f} MXE {:.5f} L1RE {:.5f} L2RE {:.5f} CRMSE {:.5f}'.\
                    format(self.valid_epoch, mse, mae, mxe, l1re, l2re, crmse))

    def on_train_end(self):
        if self.disable:
            return

        self.indexes = np.array(self.indexes)
        np.savetxt(
            self.save_path + 'errors.txt',
            np.array([self.indexes, self.maes, self.mses, self.mxes, self.l1res, self.l2res, self.crmses]).T,
            header="epochs, maes, mses, mxes, l1res, l2res, crmses"
        )

        plot_lines([self.indexes, self.maes], xlabel="epochs", labels=['maes'], path=self.save_path + "maes.png", title="mean average error")
        plot_lines([self.indexes, self.mses], xlabel="epochs", labels=['mses'], path=self.save_path + "mses.png", title="mean square error")
        plot_lines([self.indexes, self.mxes], xlabel="epochs", labels=['mxes'], path=self.save_path + "mxes.png", title="maximum error")
        plot_lines([self.indexes, self.l1res, self.l2res],
                        xlabel="epochs",
                        labels=['l1re', 'l2re'],
                        path=self.save_path + "relerr.png",
                        title="relative error")

        self.indexes = []
        self.maes = []   
        self.mses = []   
        self.mxes = []   
        self.l1res = []  
        self.l2res = []  
        self.crmses = [] 

        self.epochs_since_last_resample = 0
        self.valid_epoch = 0