import pandas as pd
import matplotlib.pyplot as plt
import numpy as np


# Load the CSV file
# df = pd.read_csv('logs/analysis/dynamics_analysis_0_to_4k.csv')  # Replace with your actual file name
# iteration_strs = ["0", "1k", "2k", "3k", "4k"]

df = pd.read_csv('logs/analysis/dynamics_analysis_0_to_4k_rich.csv')  # Replace with your actual file name
iteration_strs = [str(i) for i in range(25, 250, 25)]  + [str(i) for i in range(250, 4000, 250)] + ["4k"]
x_num = list(range(25, 250, 25)) + list(range(250, 4000, 250)) + [3999]  # x-axis values for the plot

colums_names = [
    f"analysis_model_it_{it}_mixed_data_10k_samples_input_0 - loss_train_epoch" for it in iteration_strs]

data = df.loc[19, colums_names].to_numpy()



# x_num = [0, 250, 1000, 3999]
# data = [0.0673, 0.0655, 0.0792, 0.0834]

# make bar plot
# import os
plt.figure(figsize=(10, 5))
plt.scatter(x_num, data, marker='x',  label='final converged loss', color='red')
plt.title('Final converged loss of dynamics prediction task for policy networks from different iteration numbers')
plt.xlabel('RL Iteration Number')
plt.ylabel('Loss')
plt.grid()
# plt.xticks(ticks=x_num, labels=iteration_strs)
plt.savefig('logs/analysis/plots/dynamics_analysis_mixed_data_4.png')