import pickle
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

from common.utils import discount_cumsum, nstep_cumsum


def save_return_curve(return_results: dict, reward_results: dict, save_path: Path, nstep: int):
    return_fig = plt.figure(figsize=(10, 5))
    ax_return_fig = return_fig.add_subplot(111)
    
    for target_return, returns_list in return_results.items():
        rtg_fig = plt.figure(figsize=(10, 5))
        ax_rtg_fig = rtg_fig.add_subplot(111)
        returns_list = [r.reshape(-1) for r in returns_list]
        
        max_episode_length = max([len(r) for r in returns_list])
        padded_returns_list, returns_mask_list = [], []
        for returns_array in returns_list:
            padded_returns_array = np.concatenate([returns_array, np.zeros((max_episode_length - returns_array.shape[0]))])
            padded_returns_list.append(padded_returns_array)
        padded_returns_list = np.stack(padded_returns_list)
        average_returns = np.mean(padded_returns_list, axis=0)
        
        rewards_list = reward_results[target_return]
        actual_rtg_list = [discount_cumsum(rewards, gamma=1.) for rewards in rewards_list]
        padded_actual_rtg_list, actual_rtg_mask_list = [], []
        for actual_rtg in actual_rtg_list:
            padded_actual_rtg = np.concatenate([actual_rtg, np.zeros((max_episode_length - actual_rtg.shape[0]))], axis=0)
            padded_actual_rtg_list.append(padded_actual_rtg)
        padded_actual_rtg_list = np.stack(padded_actual_rtg_list)
        average_rtg = np.mean(padded_actual_rtg_list, axis=0)
        
        if nstep > 0:
            rewards_list = reward_results[target_return]
            actual_nstep_rtg_list = [nstep_cumsum(rewards, nstep) for rewards in rewards_list]
            padded_actual_nstep_rtg_list = []
            for actual_nstep_rtg in actual_nstep_rtg_list:
                padded_actual_nstep_rtg = np.concatenate([actual_nstep_rtg, np.zeros((max_episode_length - actual_nstep_rtg.shape[0]))], axis=0)
                padded_actual_nstep_rtg_list.append(padded_actual_nstep_rtg)
            padded_actual_nstep_rtg_list = np.stack(padded_actual_nstep_rtg_list)
            average_nstep_rtg = np.mean(padded_actual_nstep_rtg_list, axis=0)
        
        steps = np.arange(max_episode_length)
        ax_return_fig.plot(steps, average_returns, label=f"{target_return:.1f}")
        
        ax_rtg_fig.plot(steps, average_rtg, label='Actual Return')
        ax_rtg_fig.plot(steps, average_returns, label='Target Return (Input Return-to-go)')
        ax_rtg_fig.set_xlabel('Steps')
        ax_rtg_fig.set_ylabel('Return')
        ax_rtg_fig.legend()
        rtg_fig.savefig(str(save_path / f"rtg_{target_return:.1f}.pdf"), dpi=100)
        
        if nstep > 0:
            nstep_fig = plt.figure(figsize=(10, 5))
            ax_nstep_fig = nstep_fig.add_subplot(111)
            ax_nstep_fig.plot(steps, average_nstep_rtg, label=f'Actual Return ({nstep} step)')
            ax_nstep_fig.plot(steps, average_returns, label=f'Target Return (Input Return-to-go) ({nstep} step)')
            ax_nstep_fig.set_xlabel('Steps')
            ax_nstep_fig.set_ylabel('Return')
            ax_nstep_fig.legend()
            nstep_fig.savefig(str(save_path / f"nstep_rtg_{target_return:.1f}.pdf"), dpi=100)
    
    ax_return_fig.set_xlabel('Steps')
    ax_return_fig.set_ylabel('Input Return-to-go')
    ax_return_fig.legend()
    return_fig.savefig(str(save_path / "actual_return.pdf"), dpi=100)
    
    with open(save_path / "actual_return.pkl", 'wb') as f:
        pickle.dump(return_results, f)
