import numpy as np
import matplotlib.pyplot as plt
import pickle
from scipy.interpolate import UnivariateSpline
from scipy.ndimage import uniform_filter1d
import os
from pathlib import Path


def compute_global_distance_range(folder_path):

    max_vals = []
    min_vals = []

    for filename in os.listdir(folder_path):
        if filename.endswith('.pkl'):
            file_path = os.path.join(folder_path, filename)
            try:
                with open(file_path, 'rb') as f:
                    results = pickle.load(f)
                    distances = np.array(results["distances"])
                    
                    top_5_max = np.mean(np.sort(distances)[-5:])   # 取最大5个值的平均
                    top_5_min = np.mean(np.sort(distances)[:5])    # 取最小5个值的平均
                    
                    max_vals.append(top_5_max)
                    min_vals.append(top_5_min)
            except Exception as e:
                print(f"Warning: failed to process {filename} due to {e}")

    if not max_vals or not min_vals:
        raise ValueError("No valid .pkl files with 'distances' found in the folder.")

    global_max = max(max_vals)
    global_min = min(min_vals)

    return global_max, global_min



def generate_memory_plot(positions, distances, distance_max, distance_min, alpha=5.0, region_ranges=None, file_path=None):

    positions  = np.asarray(positions).ravel().astype(float)
    distances  = np.asarray(distances).ravel().astype(float)
    if region_ranges is None:
        region_ranges = [(0, 100),(300, 400), (600, 700), (900, 1000), (1200, 1300)]
        
    sorted_distances = np.sort(distances)
    d_min = np.mean(sorted_distances[:5])  
    d_max = np.mean(sorted_distances[-5:]) 
    print("min distance in local:", d_min, "max distance in local:", d_max)
    d_min = distance_min
    d_max = distance_max
    print("min distance in calcuation:", d_min, "max distance in calcuation:", d_max)
    

    distances_clipped = np.clip(distances, d_min, d_max)
    raw = np.exp(-alpha * (distances_clipped - d_min))

    raw_max = np.exp(-alpha * (d_max - d_min))  
    memory_strength = (raw - raw_max) / (1 - raw_max)
    memory_strength = np.clip(memory_strength, 0.0, 1.0)


    regional_avg_pos = []
    regional_avg_val = []
    regional_max_pos = []
    regional_max_val = []

    for (start, end) in region_ranges:
        mask = (positions >= start) & (positions <= end)
        if np.any(mask):
            pos_subset = positions[mask]
            mem_subset = memory_strength[mask]
            regional_avg_pos.append(np.mean(pos_subset))
            regional_avg_val.append(np.mean(mem_subset))
            max_idx = np.argmax(mem_subset)
            regional_max_pos.append(pos_subset[max_idx])
            regional_max_val.append(mem_subset[max_idx])

    plt.figure(figsize=(12, 6))
    for (start, end) in region_ranges:
        plt.axvspan(start, end, color='lightgreen', alpha=0.2)

    plt.plot(positions, memory_strength, label="Memory Strength", linewidth=2)

    window = 50
    padded = np.concatenate([np.zeros(window - 1), memory_strength.astype(float)])
    smoothed = uniform_filter1d(padded, size=window, mode='nearest')
    memory_strength_trend = smoothed[window - 1:]

    plt.plot(positions, memory_strength_trend, color='red', linewidth=2, label="Trend Line (Spline)")

    for x in range(0, int(max(positions)) + 1, 300):
        plt.axvline(x=x, color='gray', linestyle=':', linewidth=0.8)
    plt.ylim(0,1.1)
    plt.xlabel("Input Content Length")
    plt.ylabel("Memory Retention")
    plt.title("Memory Retention vs. Input Content Length")
    plt.legend()
    plt.tight_layout()
    plt.grid(True)
    plt.show()
    location_name = Path(file_path).parent 
    data_name = Path(file_path).stem 

    image_path = location_name / f"{data_name}_smooth.png"

    plt.savefig(image_path, dpi=300)

    original_data_location = Path("xxx")
    original_data_path = os.path.join(original_data_location, f"{data_name}.pkl")
    original_data_location.mkdir(parents=True, exist_ok=True)
    with open(original_data_path, "wb") as f:
        pickle.dump({
            "positions": positions,
            "memory_strength_trend": memory_strength
        }, f)

    print(f"Data saved to: {original_data_path}")



folder_path = "xxx"
for filename in os.listdir(folder_path):
    if filename.endswith('.pkl'):
        file_path = os.path.join(folder_path, filename)
        print(f"Processing file: {file_path}")
        with open(file_path, "rb") as f:
            results = pickle.load(f)

        print(type(results))
        global_max, global_min = compute_global_distance_range(Path(file_path).parent)
        print(f"Global max: {global_max}, Global min: {global_min}")
        generate_memory_plot(results['positions'], results['distances'], global_max, global_min, alpha=5.0, region_ranges=None, file_path=file_path)