import os
import matplotlib
matplotlib.use('Agg') # Use a non-GUI backend
import matplotlib.pyplot as plt
import numpy as np




class WriteResults():
    def __init__(self, parser, path):
        if not os.path.exists(path):
            os.makedirs(path)
        self.path = path
        self.identifier = ''
        parameters = vars(parser)
        #for key, value in parameters.items():
        #    if not "path" in key and not "data_size" in key and not "val_size" in key and not "batch_size" in key and not "disc_steps" in key:
        #        self.identifier += f'{key}_{value}-#-'
        #if not os.path.exists(path + "/" + self.identifier):
        #    os.makedirs(path + "/" + self.identifier)
        self.path = path #+ "/" + self.identifier
    def write_image_traj(self, samples, disc_steps,no_trajs,name):
        no_samples = min(no_trajs, samples.shape[0]) #samples.shape[0]
        plt.figure()
        for k in range(no_samples):
            plt.plot(samples[k,::disc_steps])
        plt.savefig(self.path + "/" + name + self.identifier + ".png")
        plt.close()
    def write_image_traj_first(self, samples, disc_steps,no_trajs,name):
        no_samples = min(no_trajs, samples.shape[0]) #samples.shape[0]
        plt.figure()
        for k in range(no_samples):
            plt.plot(samples[k,:disc_steps])
        plt.savefig(self.path + "/" + name + self.identifier + ".png")
        plt.close()

    def write_list(self, list):
        with open(self.path + "/" + self.identifier + ".txt", "w") as f:
            for item in list:
                f.write("%s\n" % item)

    def plot_loss(self, loss):
        plt.figure()
        plt.plot(loss)
        plt.savefig(self.path + "/LOSS" + self.identifier + ".png")
        plt.close()

    def write_value(self, val, name):
        with open(self.path + "/" + name + "_@_" + self.identifier + ".txt", "w") as f:
            f.write("%s\n" % val)
    def save_tensor_as_txt(self, tensor, name):
        tensor = tensor.detach().cpu().numpy()
        np.savetxt(self.path + "/" + name + "_@_" + self.identifier + ".txt", tensor)
