import pickle
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt


def save_velocity_curve(velocity_results: dict, save_path: Path):
    plt.figure(figsize=(10, 5))
    
    for target_return, velocity_list in velocity_results.items():
        max_episode_length = max([len(v) for v in velocity_list])
        padded_velocity_list, velocity_mask_list = [], []
        for velocity_array in velocity_list:
            velocity_mask = np.zeros(max_episode_length)
            velocity_mask[:len(velocity_array)] = 1
            velocity_mask_list.append(velocity_mask)
            padded_velocity_array = np.concatenate([velocity_array, np.zeros(max_episode_length - len(velocity_array))])
            padded_velocity_list.append(padded_velocity_array)
        padded_velocity_list = np.stack(padded_velocity_list)
        velocity_mask_list = np.stack(velocity_mask_list)
        average_velocity = np.sum(padded_velocity_list, axis=0) / np.sum(velocity_mask_list, axis=0)
        average_velocity = smooth(average_velocity, 0.9)
        steps = np.arange(average_velocity.shape[0])
        plt.plot(steps, average_velocity, label=f"{target_return:.1f}")

    plt.xlabel('Steps')
    plt.ylabel('X_velocity')
    plt.legend()
    plt.savefig(str(save_path / "velocity.pdf"), dpi=100)
    
    with open(save_path / "velocity.pkl", 'wb') as f:
        pickle.dump(velocity_results, f)


def smooth(scalars, weight):
    last=scalars[0]
    smoothed=[]
    for point in scalars:
        smoothed_val=last*weight+(1-weight)*point
        smoothed.append(smoothed_val)
        last=smoothed_val
    return np.array(smoothed)