import torch
import matplotlib.pyplot as plt


def visualize_TS(x, output : str):
    """Method to visualize time series. 
    """
    t = [i for i in range(x)]
    plt.plot(t, x)
    plt.savefig(output)
    pass

def visualize_loss(epoch, loss, output : str):
    """Method to visualize loss.

    Args:
        loss ([type]): Array containing loss. 
        epoch ([type]): Array containing epochs. 
        output ([type]): Output file path. 
    """
    plt.figure()
    plt.plot(epoch, loss)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.savefig(output)

def visualize_space(x, method : str, output: str):
    """Method to visualize spatial analysis. 

    Args:
        x ([type]): Input array. 
        method (str): Type of visualization done for visualizing spatial relationships. 
        output (str): Output file path. 
    """

    if(method == "heatmap"):
        plt.imshow(x, cmap='hot', interpolation='nearest')
    
    plt.savefig(output)

def tensorboard_TS(writer, index, x, y, dataset):
    """Adds the time series to a running SummaryWriter instance of TensorBoard. 

    Args:
        writer (SummaryWriter): Default summary writer for Tensorboard. 
        index (int): Current index. 
        x (float): Input time series. 
        y (float): Output time series. 
    """
    
    if(dataset == "ar"):
        for jndex in range(len(x)):
            writer.add_scalar('Data/X', x[jndex].item(), jndex + len(x) * index)
        for kndex in range(len(y[0])):
            writer.add_scalar('Data/y', y[0][kndex].item(), kndex + len(y[0]) * index)

    else:
        x = torch.flatten(x)
        y = torch.flatten(y)

        if(index == 0):
            for jndex in range(x.shape[0]):
                writer.add_scalar('Data/X', x[jndex].item(), jndex)
            for kndex in range(y.shape[0]):
                writer.add_scalar('Data/y', y[kndex].item(), kndex)
        
        else:
            writer.add_scalar('Data/X', x[-1].item(), index+(x.shape[0]))
            writer.add_scalar('Data/y', y[-1].item(), index+(y.shape[0]))

def tensorboard_loss(writer, index, **kwargs):
    """Add loss to TensorBoard. 

    Args:
        writer (summarWriter): summary writer. 
        index (index): Index. 
    """
    for key, value in kwargs.items():
        writer.add_scalar(key, value, index)

def tensorboard_preds(writer, index, spreds, rpreds, coord_matrix):
    """Method to add predictions to tensorboard. 

    Args:
        writer (): Summary Writer
        index (int):  Index
        spreds (tensor): Predictions from the S-block
        rpreds (tensor): Predictions from the R-matrix
        coord_matrix (tensor): Coordinate matrix to identify which R-Block is being used. 
    """
    spreds = torch.flatten(spreds)
    rpreds = torch.flatten(rpreds)
    coord_matrix = torch.flatten(coord_matrix, start_dim=0, end_dim=1)
    for coord, spred, rpred in zip(coord_matrix, spreds, rpreds):
        #print(str(coord))
        #print(spred.item())
        #print(rpred.item())
        writer.add_scalars('Preds/rpred' , 
            {str(coord) : rpred.item()}, 
        index)
        writer.add_scalars('Preds/Spred' , {str(coord) : spred.item()}, index)

def tensorboard_graph(writer, model, x, add_graph):
    """Add graphs to the Tensorboard.
    """
    # get S and R blocks.
    s_block = model.get_sblock()
    rmatrix = model.get_rmatrix()
    t = x.permute(0, 2, 1)
    x_adj = model.rmatrix_hidden(t)
    # sample R-Block
    rblock = rmatrix.get_rblock(coordinate=[0,0])

    if(add_graph == "S"):  preds = writer.add_graph(s_block, x)
    elif(add_graph == "R"): preds = writer.add_graph(rblock, (t, x_adj), verbose=True)
    else: return None
    return preds