import os
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

def plot_cka(args, cka_losses, rank = None):
    plt.figure(dpi = 150, figsize = (17, 4))
    for i, layer in enumerate(cka_losses):
        layer_cka_loss = np.array(cka_losses[layer])
        x = np.array([i for i in range(len(layer_cka_loss))])
        plt.plot(x, layer_cka_loss, label = f'layer_{i}', alpha = 0.7)
    plt.xlabel('Training Steps')
    plt.ylabel('1 - RepSim')
    plt.legend()
    title = 'RepSim Losses Across Layers'
    title += f' with rank {rank}' if rank != None else ''
    plt.title(title)
    os.makedirs(f'figures/{args.exp_name}/', exist_ok = True)
    plt.savefig(f'figures/{args.exp_name}/cka_losses.pdf', format = 'pdf')

def plot_ft(args, ft_json, orig_acc):
    fig, axes = plt.subplots(nrows = 1, ncols = 3, figsize = (20, 4), dpi = 100)
    
    train_losses = np.array(ft_json['train_losses'])
    x = np.array([i for i in range(len(train_losses))])
    axes[0].plot(x, train_losses)
    axes[0].set_title('Training Loss Over Steps')
    axes[0].set_xlabel('Training Steps')
    axes[0].set_ylabel('Training Loss')

    val_losses = np.array(ft_json['val_losses'])
    epochs = np.array([i for i in range(len(val_losses))])
    axes[1].plot(epochs, val_losses)
    axes[1].set_title('Validation Loss Over Epochs')
    axes[1].set_xlabel('Epochs')
    axes[1].set_ylabel('Validation Loss')

    accs = np.array(ft_json['top1_accs'])
    epochs = np.array([i for i in range(len(accs))])
    axes[2].plot(epochs, accs)
    axes[2].set_title('Top-1 Accuracy Over Epochs')
    axes[2].set_xlabel('Epochs')
    axes[2].set_ylabel('Top-1 Accuracy')
    axes[2].axhline(orig_acc, linestyle = '--', color = 'red')

    os.makedirs(f'figures/{args.exp_name}', exist_ok = True)
    plt.savefig(f'figures/{args.exp_name}/ft_plots.pdf', format = 'pdf')