import numpy as np
import copy
import matplotlib.pyplot as plt
from tqdm import tqdm



def log_offline_rl(logging, pi, env_eval, pi_eval):
    logging['offline_rl_subopt'].append(env_eval.evaluate_policy(pi_eval) - env_eval.evaluate_policy(pi))
    return logging

def plot_suboptimalities(expt_dicts, N_samples_per_it =50, save_name=None):
    """Plot suboptimality for different experiments."""
    for name, expt in expt_dicts.items():
        x_axis = np.arange(len(expt[0]['subopt']))*N_samples_per_it
        subopt = [r['subopt'] for r in expt]
        average_curve = np.mean(subopt, axis=0)
        confidence_interval = 1.96 * np.std(subopt, axis=0) / np.sqrt(len(subopt))
        plt.plot(x_axis, average_curve, '-', label=name)
        plt.fill_between(x_axis, average_curve - confidence_interval, average_curve + confidence_interval, alpha=0.2)
    plt.legend()
    ax = plt.gca()
    ax.set_xlabel('Human Feedback Samples')
    ax.set_ylabel('Suboptimality')
    plt.grid(alpha=0.2)
    if save_name:
        plt.savefig(save_name, bbox_inches="tight")
    else: plt.show()

def plot_results(results, N_samples_per_it =50, save=None):
    """Plot results of a single experiment."""
    x_axis = np.arange(len(results[0]['subopt']))*N_samples_per_it

    # subopt
    subopt = [r['subopt'] for r in results]
    average_curve = np.mean(subopt, axis=0)
    confidence_interval = 1.96 * np.std(subopt, axis=0) / np.sqrt(len(subopt))
    plt.plot(x_axis, average_curve, 'o-')
    plt.fill_between(x_axis, average_curve - confidence_interval, average_curve + confidence_interval, alpha=0.2)
    ax = plt.gca()
    ax.set_xlabel('Samples')
    ax.set_ylabel('Suboptimality')
    plt.grid(alpha=0.2)
    if save is not None:
        plt.savefig(save + "_subopt.pdf", bbox_inches="tight")
    else: plt.show()

    # error in reward and transition
    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()  
    r_correct = [r['R_correct'] for r in results]
    average_curve = np.mean(r_correct, axis=0)
    confidence_interval = 1.96 * np.std(r_correct, axis=0) / np.sqrt(len(r_correct))
    ax1.plot(x_axis, average_curve, 'go-')
    ax1.fill_between(x_axis, 
                     average_curve - confidence_interval, 
                     average_curve + confidence_interval, 
                     color='g',
                     alpha=0.2)
    
    t_dist_L1 = [r['T_dist_L1'] for r in results]
    average_curve = np.mean(t_dist_L1, axis=0)
    confidence_interval = 1.96 * np.std(t_dist_L1, axis=0) / np.sqrt(len(t_dist_L1))
    ax2.plot(x_axis, average_curve, 'bo-')
    ax2.fill_between(x_axis, 
                     average_curve - confidence_interval, 
                     average_curve + confidence_interval, 
                     color='b',
                     alpha=0.2)
    ax1.set_xlabel('Samples')
    ax1.set_ylabel('Correct R', color='g')
    ax2.set_ylabel('Error in T', color='b')
    plt.grid(alpha=0.2)
    if save is not None:
        plt.savefig(save + "_error.pdf", bbox_inches="tight")
    else: plt.show()