import jax.numpy as jnp
from typing import NamedTuple
import matplotlib.pyplot as plt

class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    unnorm_obs: jnp.ndarray
    unnorm_next_obs: jnp.ndarray
    info: jnp.ndarray

class DiscTransitionData(NamedTuple):
    pi_action: jnp.ndarray
    pi_unnorm_obs: jnp.ndarray
    pi_unnorm_next_obs: jnp.ndarray
    expert_action: jnp.ndarray
    expert_unnorm_obs: jnp.ndarray
    expert_unnorm_next_obs: jnp.ndarray


def plot_il_metrics(results, save_dir):
    # plot returns
    avg_returns_per_update = results['metrics']["returned_episode_returns"].mean(axis=(-1, -2, 0))  # flatten across steps & envs
    std_returns_per_update = results['metrics']["returned_episode_returns"].mean(axis=(-1, -2)).std(axis=0)
    print(f"Final return (train): {avg_returns_per_update[-1]} +- {std_returns_per_update[-1]}")
    print(f'Final return (eval): {results["eval_transitions"].info["returned_episode_returns"][results["eval_transitions"].info["returned_episode"]]}')
    mean_rews = results["eval_transitions"].info["returned_episode_returns"][results["eval_transitions"].info["returned_episode"]].mean()
    std_rews = results["eval_transitions"].info["returned_episode_returns"][results["eval_transitions"].info["returned_episode"]].std()
    print(f'Final return (eval mean, std (high var)): {mean_rews} +- {std_rews}')
    plt.plot(avg_returns_per_update)
    plt.fill_between(
        range(len(avg_returns_per_update)),
        avg_returns_per_update - std_returns_per_update,
        avg_returns_per_update + std_returns_per_update,
        alpha=0.3,
    )
    plt.title("Episode Return")
    plt.xlabel("Update")
    plt.ylabel("Average Return")
    plt.savefig(save_dir+'episode_returns.png')
    plt.close()

    # plot disc loss
    disc_loss_mean = results['metrics']["disc_total_loss"].mean(0)
    disc_loss_std = results['metrics']["disc_total_loss"].std(0)
    plt.plot(disc_loss_mean)
    plt.fill_between(
        range(len(disc_loss_mean)),
        disc_loss_mean - disc_loss_std,
        disc_loss_mean + disc_loss_std,
        alpha=0.3,
    )
    plt.title("Discriminator Loss")
    plt.xlabel("Update")
    plt.ylabel("Loss")
    plt.savefig(save_dir+'disc_loss.png')
    plt.close()

    # plot disc ce loss
    disc_ce_loss_mean = results['metrics']["disc_ce_loss"].mean(0)
    disc_ce_loss_std = results['metrics']["disc_ce_loss"].std(0)
    plt.plot(disc_ce_loss_mean)
    plt.fill_between(
        range(len(disc_ce_loss_mean)),
        disc_ce_loss_mean - disc_ce_loss_std,
        disc_ce_loss_mean + disc_ce_loss_std,
        alpha=0.3,
    )
    plt.title("Discriminator CE Loss")
    plt.xlabel("Update")
    plt.ylabel("Loss")
    plt.savefig(save_dir+'disc_ce_loss.png')
    plt.close()

    # plot disc gp loss
    disc_gp_loss_mean = results['metrics']["disc_gp_loss"].mean(0)
    disc_gp_loss_std = results['metrics']["disc_gp_loss"].std(0)
    plt.plot(disc_gp_loss_mean)
    plt.fill_between(
        range(len(disc_gp_loss_mean)),
        disc_gp_loss_mean - disc_gp_loss_std,
        disc_gp_loss_mean + disc_gp_loss_std,
        alpha=0.3,
    )
    plt.title("Discriminator GP Loss")
    plt.xlabel("Update")
    plt.ylabel("Loss")
    plt.savefig(save_dir+'disc_gp_loss.png')
    plt.close()

    # plot disc rewards
    disc_rewards_mean = results['metrics']["disc_rewards"].mean(0)
    disc_rewards_std = results['metrics']["disc_rewards"].std(0)
    plt.plot(disc_rewards_mean)
    plt.fill_between(
        range(len(disc_rewards_mean)),
        disc_rewards_mean - disc_rewards_std,
        disc_rewards_mean + disc_rewards_std,
        alpha=0.3,
    )
    plt.title("Discriminator Rewards")
    plt.xlabel("Update")
    plt.ylabel("Rewards")
    plt.savefig(save_dir+'disc_rewards.png')
    plt.close()

    # plot entropy
    entropy_mean = results['metrics']["entropy"].mean(0)
    entropy_std = results['metrics']["entropy"].std(0)
    plt.plot(entropy_mean)
    plt.fill_between(
        range(len(entropy_mean)),
        entropy_mean - entropy_std,
        entropy_mean + entropy_std,
        alpha=0.3,
    )
    plt.title("Entropy")
    plt.xlabel("Update")
    plt.ylabel("Entropy")
    plt.savefig(save_dir+'entropy.png')
    plt.close()

    # plot actor loss
    actor_loss_mean = results['metrics']["actor_loss"].mean(0)
    actor_loss_std = results['metrics']["actor_loss"].std(0)
    plt.plot(actor_loss_mean)
    plt.fill_between(
        range(len(actor_loss_mean)),
        actor_loss_mean - actor_loss_std,
        actor_loss_mean + actor_loss_std,
        alpha=0.3,
    )
    plt.title("Actor Loss")
    plt.xlabel("Update")
    plt.ylabel("Loss")
    plt.savefig(save_dir+'actor_loss.png')
    plt.close()

    # plot value loss
    value_loss_mean = results['metrics']["value_loss"].mean(0)
    value_loss_std = results['metrics']["value_loss"].std(0)
    plt.plot(value_loss_mean)
    plt.fill_between(
        range(len(value_loss_mean)),
        value_loss_mean - value_loss_std,
        value_loss_mean + value_loss_std,
        alpha=0.3,
    )
    plt.title("Value Loss")
    plt.xlabel("Update")
    plt.ylabel("Loss")
    plt.savefig(save_dir+'value_loss.png')
    plt.close()