import pandas as pd
import re
import plotly.express as px
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np


def plot_losses(results_df, save_folder_path):
    '''
        Plot and save the losses vs time
    '''
    fig, axes = plt.subplots(5, 1, figsize=(10, 8), sharex=True)
    sns.lineplot(data=results_df, x="t", y="loss_build", ax=axes[0])
    sns.lineplot(data=results_df, x="t", y="acc", ax=axes[1])
    sns.lineplot(data=results_df, x="t", y="cov", ax=axes[2])
    sns.lineplot(data=results_df, x="t", y="loss_privReg", ax=axes[3])
    sns.lineplot(data=results_df, x="t", y="loss_fairReg", ax=axes[4])
    plt.savefig(save_folder_path+"/losses.pdf")
    

def plot_parameters(results_df, save_folder_path, algorithm='fairPATE'):
    '''
        Plot and save the parameters vs time
    '''
    fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
    sns.lineplot(data=results_df, x="t", y="epsilon", hue="agent", ax=axes[0])
    if algorithm == 'fairPATE':
        sns.lineplot(data=results_df, x="t", y="gamma", hue="agent", ax=axes[1])
    elif algorithm == 'fairdp':
        sns.lineplot(data=results_df, x="t", y="tau", hue="agent", ax=axes[1])
    plt.savefig(save_folder_path+"/parameters.pdf")
    
    
def plot_parameters_3d(results_df, save_folder_path, algorithm='fairPATE'):
    '''
        Plot and save the parameters vs time in 3d plot
    '''
    if algorithm == 'fairPATE':
        y_var = "gamma"
    elif algorithm == 'fairdp':
        y_var = "tau"
    fig = px.scatter_3d(data_frame=results_df, x="epsilon", y=y_var, z="t", color="agent", width=600, height=600)
    fig.write_image(save_folder_path+"/parameters_3d.pdf")
    
    
def plot_figures(results_df, save_folder_path, algorithm='fairPATE'):
    '''
        Plot all the relevant figures using the experiment results
    '''
    results_df["loss_build"] = 100 + results_df["loss_build"]
    plot_losses(results_df, save_folder_path)
    plot_parameters(results_df, save_folder_path, algorithm)
    plot_parameters_3d(results_df, save_folder_path, algorithm)
    
    
    