import csv
import sys
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F


class Logger(object):
    def __init__(self, filename):
        self.terminal = sys.stdout
        self.log = open(filename, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass


def display_loss(loss, save_path=None):
    plt.plot(loss)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')

    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()


def save_toFile(path, file_name, data_saved, rows=0):
    f = open(path + file_name, 'w')
    writer = csv.writer(f)
    if rows == 0:
        writer.writerow(data_saved)
    if rows == 1:
        writer.writerows(data_saved)
    f.close()


def display_losses(loss1, loss2=None, save_path=None):
    plt.plot(loss1, label='Training Loss - Agent A')

    if loss2 is not None:
        plt.plot(loss2, label='Training Loss - Agent B')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.legend()

    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()


def gumbel_softmax(logits, temperature=1.0):
    g = torch.distributions.gumbel.Gumbel(torch.zeros_like(logits), torch.ones_like(logits))
    G = g.sample()
    return F.softmax((logits + G) / temperature, -1)


def straight_through_discretize(z_sampled_soft):
    z_argmax = torch.argmax(z_sampled_soft, dim=-1, keepdim=True)
    z_argmax_one_hot = torch.zeros_like(z_sampled_soft).scatter_(-1, z_argmax, 1)
    z_sampled_onehot_with_grad = z_sampled_soft + (z_argmax_one_hot - z_sampled_soft).detach()
    return z_sampled_onehot_with_grad, z_argmax.squeeze(-1)
