from pathlib import Path

import pandas as pd
from matplotlib import pyplot as plt

from tabicl.utils.paths_and_filenames import METRICS_PLOT_FILE_NAME, METRICS_TRAIN_FILE_NAME, METRICS_VAL_FILE_NAME


def plot_loss(output_dir: Path, loss_graph_min_step: int):

    metrics_train = pd.read_csv(output_dir / METRICS_TRAIN_FILE_NAME)
    metrics_val = pd.read_csv(output_dir / METRICS_VAL_FILE_NAME)

    fig, ax = plt.subplots(figsize=(15, 6))

    ax.plot(metrics_train['step'], metrics_train['loss'], color='blue', label='Cross Entropy Loss (training)')

    ax.set_ylabel('Cross Entropy Loss')
    
    ax.set_yscale('log')
    ax.set_xscale('log')

    min_step = loss_graph_min_step

    if len(metrics_train['loss']) > min_step:
        ax.set_xlim(min_step, len(metrics_train['step']))
        ax.set_ylim(min(metrics_train['loss'][min_step:]) * 0.9, max(metrics_train['loss'][min_step:] ) * 1.1)  

    ax.get_xaxis().set_major_formatter(plt.ScalarFormatter())
    ax.get_yaxis().set_major_formatter(plt.ScalarFormatter())

    ax.set_xlabel('Step')

    ax2 = ax.twinx()
    ax2.plot(metrics_val['step'], metrics_val['norm_acc_val_finetune'], color='darkred', label='Finetune (validation)', linewidth=2)
    ax2.plot(metrics_val['step'], metrics_val['norm_acc_test_finetune'], color='red', label='Finetune (test)', linewidth=2)
    ax2.plot(metrics_val['step'], metrics_val['norm_acc_val_zeroshot'], color='darkgreen', label='Zeroshot (validation)', linewidth=2)
    ax2.plot(metrics_val['step'], metrics_val['norm_acc_test_zeroshot'], color='green', label='Zeroshot (test)', linewidth=2)
    ax2.set_ylabel('Normalized accuracy')


    fig.suptitle('PreTraining', fontsize=16)
    fig.legend(loc='lower left', bbox_to_anchor=(0.13, 0.12))

    fig.savefig(output_dir / METRICS_PLOT_FILE_NAME)


if __name__ == '__main__':

    output_dir = Path('outputs/runs/2024-05-30/10-46-31')
    plot_loss(output_dir, loss_graph_min_step=999)