# Copyright (C) king.com Ltd 2025
# License: Apache 2.0
import matplotlib.pyplot as plt
import numpy as np
import pickle
import argparse


def plot_trajs(
        ax,
        rollouts,
        target_angle,
        target_radius
):

    for rollout_idx, rollout in enumerate(rollouts):
        rollout = np.array(rollout)
        ax.plot(
            rollout[:, 0],
            rollout[:, 1],
            lw=1, alpha=0.5)

    target_x = target_radius * np.cos(np.pi * target_angle)
    target_y = target_radius * np.sin(np.pi * target_angle)
    ax.scatter(target_x, target_y, c='gold', marker='*', s=400)

    ax.set_title("Trajectories")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_aspect('equal')
    ax.set_xlim(-3, 3)
    ax.set_ylim(-3, 3)

def plot_prompt_segments(
        ax,
        prompt_states,
        state_mean,
        state_std
):
    for rollout_idx in range(len(prompt_states)):
        for segment in prompt_states[rollout_idx]:
            segment_denormed = (segment * state_std) + state_mean  # denormalize...
            ax.plot(
                segment_denormed[:, 0],
                segment_denormed[:, 1],
                alpha=0.75,
            )

    ax.set_title("Prompt segments")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_xlim(-3, 3)
    ax.set_ylim(-3, 3)

def plot_rewards(
        ax,
        rewards,
        colors=None
):
    ax.scatter(np.arange(len(rewards)), rewards, marker='x', color=colors, s=10)
    ax.axhline(y=np.mean(rewards), color='k', linestyle='--', label="Mean reward: {}".format(np.round(np.mean(rewards), 2)))
    ax.legend()
    ax.set_title("Rewards")
    ax.set_xlabel("Rollouts")
    ax.set_ylabel("Reward")

def plot_loss(
        ax,
        losses,
):
    losses = np.array(losses)
    for i in range(losses.shape[0]):
        ax.plot(np.arange(len(losses[i])), losses[i], lw=0.75)

    ax.set_title("Bandit reward model losses")
    ax.set_xlabel("Epochs")
    ax.set_ylabel("Loss")

def plot_prompt_times(
        ax,
        prompt_times,
):
    prompt_times = np.array(prompt_times)
    for i in range(prompt_times.shape[0]):
        ax.scatter(np.arange(len(prompt_times[i])), prompt_times[i], alpha=0.5)

    ax.set_title("Prompt times")
    ax.set_xlabel("Rollouts")
    ax.set_ylabel("Prompt time")

def plot_epsilon(
        ax,
        epsilon,
        which_bandit
):
    if which_bandit in ["eps_greedy", "ts"]:
        ax.plot(np.arange(len(epsilon)), epsilon)
    else:
        epsilon = np.array(epsilon).T
        for i in range(epsilon.shape[0]):
            ax.plot(np.arange(len(epsilon[i])), epsilon[i], lw=0.75)
    ax.set_title("Bandit exploration")
    ax.set_xlabel("Rollouts")
    ax.set_ylabel("Epsilon" if which_bandit == "epsilon_greedy" else "Exploration bonus")

def plot_tuning_results(mab_results, rollouts, rewards, prompt_states, state_mean, state_std, results_dir, args):
    if mab_results is not None:
        fig, axs = plt.subplots(2, 3, figsize=(15, 10))

        plot_trajs(axs[0, 0], rollouts, target_angle=args.target_angle, target_radius=args.target_radius)
        plot_prompt_segments(axs[0, 1], prompt_states, state_mean, state_std)
        plot_rewards(axs[0, 2], rewards)

        plot_loss(axs[1, 0], mab_results["mab_losses"])
        plot_prompt_times(axs[1, 1], mab_results["mab_segment_t_hist"])
        plot_epsilon(axs[1, 2], mab_results["mab_epsilon_hist"], which_bandit=args.sampling_method)

    else:
        fig, axs = plt.subplots(1, 3, figsize=(10, 5))

        plot_trajs(axs[0], rollouts, target_angle=args.target_angle, target_radius=args.target_radius)
        plot_prompt_segments(axs[1], prompt_states, state_mean, state_std)
        plot_rewards(axs[2], rewards)

    plt.suptitle(
        f"target angle: {args.target_angle}, target radius: {args.target_radius}, prompt selection: {args.sampling_method}")
    plt.tight_layout()

    plt.savefig(f"{results_dir}/results.png", dpi=300, bbox_inches='tight', pad_inches=0.1)
    if not args.hide_plots:
        plt.show()
    plt.close()


def plot_tuning_exp_results(
        no_tuning_results_file,
        with_tuning_results_file,
        base_dir,
        data_mixture,

):
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.axhline(y=10, color='k', linestyle='--', label="Optimal")

    # iterate over method result files
    for method_result_file in [no_tuning_results_file, with_tuning_results_file]:
        method_rewards = []
        with open(method_result_file, "rb") as f:
            result_files = f.readlines()
            for result_file in result_files:
                result_file = result_file.decode("utf-8").strip()
                print("result file:", result_file)
                if result_file:
                    with open(f"{base_dir}/{result_file}", "rb") as f:
                        result = pickle.load(f)
                        method_rewards.append(result["rewards"])

        method_rewards = np.array(method_rewards)

        # plot rewards
        mean_rewards = np.mean(method_rewards, axis=0)
        std_rewards = np.std(method_rewards, axis=0)

        # apply some smoothing to increase readability
        smooth_window_len = 10
        mean_rewards = np.convolve(mean_rewards, np.ones(smooth_window_len) / smooth_window_len, mode='valid')
        std_rewards = np.convolve(std_rewards, np.ones(smooth_window_len) / smooth_window_len, mode='valid')

        ax.plot(np.arange(len(mean_rewards)), mean_rewards, label=method_result_file.split("/")[-1].split("-")[1])
        ax.fill_between(np.arange(len(mean_rewards)), mean_rewards - std_rewards, mean_rewards + std_rewards, alpha=0.2)

    ax.set_title(f"Tuning experiment, {data_mixture.replace('mixture-', '')} data")
    ax.set_xlabel("Rollouts")
    ax.set_ylabel("Reward")
    ax.legend()
    plt.tight_layout()
    plt.savefig(f"promptTuning_experiment_{data_mixture}.png", dpi=300, bbox_inches='tight', pad_inches=0.1)
    plt.show()
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--no_tuning_result_files", type=str, default=None)
    parser.add_argument("--with_tuning_result_files", type=str, default=None)
    parser.add_argument("--base_dir", type=str, default=None)
    parser.add_argument("--data_mixture", type=str, default=None)
    args = parser.parse_args()

    plot_tuning_exp_results(
        no_tuning_results_file=args.no_tuning_result_files,
        with_tuning_results_file=args.with_tuning_result_files,
        base_dir=args.base_dir,
        data_mixture=args.data_mixture
    )

