import matplotlib.pyplot as plt
import seaborn
import numpy as np

#For plotting
plt.style.use("seaborn-v0_8")

def plot_metric(solvers_metrics, metric):
    if metric not in ['train_loss', 'test_loss', 'train_acc', 'test_acc']:
        raise ValueError("This metric is not supported!")
    for solver in solvers_metrics:
        data = solvers_metrics[solver]
        grid_length, n_epochs = data[metric].shape
        for i in range(grid_length):
            if metric in ['training_loss', 'test_loss']:
                plt.semilogy(np.arange(n_epochs)+1, data[metric][i], label = solver+':'+" "+'lr ='+" "+repr(data['lr_grid'][i]), 
                         linewidth=2.5)
            else:
                plt.plot(np.arange(n_epochs)+1, data[metric][i], label = solver+':'+" "+'lr ='+" "+repr(data['lr_grid'][i]), 
                         linewidth=2.5)
    plt.xlabel("Epoch")
    if metric == 'train_loss':
        plt.ylabel("Training loss")
    
    elif metric == 'test_loss':
        plt.ylabel("Validation Loss")
    
    elif metric == 'train_acc':
        plt.ylabel("Training accuracy")
    
    else: 
        plt.ylabel("Validation Accuracy")

    plt.legend(title='Solver', frameon=True)


def plot_best_metric(solvers_metrics, metric,  
                     parameter, colors):
    if metric not in ['train_loss', 'test_loss', 'train_acc', 'test_acc']:
        raise ValueError("This metric is not supported!")
    
    if parameter not in ['data_passes', 'time']:
        raise ValueError("This parameter is not supported!")
    
    for solver in solvers_metrics:
        
        data = solvers_metrics[solver]
        Idx = data['best_lr_idx']
        n_epochs = data[metric][Idx].shape[0]
        
        if solver != 'Cronos_AM':
          if metric in ['training_loss', 'test_loss']:
            
            if parameter == 'data_passes':
              plt.semilogy(np.arange(n_epochs)+1, data[metric][Idx], 
              label = solver+" "+'(tuned)',
              color = colors[solver], linewidth=2.5)
              plt.xlabel("Data passes")
            
            elif parameter == 'time':
              plt.semilogy(np.cumsum(data['times'][Idx]), data[metric][Idx], 
              label = solver+" "+'(tuned)',
              color = colors[solver], linewidth=2.5)
              plt.xlabel("Time (s)")

          else:
              
              if parameter == 'data_passes':
                plt.plot(np.arange(n_epochs)+1, data[metric][Idx], 
                label = solver+" "+'(tuned)', 
                color = colors[solver], linewidth=2.5)
                plt.xlabel("Data passes")
              
              elif parameter == 'time':
                plt.plot(np.cumsum(data['times'][Idx]), data[metric][Idx], 
                label = solver+" "+'(tuned)',
                color = colors[solver], linewidth=2.5)
                plt.xlabel("Time (s)")

    if metric == 'train_loss':
        plt.ylabel("Training loss")
    
    elif metric == 'test_loss':
        plt.ylabel("Validation Loss")
    
    elif metric == 'train_acc':
        plt.ylabel("Training accuracy")
    
    else: 
        plt.ylabel("Validation Accuracy")

    plt.legend(title='Solver', frameon=True)

def plot_median_metric(solvers_metrics, metric, parameter, colors):
    if metric not in ['train_loss', 'test_loss', 'train_acc', 'test_acc']:
        raise ValueError("This metric is not supported!")
    alpha = 1
    for solver in solvers_metrics:
        data = solvers_metrics[solver]
        trajs = data[metric]
        grid_length, n_epochs = trajs.shape
        ql = np.quantile(trajs, 0.05, axis = 0)
        qu = np.quantile(trajs, 0.95, axis = 0)
        if metric in ['training_loss', 'test_loss']:
              if parameter == 'data_passes':
                if solver == 'Cronos_AM':
                  x = 2*np.arange(len(trajs[0]))
                else: 
                  x = np.arange(n_epochs)+1
                plt.semilogy(np.arange(n_epochs)+1, np.median(trajs, axis=0), label = solver, color = colors[solver], alpha = alpha)
                plt.xlabel("Data passes")
              elif parameter == 'time':
                x = np.median(np.cumsum(data['times'],axis = 1),axis=0)
                plt.semilogy(x, np.median(trajs, axis=0), label = solver, color = colors[solver], alpha = alpha)
                plt.xlabel("Time (s)")
        else:
              if parameter == 'data_passes':
                if solver == 'Cronos_AM':
                  x = 2*np.arange(len(trajs[0]))
                else: 
                  x = np.arange(n_epochs)+1
                plt.plot(x, np.median(trajs, axis=0), label = solver, color = colors[solver], alpha = alpha)
                plt.xlabel("Data passes")
              elif parameter == 'time':
                x = np.median(np.cumsum(data['times'],axis = 1),axis=0)
                plt.plot(x, np.median(trajs, axis=0), label = solver, color = colors[solver], alpha = alpha)
                plt.xlabel("Time (s)")
        plt.fill_between(
            x,
            ql,
            qu,
            alpha=0.2*alpha,
            color=colors[solver],
            linewidth=0.0,
            rasterized=True)
    
    if metric == 'train_loss':
      plt.ylabel("Training loss")

    elif metric == 'test_loss':
         plt.ylabel("Validation Loss")

    elif metric == 'train_acc':
         plt.ylabel("Training accuracy")

    else:
      plt.ylabel("Validation Accuracy")

    plt.legend(title='Solver', frameon=True) 
    