from pathlib import Path

import pandas as pd
from matplotlib import pyplot as plt

from tabicl.config.config_pretrain import ConfigPretrain
from tabicl.core.enums import Task
from tabicl.utils.paths_and_filenames import METRICS_PLOT_FILE_NAME, METRICS_TRAIN_FILE_NAME, METRICS_VAL_FILE_NAME


def plot_loss(cfg: ConfigPretrain):

    output_dir = cfg.output_dir
    loss_graph_min_step = cfg.testing.loss_graph_min_step
    
    match cfg.data.task:
        case Task.REGRESSION:
            loss_name = 'MSE Loss'
            metric_name = 'mse'
        case Task.CLASSIFICATION:
            loss_name = 'Cross Entropy Loss'
            metric_name = 'cross_entropy'

    metrics_train = pd.read_csv(output_dir / METRICS_TRAIN_FILE_NAME)
    metrics_val = pd.read_csv(output_dir / METRICS_VAL_FILE_NAME)

    fig, ax = plt.subplots(figsize=(15, 6))

    ax.plot(metrics_train['step'], metrics_train[metric_name], color='blue', label=f'{loss_name} (training)')

    ax.set_ylabel(loss_name)
    
    ax.set_yscale('log')
    ax.set_xscale('log')

    min_step = loss_graph_min_step

    if len(metrics_train[metric_name]) > min_step:
        ax.set_xlim(min_step, len(metrics_train['step']))
        ax.set_ylim(min(metrics_train[metric_name][min_step:]) * 0.9, max(metrics_train[metric_name][min_step:] ) * 1.1)  

    ax.get_xaxis().set_major_formatter(plt.ScalarFormatter())
    ax.get_yaxis().set_major_formatter(plt.ScalarFormatter())

    ax.set_xlabel('Step')

    ax2 = ax.twinx()
    ax2.plot(metrics_val['step'], metrics_val['norm_acc_val_finetune'], color='darkred', label='Finetune (validation)', linewidth=2)
    ax2.plot(metrics_val['step'], metrics_val['norm_acc_test_finetune'], color='red', label='Finetune (test)', linewidth=2)
    ax2.plot(metrics_val['step'], metrics_val['norm_acc_val_zeroshot'], color='darkgreen', label='Zeroshot (validation)', linewidth=2)
    ax2.plot(metrics_val['step'], metrics_val['norm_acc_test_zeroshot'], color='green', label='Zeroshot (test)', linewidth=2)
    ax2.set_ylabel('Normalized accuracy')


    fig.suptitle('PreTraining', fontsize=16)
    fig.legend(loc='lower left', bbox_to_anchor=(0.13, 0.12))

    fig.savefig(output_dir / METRICS_PLOT_FILE_NAME)


if __name__ == '__main__':

    output_dir = Path('outputs/runs/2024-05-30/10-46-31')

    class Config:
        output_dir = output_dir
        testing = type('testing', (), {'loss_graph_min_step': 999})
        data = type('data', (), {'task': Task.CLASSIFICATION})

    plot_loss(Config())   # type: ignore