import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
from sklearn.decomposition import PCA
import atexit


def Plot_losses(train_losses, val_losses, title, path):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss', color='blue', marker='o')
    plt.plot(val_losses, label='Validation Loss', color='orange', marker='x')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title(f'Training/Validation Loss of {title}')
    plt.legend()
    plt.grid(True)
    plt.savefig(path)



class LamCollector:
    def __init__(self, t):
        self.t = t
        self.data = []


    def add_sequence(self, values):
        if len(values) != self.t:
            raise ValueError(f"Input length must be {self.t}, but got {len(values)}")
        self.data.append(values)


    def summarize_and_plot(self, plot_path="temporal_summary.png", csv_path="output/temporal_mean_std.csv"):
        os.makedirs(os.path.dirname(plot_path), exist_ok=True)
        os.makedirs(os.path.dirname(csv_path), exist_ok=True)
        
        data_array = np.array(self.data)  # shape: [n, t]
        if data_array.shape[0] < 2:
            raise ValueError("Need at least 2 sequences to compute mean and std.")

        mean_curve = np.mean(data_array, axis=0)
        std_curve = np.std(data_array, axis=0)
        time_steps = np.arange(self.t)

        """plt.figure(figsize=(10, 5))
        plt.plot(time_steps, mean_curve, label="Mean", color="blue")
        plt.fill_between(time_steps, mean_curve - std_curve, mean_curve + std_curve, color="blue", alpha=0.2, label="±1 Std")
        plt.title("Temporal Data Mean ± Std")
        plt.xlabel("Time Step")
        plt.ylabel("Value")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(plot_path)
        plt.close()"""

        
        df = pd.DataFrame({
            "time_step": time_steps,
            "mean": mean_curve,
            "std": std_curve
        })
        df.to_csv(csv_path, index=False)
        
        return mean_curve, std_curve