import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import torch
import os
import csv
from scipy.interpolate import griddata
from scipy.ndimage import zoom


def save_data(cfg, num_trials, cumulative_regret, regret, episode_return, gp_loss):

    def save_data(filepath, data):
        with open(filepath, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            for t in range(num_trials):
                row = data[t].cpu().numpy().tolist()
                writer.writerow(row)

    # Save cumulative regret
    cumulative_regret_file = os.path.join(cfg.save_dir, "cumulative_regret.csv")
    save_data(cumulative_regret_file, cumulative_regret)

    # Save instantaneous regret
    regret_file = os.path.join(cfg.save_dir, "instantaneous_regret.csv")
    save_data(regret_file, regret)

    # Save episode return
    episode_return_file = os.path.join(cfg.save_dir, "episode_return.csv")
    save_data(episode_return_file, episode_return)

    # Save GP losses
    gp_losses_file = os.path.join(cfg.save_dir, "gp_losses.csv")
    save_data(gp_losses_file, gp_loss)

def plot_metrics(cfg, num_trials, cumulative_regret, regret, episode_return, gp_loss):
    episodes = np.arange(cumulative_regret.shape[-1])

    # Compute statistics over trials
    cumulative_regret_means = cumulative_regret[: num_trials].mean(axis=0).cpu().numpy()
    cumulative_regret_stds = cumulative_regret[: num_trials].std(axis=0).cpu().numpy()
    regret_means = regret[: num_trials].mean(axis=0).cpu().numpy()
    regret_stds = regret[: num_trials].std(axis=0).cpu().numpy()
    episode_return_means = episode_return[: num_trials].mean(axis=0).cpu().numpy()
    episode_return_stds = episode_return[: num_trials].std(axis=0).cpu().numpy()
    gp_loss_means = gp_loss[: num_trials].mean(axis=0).cpu().numpy()
    gp_loss_stds = gp_loss[: num_trials].std(axis=0).cpu().numpy()

    # Set up subplots
    plt.rcParams.update({
        'font.family': 'serif',
        'font.size': 9,
        'axes.titlesize': 10,
        'axes.labelsize': 9
    })
    sns.set_theme(style="white", context="talk")

    # Plot cumulative regret
    plt.figure(figsize=(8, 6))
    plt.errorbar(episodes, cumulative_regret_means, label='Cumulative Regret', color="b")
    plt.fill_between(episodes, cumulative_regret_means + cumulative_regret_stds, cumulative_regret_means - cumulative_regret_stds, alpha=0.2, color="b")
    plt.xlabel('Episodes')
    plt.ylabel('Cumulative Regret')
    plt.grid(True)
    plt.savefig(f"{cfg.save_dir}/cumulative_regret.png", bbox_inches='tight', dpi=300)
    plt.close()

    # Plot instantaneous regret
    plt.figure(figsize=(8, 6))
    plt.errorbar(episodes, regret_means, label='Instantaneous Regret', color="b")
    plt.fill_between(episodes, regret_means + regret_stds, regret_means - regret_stds, alpha=0.2, color="b")
    plt.xlabel('Episodes')
    plt.ylabel('Instantaneous Regret')
    plt.grid(True)
    plt.savefig(f"{cfg.save_dir}/instantaneous_regret.png", bbox_inches='tight', dpi=300)
    plt.close()

    # Plot episode return
    plt.figure(figsize=(8, 6))
    plt.errorbar(episodes, episode_return_means, label='Episode Return', color="b")
    plt.fill_between(episodes, episode_return_means + episode_return_stds, episode_return_means - episode_return_stds, alpha=0.2, color="b")
    plt.xlabel('Episodes')
    plt.ylabel('Episode Return')
    plt.grid(True)
    plt.savefig(f"{cfg.save_dir}/episode_return.png", bbox_inches='tight', dpi=300)
    plt.close()

    # Plot GP losses
    plt.figure(figsize=(8, 6))
    plt.errorbar(episodes, gp_loss_means, label='GP Loss', color="b")
    plt.fill_between(episodes, gp_loss_means + gp_loss_stds, gp_loss_means - gp_loss_stds, alpha=0.2, color="b")
    plt.xlabel('Episodes')
    plt.ylabel('GP Loss')
    plt.grid(True)
    plt.savefig(f"{cfg.save_dir}/gp_loss.png", bbox_inches='tight', dpi=300)
    plt.close()
