"""
this is the training loss curve or another metrics
"""

import re
import matplotlib.pyplot as plt

# 假设文本数据存储在名为 'training_log.txt' 的文件中
file_path = './logs/log_12292207-avg-pure.txt'

# 初始化数据容器
epochs = []
train_acc = []
train_loss = []
val_acc = []
val_loss = []

# 打开并读取文件
with open(file_path, 'r') as file:
    for line in file:
        # 使用正则表达式提取数据
        match = re.match(r"Epoch (\d+)\s+lr:(\S+)\s+train_acc\s*:\s*(\S+)\s+train_loss\s*:\s*(\S+)\s+time:(\S+)\s+val_acc\s*:\s*(\S+)\s+val_loss\s*:\s*(\S+)", line)
        if match:
            epoch = int(match.group(1))
            # 这里我们不关心 lr 和 time，因此跳过这些
            train_acc_value = float(match.group(3))
            train_loss_value = float(match.group(4))
            val_acc_value = float(match.group(6))
            val_loss_value = float(match.group(7))

            # 将数据添加到相应的列表中
            epochs.append(epoch)
            train_acc.append(train_acc_value)
            train_loss.append(train_loss_value)
            val_acc.append(val_acc_value)
            val_loss.append(val_loss_value)

# 绘制图表
plt.figure(figsize=(12, 6))

# 绘制训练和验证准确率
plt.subplot(2, 2, 1)
plt.plot(epochs, train_acc, label='Train Accuracy', color='blue', marker='o')
plt.plot(epochs, val_acc, label='Validation Accuracy', color='orange', marker='x')
plt.title('Accuracy vs Epochs')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

# 绘制训练和验证损失
plt.subplot(2, 2, 2)
plt.plot(epochs, train_loss, label='Train Loss', color='blue', marker='o')
plt.plot(epochs, val_loss, label='Validation Loss', color='orange', marker='x')
plt.title('Loss vs Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# 绘图
plt.tight_layout()  # 自动调整子图布局
plt.savefig('./generated/metrics/training_pop909.png')