import pandas as pd
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

# Load the data from CSV without headers
file_path = "models/model_D/log.txt"  # replace with your actual file path
data = pd.read_csv(file_path, header=None)

# Manually assign column names
data.columns = ["epch", "train_loss", "valid_loss"] + [f"valid_loss_{chr(65 + i)}" for i in range(4)]

# Filter the data to include only up to 1000 epochs
data = data[data["epch"] <= 800]


# Function to smooth the data using a moving average
def smooth(series, window_size=5):
    return gaussian_filter1d(series, sigma=window_size)


# Apply smoothing
data["train_loss"] = smooth(data["train_loss"])
data["valid_loss"] = smooth(data["valid_loss"])
data["valid_loss_A"] = smooth(data["valid_loss_A"])
data["valid_loss_B"] = smooth(data["valid_loss_B"])
data["valid_loss_C"] = smooth(data["valid_loss_C"])
data["valid_loss_D"] = smooth(data["valid_loss_D"])

# Plotting the validation losses
plt.figure(figsize=(6, 3))

plt.plot(data["epch"], data["valid_loss_A"], label="X$_1$ (B)", color="#2CA02C")
plt.plot(data["epch"], data["valid_loss_B"], label="X$_2$ (C)")
plt.plot(data["epch"], data["valid_loss_C"], label="X$_3$ (D)", color="#ef9f00")
plt.plot(data["epch"], data["valid_loss_D"], label="X$_4$ (E)", color="#EB3324")

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Validation loss for each position in the retrieval chain")
plt.legend()
plt.grid(True)
plt.tight_layout()

# Show the plot
plt.savefig("chain_D.png")
plt.savefig("chain_D.svg")
plt.savefig("chain_D.pdf")

# Plot training and validation losses
plt.figure(figsize=(6, 3))

plt.plot(data["epch"], data["train_loss"], label="Training Loss")
plt.plot(data["epch"], data["valid_loss"], label="Validation Loss")

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and validation loss of Transformer D")
plt.legend()
plt.grid(True)
plt.tight_layout()

# Show the plot
plt.savefig("loss_D.png")
plt.savefig("loss_D.svg")
plt.savefig("loss_D.pdf")
